From dac0838a80c6c36474166d84923e63509b1b8887 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 30 Jun 2025 14:24:24 +0200 Subject: [PATCH] WIP --- crates/agent/src/agent.rs | 6 +- crates/agent/src/agent2.rs | 55 +- crates/agent/src/context.rs | 4 +- crates/agent/src/context_store.rs | 6 +- crates/agent/src/thread.rs | 7134 +++++++------------ crates/agent/src/thread2.rs | 1449 ---- crates/agent/src/thread_store.rs | 125 +- crates/agent/src/tool_use.rs | 567 -- crates/agent_ui/src/active_thread.rs | 760 +- crates/agent_ui/src/agent_diff.rs | 16 +- crates/agent_ui/src/agent_model_selector.rs | 2 +- crates/agent_ui/src/agent_panel.rs | 147 +- crates/agent_ui/src/agent_ui.rs | 2 +- crates/agent_ui/src/context_picker.rs | 2 +- crates/agent_ui/src/context_strip.rs | 4 +- crates/agent_ui/src/profile_selector.rs | 2 +- 16 files changed, 3249 insertions(+), 7032 deletions(-) delete mode 100644 crates/agent/src/thread2.rs delete mode 100644 crates/agent/src/tool_use.rs diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 8deee53ae0b6fc19e1e53bbefc06a93dd46f0d0a..0adb0f0287d330a69a707ed492fdad26feb57bc1 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -5,17 +5,15 @@ pub mod context_server_tool; pub mod context_store; pub mod history_store; pub mod thread; -mod thread2; pub mod thread_store; -pub mod tool_use; mod zed_agent; pub use agent2::*; pub use context::{AgentContext, ContextId, ContextLoadResult}; pub use context_store::ContextStore; pub use thread::{ - LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, - ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, + LastRestoreCheckpoint, Message, MessageCrease, Thread, ThreadError, ThreadEvent, + ThreadFeedback, ThreadTitle, TokenUsageRatio, }; pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore}; pub use zed_agent::*; diff --git a/crates/agent/src/agent2.rs b/crates/agent/src/agent2.rs index c0b5042ffe07f8026fa9e140803dcd6997559f81..800b62f9156bd022e46edecfbce09ee37ff0f9d8 100644 --- a/crates/agent/src/agent2.rs +++ b/crates/agent/src/agent2.rs @@ -2,13 +2,42 @@ use anyhow::Result; use assistant_tool::{Tool, ToolResultOutput}; use futures::{channel::oneshot, future::BoxFuture, stream::BoxStream}; use gpui::SharedString; -use std::sync::Arc; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{self, Display}, + sync::Arc, +}; -#[derive(Debug, Clone)] -pub struct AgentThreadId(SharedString); +#[derive( + Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema, +)] +pub struct ThreadId(SharedString); + +impl ThreadId { + pub fn as_str(&self) -> &str { + &self.0 + } + + pub fn to_string(&self) -> String { + self.0.to_string() + } +} + +impl From<&str> for ThreadId { + fn from(value: &str) -> Self { + ThreadId(SharedString::from(value.to_string())) + } +} + +impl Display for ThreadId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} -#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] -pub struct AgentThreadMessageId(usize); +#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct MessageId(pub usize); #[derive(Debug, Clone)] pub struct AgentThreadToolCallId(SharedString); @@ -31,11 +60,11 @@ pub enum AgentThreadResponseEvent { pub enum AgentThreadMessage { User { - id: AgentThreadMessageId, + id: MessageId, chunks: Vec, }, Assistant { - id: AgentThreadMessageId, + id: MessageId, chunks: Vec, }, } @@ -56,20 +85,20 @@ pub enum AgentThreadAssistantMessageChunk { }, } -struct AgentThreadResponse { - user_message_id: AgentThreadMessageId, - events: BoxStream<'static, Result>, +pub struct AgentThreadResponse { + pub user_message_id: MessageId, + pub events: BoxStream<'static, Result>, } pub trait AgentThread { - fn id(&self) -> AgentThreadId; + fn id(&self) -> ThreadId; fn title(&self) -> BoxFuture<'static, Result>; fn summary(&self) -> BoxFuture<'static, Result>; fn messages(&self) -> BoxFuture<'static, Result>>; - fn truncate(&self, message_id: AgentThreadMessageId) -> BoxFuture<'static, Result<()>>; + fn truncate(&self, message_id: MessageId) -> BoxFuture<'static, Result<()>>; fn edit( &self, - message_id: AgentThreadMessageId, + message_id: MessageId, content: Vec, max_iterations: usize, ) -> BoxFuture<'static, Result>; diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index ddd13de491ecb0e7d143ae7a6c3e602858fb9b85..333599d2346f007fa97ccd744fc9512ae9034584 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -581,7 +581,7 @@ impl ThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.thread.read(cx).summary().or_default() + self.thread.read(cx).title().or_default() } fn load(self, cx: &App) -> Task>)>> { @@ -589,7 +589,7 @@ impl ThreadContextHandle { let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; let title = self .thread - .read_with(cx, |thread, _cx| thread.summary().or_default()) + .read_with(cx, |thread, _cx| thread.title().or_default()) .ok()?; let context = AgentContext::Thread(ThreadContext { title, diff --git a/crates/agent/src/context_store.rs b/crates/agent/src/context_store.rs index 60ba5527dcca22d81b7da62657c6abc00aa51607..662a8bb1ef5ad09de7ca2a29d6f5653bec12f23d 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent/src/context_store.rs @@ -1,10 +1,11 @@ use crate::{ + MessageId, ThreadId, context::{ AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle, FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle, }, - thread::{MessageId, Thread, ThreadId}, + thread::Thread, thread_store::ThreadStore, }; use anyhow::{Context as _, Result, anyhow}; @@ -71,6 +72,7 @@ impl ContextStore { ) -> Vec { let existing_context = thread .messages() + .iter() .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id)) .flat_map(|message| { message @@ -441,7 +443,7 @@ impl ContextStore { match context { AgentContextHandle::Thread(thread_context) => { self.context_thread_ids - .remove(thread_context.thread.read(cx).id()); + .remove(&thread_context.thread.read(cx).id()); } AgentContextHandle::TextThread(text_thread_context) => { if let Some(path) = text_thread_context.context.read(cx).path() { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 32c376ca67fdfa492233ccb6ce1b2947abba0538..9054ea7f1983287e990ac9105f64716f4ce84187 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1,12 +1,8 @@ use crate::{ + AgentThread, AgentThreadUserMessageChunk, MessageId, ThreadId, agent_profile::AgentProfile, context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}, - thread_store::{ - SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, - SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext, - ThreadStore, - }, - tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, + thread_store::{SharedProjectContext, ThreadStore}, }; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; @@ -15,7 +11,7 @@ use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; use collections::{HashMap, HashSet}; use feature_flags::{self, FeatureFlagAppExt}; -use futures::{FutureExt, StreamExt as _, future::Shared}; +use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared}; use git::repository::DiffType; use gpui::{ AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, @@ -26,8 +22,7 @@ use language_model::{ LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, - TokenUsage, + ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::{ @@ -36,7 +31,6 @@ use project::{ }; use prompt_store::{ModelContext, PromptBuilder}; use proto::Plan; -use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use std::{ @@ -47,66 +41,8 @@ use std::{ }; use thiserror::Error; use util::{ResultExt as _, post_inc}; -use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; -const MAX_RETRY_ATTEMPTS: u8 = 3; -const BASE_RETRY_DELAY_SECS: u64 = 5; - -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, -)] -pub struct ThreadId(Arc); - -impl ThreadId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for ThreadId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<&str> for ThreadId { - fn from(value: &str) -> Self { - Self(value.into()) - } -} - -/// The ID of the user prompt that initiated a request. -/// -/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] -pub struct PromptId(Arc); - -impl PromptId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for PromptId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -pub struct MessageId(pub(crate) usize); - -impl MessageId { - fn post_inc(&mut self) -> Self { - Self(post_inc(&mut self.0)) - } - - pub fn as_usize(&self) -> usize { - self.0 - } -} - /// Stored information that can be used to resurrect a context crease when creating an editor for a past message. #[derive(Clone, Debug)] pub struct MessageCrease { @@ -117,105 +53,38 @@ pub struct MessageCrease { pub context: Option, } +pub enum MessageTool { + Pending { + tool: Arc, + input: serde_json::Value, + }, + NeedsConfirmation { + tool: Arc, + input_json: serde_json::Value, + confirm_tx: oneshot::Sender, + }, + Confirmed { + card: AnyToolCard, + }, + Declined { + tool: Arc, + input_json: serde_json::Value, + }, +} + /// A message in a [`Thread`]. -#[derive(Debug, Clone)] pub struct Message { pub id: MessageId, pub role: Role, - pub segments: Vec, + pub thinking: String, + pub text: String, + pub tools: Vec, pub loaded_context: LoadedContext, pub creases: Vec, pub is_hidden: bool, pub ui_only: bool, } -impl Message { - /// Returns whether the message contains any meaningful text that should be displayed - /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`]) - pub fn should_display_content(&self) -> bool { - self.segments.iter().all(|segment| segment.should_display()) - } - - pub fn push_thinking(&mut self, text: &str, signature: Option) { - if let Some(MessageSegment::Thinking { - text: segment, - signature: current_signature, - }) = self.segments.last_mut() - { - if let Some(signature) = signature { - *current_signature = Some(signature); - } - segment.push_str(text); - } else { - self.segments.push(MessageSegment::Thinking { - text: text.to_string(), - signature, - }); - } - } - - pub fn push_redacted_thinking(&mut self, data: String) { - self.segments.push(MessageSegment::RedactedThinking(data)); - } - - pub fn push_text(&mut self, text: &str) { - if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() { - segment.push_str(text); - } else { - self.segments.push(MessageSegment::Text(text.to_string())); - } - } - - pub fn to_string(&self) -> String { - let mut result = String::new(); - - if !self.loaded_context.text.is_empty() { - result.push_str(&self.loaded_context.text); - } - - for segment in &self.segments { - match segment { - MessageSegment::Text(text) => result.push_str(text), - MessageSegment::Thinking { text, .. } => { - result.push_str("\n"); - result.push_str(text); - result.push_str("\n"); - } - MessageSegment::RedactedThinking(_) => {} - } - } - - result - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum MessageSegment { - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking(String), -} - -impl MessageSegment { - pub fn should_display(&self) -> bool { - match self { - Self::Text(text) => text.is_empty(), - Self::Thinking { text, .. } => text.is_empty(), - Self::RedactedThinking(_) => false, - } - } - - pub fn text(&self) -> Option<&str> { - match self { - MessageSegment::Text(text) => Some(text), - _ => None, - } - } -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ProjectSnapshot { pub worktree_snapshots: Vec, @@ -345,25 +214,17 @@ pub enum QueueState { /// A thread of conversation with the LLM. pub struct Thread { - id: ThreadId, - updated_at: DateTime, - summary: ThreadSummary, + agent_thread: Arc, + title: ThreadTitle, + pending_send: Option>>, pending_summary: Task>, detailed_summary_task: Task>, detailed_summary_tx: postage::watch::Sender, detailed_summary_rx: postage::watch::Receiver, completion_mode: agent_settings::CompletionMode, messages: Vec, - next_message_id: MessageId, - last_prompt_id: PromptId, - project_context: SharedProjectContext, checkpoints_by_message: HashMap, - completion_count: usize, - pending_completions: Vec, project: Entity, - prompt_builder: Arc, - tools: Entity, - tool_use: ToolUseState, action_log: Entity, last_restore_checkpoint: Option, pending_checkpoint: Option, @@ -372,35 +233,22 @@ pub struct Thread { cumulative_token_usage: TokenUsage, exceeded_window_error: Option, tool_use_limit_reached: bool, + // todo!(keep track of retries from the underlying agent) feedback: Option, - retry_state: Option, message_feedback: HashMap, last_auto_capture_at: Option, last_received_chunk_at: Option, - request_callback: Option< - Box])>, - >, - remaining_turns: u32, - configured_model: Option, - profile: AgentProfile, -} - -#[derive(Clone, Debug)] -struct RetryState { - attempt: u8, - max_attempts: u8, - intent: CompletionIntent, } #[derive(Clone, Debug, PartialEq, Eq)] -pub enum ThreadSummary { +pub enum ThreadTitle { Pending, Generating, Ready(SharedString), Error, } -impl ThreadSummary { +impl ThreadTitle { pub const DEFAULT: SharedString = SharedString::new_static("New Thread"); pub fn or_default(&self) -> SharedString { @@ -413,8 +261,8 @@ impl ThreadSummary { pub fn ready(&self) -> Option { match self { - ThreadSummary::Ready(summary) => Some(summary.clone()), - ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None, + ThreadTitle::Ready(summary) => Some(summary.clone()), + ThreadTitle::Pending | ThreadTitle::Generating | ThreadTitle::Error => None, } } } @@ -428,39 +276,26 @@ pub struct ExceededWindowError { } impl Thread { - pub fn new( + pub fn load( + agent_thread: Arc, project: Entity, - tools: Entity, - prompt_builder: Arc, - system_prompt: SharedProjectContext, cx: &mut Context, ) -> Self { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); - let configured_model = LanguageModelRegistry::read_global(cx).default_model(); - let profile_id = AgentSettings::get_global(cx).default_profile.clone(); - Self { - id: ThreadId::new(), - updated_at: Utc::now(), - summary: ThreadSummary::Pending, + agent_thread, + title: ThreadTitle::Pending, + pending_send: None, pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, detailed_summary_rx, completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, - messages: Vec::new(), - next_message_id: MessageId(0), - last_prompt_id: PromptId::new(), - project_context: system_prompt, + messages: todo!("read from agent"), checkpoints_by_message: HashMap::default(), - completion_count: 0, - pending_completions: Vec::new(), project: project.clone(), - prompt_builder, - tools: tools.clone(), last_restore_checkpoint: None, pending_checkpoint: None, - tool_use: ToolUseState::new(tools.clone()), action_log: cx.new(|_| ActionLog::new(project.clone())), initial_project_snapshot: { let project_snapshot = Self::project_snapshot(project, cx); @@ -473,221 +308,64 @@ impl Thread { exceeded_window_error: None, tool_use_limit_reached: false, feedback: None, - retry_state: None, - message_feedback: HashMap::default(), - last_auto_capture_at: None, - last_received_chunk_at: None, - request_callback: None, - remaining_turns: u32::MAX, - configured_model, - profile: AgentProfile::new(profile_id, tools), - } - } - - pub fn deserialize( - id: ThreadId, - serialized: SerializedThread, - project: Entity, - tools: Entity, - prompt_builder: Arc, - project_context: SharedProjectContext, - window: Option<&mut Window>, // None in headless mode - cx: &mut Context, - ) -> Self { - let next_message_id = MessageId( - serialized - .messages - .last() - .map(|message| message.id.0 + 1) - .unwrap_or(0), - ); - let tool_use = ToolUseState::from_serialized_messages( - tools.clone(), - &serialized.messages, - project.clone(), - window, - cx, - ); - let (detailed_summary_tx, detailed_summary_rx) = - postage::watch::channel_with(serialized.detailed_summary_state); - - let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - serialized - .model - .and_then(|model| { - let model = SelectedModel { - provider: model.provider.clone().into(), - model: model.model.clone().into(), - }; - registry.select_model(&model, cx) - }) - .or_else(|| registry.default_model()) - }); - - let completion_mode = serialized - .completion_mode - .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode); - let profile_id = serialized - .profile - .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); - - Self { - id, - updated_at: serialized.updated_at, - summary: ThreadSummary::Ready(serialized.summary), - pending_summary: Task::ready(None), - detailed_summary_task: Task::ready(None), - detailed_summary_tx, - detailed_summary_rx, - completion_mode, - retry_state: None, - messages: serialized - .messages - .into_iter() - .map(|message| Message { - id: message.id, - role: message.role, - segments: message - .segments - .into_iter() - .map(|segment| match segment { - SerializedMessageSegment::Text { text } => MessageSegment::Text(text), - SerializedMessageSegment::Thinking { text, signature } => { - MessageSegment::Thinking { text, signature } - } - SerializedMessageSegment::RedactedThinking { data } => { - MessageSegment::RedactedThinking(data) - } - }) - .collect(), - loaded_context: LoadedContext { - contexts: Vec::new(), - text: message.context, - images: Vec::new(), - }, - creases: message - .creases - .into_iter() - .map(|crease| MessageCrease { - range: crease.start..crease.end, - icon_path: crease.icon_path, - label: crease.label, - context: None, - }) - .collect(), - is_hidden: message.is_hidden, - ui_only: false, // UI-only messages are not persisted - }) - .collect(), - next_message_id, - last_prompt_id: PromptId::new(), - project_context, - checkpoints_by_message: HashMap::default(), - completion_count: 0, - pending_completions: Vec::new(), - last_restore_checkpoint: None, - pending_checkpoint: None, - project: project.clone(), - prompt_builder, - tools: tools.clone(), - tool_use, - action_log: cx.new(|_| ActionLog::new(project)), - initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), - request_token_usage: serialized.request_token_usage, - cumulative_token_usage: serialized.cumulative_token_usage, - exceeded_window_error: None, - tool_use_limit_reached: serialized.tool_use_limit_reached, - feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, last_received_chunk_at: None, - request_callback: None, - remaining_turns: u32::MAX, - configured_model, - profile: AgentProfile::new(profile_id, tools), } } - pub fn set_request_callback( - &mut self, - callback: impl 'static - + FnMut(&LanguageModelRequest, &[Result]), - ) { - self.request_callback = Some(Box::new(callback)); - } - - pub fn id(&self) -> &ThreadId { - &self.id + pub fn id(&self) -> ThreadId { + self.agent_thread.id() } pub fn profile(&self) -> &AgentProfile { - &self.profile + todo!() } pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context) { - if &id != self.profile.id() { - self.profile = AgentProfile::new(id, self.tools.clone()); - cx.emit(ThreadEvent::ProfileChanged); - } + todo!() + // if &id != self.profile.id() { + // self.profile = AgentProfile::new(id, self.tools.clone()); + // cx.emit(ThreadEvent::ProfileChanged); + // } } pub fn is_empty(&self) -> bool { self.messages.is_empty() } - pub fn updated_at(&self) -> DateTime { - self.updated_at - } - - pub fn touch_updated_at(&mut self) { - self.updated_at = Utc::now(); - } - - pub fn advance_prompt_id(&mut self) { - self.last_prompt_id = PromptId::new(); - } - pub fn project_context(&self) -> SharedProjectContext { - self.project_context.clone() - } - - pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option { - if self.configured_model.is_none() { - self.configured_model = LanguageModelRegistry::read_global(cx).default_model(); - } - self.configured_model.clone() - } - - pub fn configured_model(&self) -> Option { - self.configured_model.clone() + todo!() + // self.project_context.clone() } - pub fn set_configured_model(&mut self, model: Option, cx: &mut Context) { - self.configured_model = model; - cx.notify(); + pub fn title(&self) -> &ThreadTitle { + &self.title } - pub fn summary(&self) -> &ThreadSummary { - &self.summary - } + pub fn set_title(&mut self, new_summary: impl Into, cx: &mut Context) { + todo!() + // let current_summary = match &self.summary { + // ThreadSummary::Pending | ThreadSummary::Generating => return, + // ThreadSummary::Ready(summary) => summary, + // ThreadSummary::Error => &ThreadSummary::DEFAULT, + // }; - pub fn set_summary(&mut self, new_summary: impl Into, cx: &mut Context) { - let current_summary = match &self.summary { - ThreadSummary::Pending | ThreadSummary::Generating => return, - ThreadSummary::Ready(summary) => summary, - ThreadSummary::Error => &ThreadSummary::DEFAULT, - }; + // let mut new_summary = new_summary.into(); - let mut new_summary = new_summary.into(); + // if new_summary.is_empty() { + // new_summary = ThreadSummary::DEFAULT; + // } - if new_summary.is_empty() { - new_summary = ThreadSummary::DEFAULT; - } + // if current_summary != &new_summary { + // self.summary = ThreadSummary::Ready(new_summary); + // cx.emit(ThreadEvent::SummaryChanged); + // } + } - if current_summary != &new_summary { - self.summary = ThreadSummary::Ready(new_summary); - cx.emit(ThreadEvent::SummaryChanged); - } + pub fn regenerate_summary(&self, cx: &mut Context) { + todo!() + // self.summarize(cx); } pub fn completion_mode(&self) -> CompletionMode { @@ -707,12 +385,12 @@ impl Thread { self.messages.get(index) } - pub fn messages(&self) -> impl ExactSizeIterator { - self.messages.iter() + pub fn messages(&self) -> &[Message] { + &self.messages } pub fn is_generating(&self) -> bool { - !self.pending_completions.is_empty() || !self.all_tools_finished() + self.pending_send.is_some() } /// Indicates whether streaming of language model events is stale. @@ -728,34 +406,6 @@ impl Thread { self.last_received_chunk_at = Some(Instant::now()); } - pub fn queue_state(&self) -> Option { - self.pending_completions - .first() - .map(|pending_completion| pending_completion.queue_state) - } - - pub fn tools(&self) -> &Entity { - &self.tools - } - - pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> { - self.tool_use - .pending_tool_uses() - .into_iter() - .find(|tool_use| &tool_use.id == id) - } - - pub fn tools_needing_confirmation(&self) -> impl Iterator { - self.tool_use - .pending_tool_uses() - .into_iter() - .filter(|tool_use| tool_use.status.needs_confirmation()) - } - - pub fn has_pending_tool_uses(&self) -> bool { - !self.tool_use.pending_tool_uses().is_empty() - } - pub fn checkpoint_for_message(&self, id: MessageId) -> Option { self.checkpoints_by_message.get(&id).cloned() } @@ -855,6 +505,7 @@ impl Thread { } pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context) { + todo!("call truncate on the agent"); let Some(message_ix) = self .messages .iter() @@ -868,248 +519,203 @@ impl Thread { cx.notify(); } - pub fn context_for_message(&self, id: MessageId) -> impl Iterator { - self.messages - .iter() - .find(|message| message.id == id) - .into_iter() - .flat_map(|message| message.loaded_context.contexts.iter()) - } - pub fn is_turn_end(&self, ix: usize) -> bool { - if self.messages.is_empty() { - return false; - } + todo!() + // if self.messages.is_empty() { + // return false; + // } - if !self.is_generating() && ix == self.messages.len() - 1 { - return true; - } + // if !self.is_generating() && ix == self.messages.len() - 1 { + // return true; + // } - let Some(message) = self.messages.get(ix) else { - return false; - }; + // let Some(message) = self.messages.get(ix) else { + // return false; + // }; - if message.role != Role::Assistant { - return false; - } + // if message.role != Role::Assistant { + // return false; + // } - self.messages - .get(ix + 1) - .and_then(|message| { - self.message(message.id) - .map(|next_message| next_message.role == Role::User && !next_message.is_hidden) - }) - .unwrap_or(false) + // self.messages + // .get(ix + 1) + // .and_then(|message| { + // self.message(message.id) + // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden) + // }) + // .unwrap_or(false) } pub fn tool_use_limit_reached(&self) -> bool { self.tool_use_limit_reached } - /// Returns whether all of the tool uses have finished running. - pub fn all_tools_finished(&self) -> bool { - // If the only pending tool uses left are the ones with errors, then - // that means that we've finished running all of the pending tools. - self.tool_use - .pending_tool_uses() - .iter() - .all(|pending_tool_use| pending_tool_use.status.is_error()) - } - /// Returns whether any pending tool uses may perform edits pub fn has_pending_edit_tool_uses(&self) -> bool { - self.tool_use - .pending_tool_uses() - .iter() - .filter(|pending_tool_use| !pending_tool_use.status.is_error()) - .any(|pending_tool_use| pending_tool_use.may_perform_edits) - } - - pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { - self.tool_use.tool_uses_for_message(id, cx) - } - - pub fn tool_results_for_message( - &self, - assistant_message_id: MessageId, - ) -> Vec<&LanguageModelToolResult> { - self.tool_use.tool_results_for_message(assistant_message_id) - } - - pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> { - self.tool_use.tool_result(id) - } - - pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc> { - match &self.tool_use.tool_result(id)?.content { - LanguageModelToolResultContent::Text(text) => Some(text), - LanguageModelToolResultContent::Image(_) => { - // TODO: We should display image - None - } - } - } - - pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option { - self.tool_use.tool_result_card(id).cloned() - } - - /// Return tools that are both enabled and supported by the model - pub fn available_tools( - &self, - cx: &App, - model: Arc, - ) -> Vec { - if model.supports_tools() { - resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice()) - .into_iter() - .filter_map(|(name, tool)| { - // Skip tools that cannot be supported - let input_schema = tool.input_schema(model.tool_input_format()).ok()?; - Some(LanguageModelRequestTool { - name, - description: tool.description(), - input_schema, - }) - }) - .collect() - } else { - Vec::default() - } - } - - pub fn insert_user_message( + todo!() + } + + // pub fn insert_user_message( + // &mut self, + // text: impl Into, + // loaded_context: ContextLoadResult, + // git_checkpoint: Option, + // creases: Vec, + // cx: &mut Context, + // ) -> AgentThreadMessageId { + // todo!("move this logic into send") + // if !loaded_context.referenced_buffers.is_empty() { + // self.action_log.update(cx, |log, cx| { + // for buffer in loaded_context.referenced_buffers { + // log.buffer_read(buffer, cx); + // } + // }); + // } + + // let message_id = self.insert_message( + // Role::User, + // vec![MessageSegment::Text(text.into())], + // loaded_context.loaded_context, + // creases, + // false, + // cx, + // ); + + // if let Some(git_checkpoint) = git_checkpoint { + // self.pending_checkpoint = Some(ThreadCheckpoint { + // message_id, + // git_checkpoint, + // }); + // } + + // self.auto_capture_telemetry(cx); + + // message_id + // } + + pub fn set_model(&mut self, model: Option, cx: &mut Context) { + todo!() + } + + pub fn model(&self) -> Option { + todo!() + } + + pub fn send( &mut self, - text: impl Into, - loaded_context: ContextLoadResult, - git_checkpoint: Option, - creases: Vec, + message: Vec, + window: &mut Window, cx: &mut Context, - ) -> MessageId { - if !loaded_context.referenced_buffers.is_empty() { - self.action_log.update(cx, |log, cx| { - for buffer in loaded_context.referenced_buffers { - log.buffer_read(buffer, cx); - } - }); - } - - let message_id = self.insert_message( - Role::User, - vec![MessageSegment::Text(text.into())], - loaded_context.loaded_context, - creases, - false, - cx, - ); - - if let Some(git_checkpoint) = git_checkpoint { - self.pending_checkpoint = Some(ThreadCheckpoint { - message_id, - git_checkpoint, - }); - } - - self.auto_capture_telemetry(cx); - - message_id - } - - pub fn insert_invisible_continue_message(&mut self, cx: &mut Context) -> MessageId { - let id = self.insert_message( - Role::User, - vec![MessageSegment::Text("Continue where you left off".into())], - LoadedContext::default(), - vec![], - true, - cx, - ); - self.pending_checkpoint = None; - - id - } - - pub fn insert_assistant_message( - &mut self, - segments: Vec, - cx: &mut Context, - ) -> MessageId { - self.insert_message( - Role::Assistant, - segments, - LoadedContext::default(), - Vec::new(), - false, - cx, - ) + ) { + todo!() } - pub fn insert_message( - &mut self, - role: Role, - segments: Vec, - loaded_context: LoadedContext, - creases: Vec, - is_hidden: bool, - cx: &mut Context, - ) -> MessageId { - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role, - segments, - loaded_context, - creases, - is_hidden, - ui_only: false, - }); - self.touch_updated_at(); - cx.emit(ThreadEvent::MessageAdded(id)); - id + pub fn resume(&mut self, window: &mut Window, cx: &mut Context) { + todo!() } - pub fn edit_message( + pub fn edit( &mut self, - id: MessageId, - new_role: Role, - new_segments: Vec, - creases: Vec, - loaded_context: Option, - checkpoint: Option, + message_id: MessageId, + message: Vec, + window: &mut Window, cx: &mut Context, - ) -> bool { - let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { - return false; - }; - message.role = new_role; - message.segments = new_segments; - message.creases = creases; - if let Some(context) = loaded_context { - message.loaded_context = context; - } - if let Some(git_checkpoint) = checkpoint { - self.checkpoints_by_message.insert( - id, - ThreadCheckpoint { - message_id: id, - git_checkpoint, - }, - ); - } - self.touch_updated_at(); - cx.emit(ThreadEvent::MessageEdited(id)); - true - } - - pub fn delete_message(&mut self, id: MessageId, cx: &mut Context) -> bool { - let Some(index) = self.messages.iter().position(|message| message.id == id) else { - return false; - }; - self.messages.remove(index); - self.touch_updated_at(); - cx.emit(ThreadEvent::MessageDeleted(id)); - true - } + ) { + todo!() + } + + pub fn cancel(&mut self, window: &mut Window, cx: &mut Context) -> bool { + todo!() + } + + // pub fn insert_invisible_continue_message( + // &mut self, + // cx: &mut Context, + // ) -> AgentThreadMessageId { + // let id = self.insert_message( + // Role::User, + // vec![MessageSegment::Text("Continue where you left off".into())], + // LoadedContext::default(), + // vec![], + // true, + // cx, + // ); + // self.pending_checkpoint = None; + + // id + // } + + // pub fn insert_assistant_message( + // &mut self, + // segments: Vec, + // cx: &mut Context, + // ) -> AgentThreadMessageId { + // self.insert_message( + // Role::Assistant, + // segments, + // LoadedContext::default(), + // Vec::new(), + // false, + // cx, + // ) + // } + + // pub fn insert_message( + // &mut self, + // role: Role, + // segments: Vec, + // loaded_context: LoadedContext, + // creases: Vec, + // is_hidden: bool, + // cx: &mut Context, + // ) -> AgentThreadMessageId { + // let id = self.next_message_id.post_inc(); + // self.messages.push(Message { + // id, + // role, + // segments, + // loaded_context, + // creases, + // is_hidden, + // ui_only: false, + // }); + // self.touch_updated_at(); + // cx.emit(ThreadEvent::MessageAdded(id)); + // id + // } + + // pub fn edit_message( + // &mut self, + // id: AgentThreadMessageId, + // new_role: Role, + // new_segments: Vec, + // creases: Vec, + // loaded_context: Option, + // checkpoint: Option, + // cx: &mut Context, + // ) -> bool { + // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { + // return false; + // }; + // message.role = new_role; + // message.segments = new_segments; + // message.creases = creases; + // if let Some(context) = loaded_context { + // message.loaded_context = context; + // } + // if let Some(git_checkpoint) = checkpoint { + // self.checkpoints_by_message.insert( + // id, + // ThreadCheckpoint { + // message_id: id, + // git_checkpoint, + // }, + // ); + // } + // self.touch_updated_at(); + // cx.emit(ThreadEvent::MessageEdited(id)); + // true + // } /// Returns the representation of this [`Thread`] in a textual form. /// @@ -1125,1140 +731,53 @@ impl Thread { }); text.push('\n'); - for segment in &message.segments { - match segment { - MessageSegment::Text(content) => text.push_str(content), - MessageSegment::Thinking { text: content, .. } => { - text.push_str(&format!("{}", content)) - } - MessageSegment::RedactedThinking(_) => {} - } - } + text.push_str(""); + text.push_str(&message.thinking); + text.push_str(""); + text.push_str(&message.text); + + // todo!('what about tools?'); + text.push('\n'); } text } - /// Serializes this thread into a format for storage or telemetry. - pub fn serialize(&self, cx: &mut Context) -> Task> { - let initial_project_snapshot = self.initial_project_snapshot.clone(); - cx.spawn(async move |this, cx| { - let initial_project_snapshot = initial_project_snapshot.await; - this.read_with(cx, |this, cx| SerializedThread { - version: SerializedThread::VERSION.to_string(), - summary: this.summary().or_default(), - updated_at: this.updated_at(), - messages: this - .messages() - .filter(|message| !message.ui_only) - .map(|message| SerializedMessage { - id: message.id, - role: message.role, - segments: message - .segments - .iter() - .map(|segment| match segment { - MessageSegment::Text(text) => { - SerializedMessageSegment::Text { text: text.clone() } - } - MessageSegment::Thinking { text, signature } => { - SerializedMessageSegment::Thinking { - text: text.clone(), - signature: signature.clone(), - } - } - MessageSegment::RedactedThinking(data) => { - SerializedMessageSegment::RedactedThinking { - data: data.clone(), - } - } - }) - .collect(), - tool_uses: this - .tool_uses_for_message(message.id, cx) - .into_iter() - .map(|tool_use| SerializedToolUse { - id: tool_use.id, - name: tool_use.name, - input: tool_use.input, - }) - .collect(), - tool_results: this - .tool_results_for_message(message.id) - .into_iter() - .map(|tool_result| SerializedToolResult { - tool_use_id: tool_result.tool_use_id.clone(), - is_error: tool_result.is_error, - content: tool_result.content.clone(), - output: tool_result.output.clone(), - }) - .collect(), - context: message.loaded_context.text.clone(), - creases: message - .creases - .iter() - .map(|crease| SerializedCrease { - start: crease.range.start, - end: crease.range.end, - icon_path: crease.icon_path.clone(), - label: crease.label.clone(), - }) - .collect(), - is_hidden: message.is_hidden, - }) - .collect(), - initial_project_snapshot, - cumulative_token_usage: this.cumulative_token_usage, - request_token_usage: this.request_token_usage.clone(), - detailed_summary_state: this.detailed_summary_rx.borrow().clone(), - exceeded_window_error: this.exceeded_window_error.clone(), - model: this - .configured_model - .as_ref() - .map(|model| SerializedLanguageModel { - provider: model.provider.id().0.to_string(), - model: model.model.id().0.to_string(), - }), - completion_mode: Some(this.completion_mode), - tool_use_limit_reached: this.tool_use_limit_reached, - profile: Some(this.profile.id().clone()), - }) - }) - } - - pub fn remaining_turns(&self) -> u32 { - self.remaining_turns - } + pub fn used_tools_since_last_user_message(&self) -> bool { + todo!() + // for message in self.messages.iter().rev() { + // if self.tool_use.message_has_tool_results(message.id) { + // return true; + // } else if message.role == Role::User { + // return false; + // } + // } - pub fn set_remaining_turns(&mut self, remaining_turns: u32) { - self.remaining_turns = remaining_turns; + // false } - pub fn send_to_model( + pub fn start_generating_detailed_summary_if_needed( &mut self, - model: Arc, - intent: CompletionIntent, - window: Option, + thread_store: WeakEntity, cx: &mut Context, ) { - if self.remaining_turns == 0 { + let Some(last_message_id) = self.messages.last().map(|message| message.id) else { return; - } - - self.remaining_turns -= 1; - - let request = self.to_completion_request(model.clone(), intent, cx); - - self.stream_completion(request, model, intent, window, cx); - } - - pub fn used_tools_since_last_user_message(&self) -> bool { - for message in self.messages.iter().rev() { - if self.tool_use.message_has_tool_results(message.id) { - return true; - } else if message.role == Role::User { - return false; - } - } - - false - } - - pub fn to_completion_request( - &self, - model: Arc, - intent: CompletionIntent, - cx: &mut Context, - ) -> LanguageModelRequest { - let mut request = LanguageModelRequest { - thread_id: Some(self.id.to_string()), - prompt_id: Some(self.last_prompt_id.to_string()), - intent: Some(intent), - mode: None, - messages: vec![], - tools: Vec::new(), - tool_choice: None, - stop: Vec::new(), - temperature: AgentSettings::temperature_for_model(&model, cx), - }; - - let available_tools = self.available_tools(cx, model.clone()); - let available_tool_names = available_tools - .iter() - .map(|tool| tool.name.clone()) - .collect(); - - let model_context = &ModelContext { - available_tools: available_tool_names, }; - if let Some(project_context) = self.project_context.borrow().as_ref() { - match self - .prompt_builder - .generate_assistant_system_prompt(project_context, model_context) + match &*self.detailed_summary_rx.borrow() { + DetailedSummaryState::Generating { message_id, .. } + | DetailedSummaryState::Generated { message_id, .. } + if *message_id == last_message_id => { - Err(err) => { - let message = format!("{err:?}").into(); - log::error!("{message}"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error generating system prompt".into(), - message, - })); - } - Ok(system_prompt) => { - request.messages.push(LanguageModelRequestMessage { - role: Role::System, - content: vec![MessageContent::Text(system_prompt)], - cache: true, - }); - } + // Already up-to-date + return; } - } else { - let message = "Context for system prompt unexpectedly not ready.".into(); - log::error!("{message}"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error generating system prompt".into(), - message, - })); + _ => {} } - let mut message_ix_to_cache = None; - for message in &self.messages { - // ui_only messages are for the UI only, not for the model - if message.ui_only { - continue; - } - - let mut request_message = LanguageModelRequestMessage { - role: message.role, - content: Vec::new(), - cache: false, - }; - - message - .loaded_context - .add_to_request_message(&mut request_message); - - for segment in &message.segments { - match segment { - MessageSegment::Text(text) => { - if !text.is_empty() { - request_message - .content - .push(MessageContent::Text(text.into())); - } - } - MessageSegment::Thinking { text, signature } => { - if !text.is_empty() { - request_message.content.push(MessageContent::Thinking { - text: text.into(), - signature: signature.clone(), - }); - } - } - MessageSegment::RedactedThinking(data) => { - request_message - .content - .push(MessageContent::RedactedThinking(data.clone())); - } - }; - } - - let mut cache_message = true; - let mut tool_results_message = LanguageModelRequestMessage { - role: Role::User, - content: Vec::new(), - cache: false, - }; - for (tool_use, tool_result) in self.tool_use.tool_results(message.id) { - if let Some(tool_result) = tool_result { - request_message - .content - .push(MessageContent::ToolUse(tool_use.clone())); - tool_results_message - .content - .push(MessageContent::ToolResult(LanguageModelToolResult { - tool_use_id: tool_use.id.clone(), - tool_name: tool_result.tool_name.clone(), - is_error: tool_result.is_error, - content: if tool_result.content.is_empty() { - // Surprisingly, the API fails if we return an empty string here. - // It thinks we are sending a tool use without a tool result. - "".into() - } else { - tool_result.content.clone() - }, - output: None, - })); - } else { - cache_message = false; - log::debug!( - "skipped tool use {:?} because it is still pending", - tool_use - ); - } - } - - if cache_message { - message_ix_to_cache = Some(request.messages.len()); - } - request.messages.push(request_message); - - if !tool_results_message.content.is_empty() { - if cache_message { - message_ix_to_cache = Some(request.messages.len()); - } - request.messages.push(tool_results_message); - } - } - - // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching - if let Some(message_ix_to_cache) = message_ix_to_cache { - request.messages[message_ix_to_cache].cache = true; - } - - request.tools = available_tools; - request.mode = if model.supports_burn_mode() { - Some(self.completion_mode.into()) - } else { - Some(CompletionMode::Normal.into()) - }; - - request - } - - fn to_summarize_request( - &self, - model: &Arc, - intent: CompletionIntent, - added_user_message: String, - cx: &App, - ) -> LanguageModelRequest { - let mut request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: Some(intent), - mode: None, - messages: vec![], - tools: Vec::new(), - tool_choice: None, - stop: Vec::new(), - temperature: AgentSettings::temperature_for_model(model, cx), - }; - - for message in &self.messages { - let mut request_message = LanguageModelRequestMessage { - role: message.role, - content: Vec::new(), - cache: false, - }; - - for segment in &message.segments { - match segment { - MessageSegment::Text(text) => request_message - .content - .push(MessageContent::Text(text.clone())), - MessageSegment::Thinking { .. } => {} - MessageSegment::RedactedThinking(_) => {} - } - } - - if request_message.content.is_empty() { - continue; - } - - request.messages.push(request_message); - } - - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text(added_user_message)], - cache: false, - }); - - request - } - - pub fn stream_completion( - &mut self, - request: LanguageModelRequest, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) { - self.tool_use_limit_reached = false; - - let pending_completion_id = post_inc(&mut self.completion_count); - let mut request_callback_parameters = if self.request_callback.is_some() { - Some((request.clone(), Vec::new())) - } else { - None - }; - let prompt_id = self.last_prompt_id.clone(); - let tool_use_metadata = ToolUseMetadata { - model: model.clone(), - thread_id: self.id.clone(), - prompt_id: prompt_id.clone(), - }; - - self.last_received_chunk_at = Some(Instant::now()); - - let task = cx.spawn(async move |thread, cx| { - let stream_completion_future = model.stream_completion(request, &cx); - let initial_token_usage = - thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); - let stream_completion = async { - let mut events = stream_completion_future.await?; - - let mut stop_reason = StopReason::EndTurn; - let mut current_token_usage = TokenUsage::default(); - - thread - .update(cx, |_thread, cx| { - cx.emit(ThreadEvent::NewRequest); - }) - .ok(); - - let mut request_assistant_message_id = None; - - while let Some(event) = events.next().await { - if let Some((_, response_events)) = request_callback_parameters.as_mut() { - response_events - .push(event.as_ref().map_err(|error| error.to_string()).cloned()); - } - - thread.update(cx, |thread, cx| { - let event = match event { - Ok(event) => event, - Err(error) => { - match error { - LanguageModelCompletionError::RateLimitExceeded { retry_after } => { - anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after }); - } - LanguageModelCompletionError::Overloaded => { - anyhow::bail!(LanguageModelKnownError::Overloaded); - } - LanguageModelCompletionError::ApiInternalServerError =>{ - anyhow::bail!(LanguageModelKnownError::ApiInternalServerError); - } - LanguageModelCompletionError::PromptTooLarge { tokens } => { - let tokens = tokens.unwrap_or_else(|| { - // We didn't get an exact token count from the API, so fall back on our estimate. - thread.total_token_usage() - .map(|usage| usage.total) - .unwrap_or(0) - // We know the context window was exceeded in practice, so if our estimate was - // lower than max tokens, the estimate was wrong; return that we exceeded by 1. - .max(model.max_token_count().saturating_add(1)) - }); - - anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens }) - } - LanguageModelCompletionError::ApiReadResponseError(io_error) => { - anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error)); - } - LanguageModelCompletionError::UnknownResponseFormat(error) => { - anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error)); - } - LanguageModelCompletionError::HttpResponseError { status, ref body } => { - if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) { - anyhow::bail!(known_error); - } else { - return Err(error.into()); - } - } - LanguageModelCompletionError::DeserializeResponse(error) => { - anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error)); - } - LanguageModelCompletionError::BadInputJson { - id, - tool_name, - raw_input: invalid_input_json, - json_parse_error, - } => { - thread.receive_invalid_tool_json( - id, - tool_name, - invalid_input_json, - json_parse_error, - window, - cx, - ); - return Ok(()); - } - // These are all errors we can't automatically attempt to recover from (e.g. by retrying) - err @ LanguageModelCompletionError::BadRequestFormat | - err @ LanguageModelCompletionError::AuthenticationError | - err @ LanguageModelCompletionError::PermissionError | - err @ LanguageModelCompletionError::ApiEndpointNotFound | - err @ LanguageModelCompletionError::SerializeRequest(_) | - err @ LanguageModelCompletionError::BuildRequestBody(_) | - err @ LanguageModelCompletionError::HttpSend(_) => { - anyhow::bail!(err); - } - LanguageModelCompletionError::Other(error) => { - return Err(error); - } - } - } - }; - - match event { - LanguageModelCompletionEvent::StartMessage { .. } => { - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::Text(String::new())], - cx, - )); - } - LanguageModelCompletionEvent::Stop(reason) => { - stop_reason = reason; - } - LanguageModelCompletionEvent::UsageUpdate(token_usage) => { - thread.update_token_usage_at_last_message(token_usage); - thread.cumulative_token_usage = thread.cumulative_token_usage - + token_usage - - current_token_usage; - current_token_usage = token_usage; - } - LanguageModelCompletionEvent::Text(chunk) => { - thread.received_chunk(); - - cx.emit(ThreadEvent::ReceivedTextChunk); - if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant - && !thread.tool_use.has_tool_results(last_message.id) - { - last_message.push_text(&chunk); - cx.emit(ThreadEvent::StreamedAssistantText( - last_message.id, - chunk, - )); - } else { - // If we won't have an Assistant message yet, assume this chunk marks the beginning - // of a new Assistant response. - // - // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it - // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::Text(chunk.to_string())], - cx, - )); - }; - } - } - LanguageModelCompletionEvent::Thinking { - text: chunk, - signature, - } => { - thread.received_chunk(); - - if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant - && !thread.tool_use.has_tool_results(last_message.id) - { - last_message.push_thinking(&chunk, signature); - cx.emit(ThreadEvent::StreamedAssistantThinking( - last_message.id, - chunk, - )); - } else { - // If we won't have an Assistant message yet, assume this chunk marks the beginning - // of a new Assistant response. - // - // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it - // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::Thinking { - text: chunk.to_string(), - signature, - }], - cx, - )); - }; - } - } - LanguageModelCompletionEvent::RedactedThinking { - data - } => { - thread.received_chunk(); - - if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant - && !thread.tool_use.has_tool_results(last_message.id) - { - last_message.push_redacted_thinking(data); - } else { - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::RedactedThinking(data)], - cx, - )); - }; - } - } - LanguageModelCompletionEvent::ToolUse(tool_use) => { - let last_assistant_message_id = request_assistant_message_id - .unwrap_or_else(|| { - let new_assistant_message_id = - thread.insert_assistant_message(vec![], cx); - request_assistant_message_id = - Some(new_assistant_message_id); - new_assistant_message_id - }); - - let tool_use_id = tool_use.id.clone(); - let streamed_input = if tool_use.is_input_complete { - None - } else { - Some((&tool_use.input).clone()) - }; - - let ui_text = thread.tool_use.request_tool_use( - last_assistant_message_id, - tool_use, - tool_use_metadata.clone(), - cx, - ); - - if let Some(input) = streamed_input { - cx.emit(ThreadEvent::StreamedToolUse { - tool_use_id, - ui_text, - input, - }); - } - } - LanguageModelCompletionEvent::StatusUpdate(status_update) => { - if let Some(completion) = thread - .pending_completions - .iter_mut() - .find(|completion| completion.id == pending_completion_id) - { - match status_update { - CompletionRequestStatus::Queued { - position, - } => { - completion.queue_state = QueueState::Queued { position }; - } - CompletionRequestStatus::Started => { - completion.queue_state = QueueState::Started; - } - CompletionRequestStatus::Failed { - code, message, request_id - } => { - anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}"); - } - CompletionRequestStatus::UsageUpdated { - amount, limit - } => { - thread.update_model_request_usage(amount as u32, limit, cx); - } - CompletionRequestStatus::ToolUseLimitReached => { - thread.tool_use_limit_reached = true; - cx.emit(ThreadEvent::ToolUseLimitReached); - } - } - } - } - } - - thread.touch_updated_at(); - cx.emit(ThreadEvent::StreamedCompletion); - cx.notify(); - - thread.auto_capture_telemetry(cx); - Ok(()) - })??; - - smol::future::yield_now().await; - } - - thread.update(cx, |thread, cx| { - thread.last_received_chunk_at = None; - thread - .pending_completions - .retain(|completion| completion.id != pending_completion_id); - - // If there is a response without tool use, summarize the message. Otherwise, - // allow two tool uses before summarizing. - if matches!(thread.summary, ThreadSummary::Pending) - && thread.messages.len() >= 2 - && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) - { - thread.summarize(cx); - } - })?; - - anyhow::Ok(stop_reason) - }; - - let result = stream_completion.await; - let mut retry_scheduled = false; - - thread - .update(cx, |thread, cx| { - thread.finalize_pending_checkpoint(cx); - match result.as_ref() { - Ok(stop_reason) => { - match stop_reason { - StopReason::ToolUse => { - let tool_uses = thread.use_pending_tools(window, model.clone(), cx); - cx.emit(ThreadEvent::UsePendingTools { tool_uses }); - } - StopReason::EndTurn | StopReason::MaxTokens => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - } - StopReason::Refusal => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - - // Remove the turn that was refused. - // - // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal - { - let mut messages_to_remove = Vec::new(); - - for (ix, message) in thread.messages.iter().enumerate().rev() { - messages_to_remove.push(message.id); - - if message.role == Role::User { - if ix == 0 { - break; - } - - if let Some(prev_message) = thread.messages.get(ix - 1) { - if prev_message.role == Role::Assistant { - break; - } - } - } - } - - for message_id in messages_to_remove { - thread.delete_message(message_id, cx); - } - } - - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Language model refusal".into(), - message: "Model refused to generate content for safety reasons.".into(), - })); - } - } - - // We successfully completed, so cancel any remaining retries. - thread.retry_state = None; - }, - Err(error) => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - - fn emit_generic_error(error: &anyhow::Error, cx: &mut Context) { - let error_message = error - .chain() - .map(|err| err.to_string()) - .collect::>() - .join("\n"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error interacting with language model".into(), - message: SharedString::from(error_message.clone()), - })); - } - - if error.is::() { - cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); - } else if let Some(error) = - error.downcast_ref::() - { - cx.emit(ThreadEvent::ShowError( - ThreadError::ModelRequestLimitReached { plan: error.plan }, - )); - } else if let Some(known_error) = - error.downcast_ref::() - { - match known_error { - LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => { - thread.exceeded_window_error = Some(ExceededWindowError { - model_id: model.id(), - token_count: *tokens, - }); - cx.notify(); - } - LanguageModelKnownError::RateLimitExceeded { retry_after } => { - let provider_name = model.provider_name(); - let error_message = format!( - "{}'s API rate limit exceeded", - provider_name.0.as_ref() - ); - - thread.handle_rate_limit_error( - &error_message, - *retry_after, - model.clone(), - intent, - window, - cx, - ); - retry_scheduled = true; - } - LanguageModelKnownError::Overloaded => { - let provider_name = model.provider_name(); - let error_message = format!( - "{}'s API servers are overloaded right now", - provider_name.0.as_ref() - ); - - retry_scheduled = thread.handle_retryable_error( - &error_message, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); - } - } - LanguageModelKnownError::ApiInternalServerError => { - let provider_name = model.provider_name(); - let error_message = format!( - "{}'s API server reported an internal server error", - provider_name.0.as_ref() - ); - - retry_scheduled = thread.handle_retryable_error( - &error_message, - model.clone(), - intent, - window, - cx, - ); - if !retry_scheduled { - emit_generic_error(error, cx); - } - } - LanguageModelKnownError::ReadResponseError(_) | - LanguageModelKnownError::DeserializeResponse(_) | - LanguageModelKnownError::UnknownResponseFormat(_) => { - // In the future we will attempt to re-roll response, but only once - emit_generic_error(error, cx); - } - } - } else { - emit_generic_error(error, cx); - } - - if !retry_scheduled { - thread.cancel_last_completion(window, cx); - } - } - } - - if !retry_scheduled { - cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); - } - - if let Some((request_callback, (request, response_events))) = thread - .request_callback - .as_mut() - .zip(request_callback_parameters.as_ref()) - { - request_callback(request, response_events); - } - - thread.auto_capture_telemetry(cx); - - if let Ok(initial_usage) = initial_token_usage { - let usage = thread.cumulative_token_usage - initial_usage; - - telemetry::event!( - "Assistant Thread Completion", - thread_id = thread.id().to_string(), - prompt_id = prompt_id, - model = model.telemetry_id(), - model_provider = model.provider_id().to_string(), - input_tokens = usage.input_tokens, - output_tokens = usage.output_tokens, - cache_creation_input_tokens = usage.cache_creation_input_tokens, - cache_read_input_tokens = usage.cache_read_input_tokens, - ); - } - }) - .ok(); - }); - - self.pending_completions.push(PendingCompletion { - id: pending_completion_id, - queue_state: QueueState::Sending, - _task: task, - }); - } - - pub fn summarize(&mut self, cx: &mut Context) { - let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { - println!("No thread summary model"); - return; - }; - - if !model.provider.is_authenticated(cx) { - return; - } - - let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt"); - - let request = self.to_summarize_request( - &model.model, - CompletionIntent::ThreadSummarization, - added_user_message.into(), - cx, - ); - - self.summary = ThreadSummary::Generating; - - self.pending_summary = cx.spawn(async move |this, cx| { - let result = async { - let mut messages = model.model.stream_completion(request, &cx).await?; - - let mut new_summary = String::new(); - while let Some(event) = messages.next().await { - let Ok(event) = event else { - continue; - }; - let text = match event { - LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - ) => { - this.update(cx, |thread, cx| { - thread.update_model_request_usage(amount as u32, limit, cx); - })?; - continue; - } - _ => continue, - }; - - let mut lines = text.lines(); - new_summary.extend(lines.next()); - - // Stop if the LLM generated multiple lines. - if lines.next().is_some() { - break; - } - } - - anyhow::Ok(new_summary) - } - .await; - - this.update(cx, |this, cx| { - match result { - Ok(new_summary) => { - if new_summary.is_empty() { - this.summary = ThreadSummary::Error; - } else { - this.summary = ThreadSummary::Ready(new_summary.into()); - } - } - Err(err) => { - this.summary = ThreadSummary::Error; - log::error!("Failed to generate thread summary: {}", err); - } - } - cx.emit(ThreadEvent::SummaryGenerated); - }) - .log_err()?; - - Some(()) - }); - } - - fn handle_rate_limit_error( - &mut self, - error_message: &str, - retry_after: Duration, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) { - // For rate limit errors, we only retry once with the specified duration - let retry_message = format!( - "{error_message}. Retrying in {} seconds…", - retry_after.as_secs() - ); - - // Add a UI-only message instead of a regular message - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text(retry_message)], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: false, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - // Schedule the retry - let thread_handle = cx.entity().downgrade(); - - cx.spawn(async move |_thread, cx| { - cx.background_executor().timer(retry_after).await; - - thread_handle - .update(cx, |thread, cx| { - // Retry the completion - thread.send_to_model(model, intent, window, cx); - }) - .log_err(); - }) - .detach(); - } - - fn handle_retryable_error( - &mut self, - error_message: &str, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) -> bool { - self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx) - } - - fn handle_retryable_error_with_delay( - &mut self, - error_message: &str, - custom_delay: Option, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) -> bool { - let retry_state = self.retry_state.get_or_insert(RetryState { - attempt: 0, - max_attempts: MAX_RETRY_ATTEMPTS, - intent, - }); - - retry_state.attempt += 1; - let attempt = retry_state.attempt; - let max_attempts = retry_state.max_attempts; - let intent = retry_state.intent; - - if attempt <= max_attempts { - // Use custom delay if provided (e.g., from rate limit), otherwise exponential backoff - let delay = if let Some(custom_delay) = custom_delay { - custom_delay - } else { - let delay_secs = BASE_RETRY_DELAY_SECS * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) - }; - - // Add a transient message to inform the user - let delay_secs = delay.as_secs(); - let retry_message = format!( - "{}. Retrying (attempt {} of {}) in {} seconds...", - error_message, attempt, max_attempts, delay_secs - ); - - // Add a UI-only message instead of a regular message - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text(retry_message)], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: false, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - - // Schedule the retry - let thread_handle = cx.entity().downgrade(); - - cx.spawn(async move |_thread, cx| { - cx.background_executor().timer(delay).await; - - thread_handle - .update(cx, |thread, cx| { - // Retry the completion - thread.send_to_model(model, intent, window, cx); - }) - .log_err(); - }) - .detach(); - - true - } else { - // Max retries exceeded - self.retry_state = None; - - let notification_text = if max_attempts == 1 { - "Failed after retrying.".into() - } else { - format!("Failed after retrying {} times.", max_attempts).into() - }; - - // Stop generating since we're giving up on retrying. - self.pending_completions.clear(); - - cx.emit(ThreadEvent::RetriesFailed { - message: notification_text, - }); - - false - } - } - - pub fn start_generating_detailed_summary_if_needed( - &mut self, - thread_store: WeakEntity, - cx: &mut Context, - ) { - let Some(last_message_id) = self.messages.last().map(|message| message.id) else { - return; - }; - - match &*self.detailed_summary_rx.borrow() { - DetailedSummaryState::Generating { message_id, .. } - | DetailedSummaryState::Generated { message_id, .. } - if *message_id == last_message_id => - { - // Already up-to-date - return; - } - _ => {} - } - - let Some(ConfiguredModel { model, provider }) = - LanguageModelRegistry::read_global(cx).thread_summary_model() - else { - return; - }; - - if !provider.is_authenticated(cx) { - return; - } - - let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt"); - - let request = self.to_summarize_request( - &model, - CompletionIntent::ThreadContextSummarization, - added_user_message.into(), - cx, - ); + let summary = self.agent_thread.summary(); *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating { message_id: last_message_id, @@ -2269,8 +788,7 @@ impl Thread { // which result to prefer (the old task could complete after the new one, resulting in a // stale summary). self.detailed_summary_task = cx.spawn(async move |thread, cx| { - let stream = model.stream_completion_text(request, &cx); - let Some(mut messages) = stream.await.log_err() else { + let Some(summary) = summary.await.log_err() else { thread .update(cx, |thread, _cx| { *thread.detailed_summary_tx.borrow_mut() = @@ -2280,33 +798,15 @@ impl Thread { return None; }; - let mut new_detailed_summary = String::new(); - - while let Some(chunk) = messages.stream.next().await { - if let Some(chunk) = chunk.log_err() { - new_detailed_summary.push_str(&chunk); - } - } - thread .update(cx, |thread, _cx| { *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated { - text: new_detailed_summary.into(), + text: summary.into(), message_id: last_message_id, }; }) .ok()?; - // Save thread so its summary can be reused later - if let Some(thread) = thread.upgrade() { - if let Ok(Ok(save_task)) = cx.update(|cx| { - thread_store - .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) - }) { - save_task.await.log_err(); - } - } - Some(()) }); } @@ -2343,384 +843,117 @@ impl Thread { ) } - pub fn use_pending_tools( - &mut self, - window: Option, - model: Arc, - cx: &mut Context, - ) -> Vec { - self.auto_capture_telemetry(cx); - let request = - Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx)); - let pending_tool_uses = self - .tool_use - .pending_tool_uses() - .into_iter() - .filter(|tool_use| tool_use.status.is_idle()) - .cloned() - .collect::>(); - - for tool_use in pending_tool_uses.iter() { - self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx); - } + pub fn feedback(&self) -> Option { + self.feedback + } - pending_tool_uses + pub fn message_feedback(&self, message_id: MessageId) -> Option { + self.message_feedback.get(&message_id).copied() } - fn use_pending_tool( + pub fn report_message_feedback( &mut self, - tool_use: PendingToolUse, - request: Arc, - model: Arc, - window: Option, + message_id: MessageId, + feedback: ThreadFeedback, cx: &mut Context, - ) { - let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else { - return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx); - }; - - if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) { - return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx); - } - - if tool.needs_confirmation(&tool_use.input, cx) - && !AgentSettings::get_global(cx).always_allow_tool_actions - { - self.tool_use.confirm_tool_use( - tool_use.id, - tool_use.ui_text, - tool_use.input, - request, - tool, - ); - cx.emit(ThreadEvent::ToolConfirmationNeeded); - } else { - self.run_tool( - tool_use.id, - tool_use.ui_text, - tool_use.input, - request, - tool, - model, - window, - cx, - ); - } + ) -> Task> { + todo!() + // if self.message_feedback.get(&message_id) == Some(&feedback) { + // return Task::ready(Ok(())); + // } + + // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); + // let serialized_thread = self.serialize(cx); + // let thread_id = self.id().clone(); + // let client = self.project.read(cx).client(); + + // let enabled_tool_names: Vec = self + // .profile + // .enabled_tools(cx) + // .iter() + // .map(|tool| tool.name()) + // .collect(); + + // self.message_feedback.insert(message_id, feedback); + + // cx.notify(); + + // let message_content = self + // .message(message_id) + // .map(|msg| msg.to_string()) + // .unwrap_or_default(); + + // cx.background_spawn(async move { + // let final_project_snapshot = final_project_snapshot.await; + // let serialized_thread = serialized_thread.await?; + // let thread_data = + // serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null); + + // let rating = match feedback { + // ThreadFeedback::Positive => "positive", + // ThreadFeedback::Negative => "negative", + // }; + // telemetry::event!( + // "Assistant Thread Rated", + // rating, + // thread_id, + // enabled_tool_names, + // message_id = message_id, + // message_content, + // thread_data, + // final_project_snapshot + // ); + // client.telemetry().flush_events().await; + + // Ok(()) + // }) } - pub fn handle_hallucinated_tool_use( + pub fn report_feedback( &mut self, - tool_use_id: LanguageModelToolUseId, - hallucinated_tool_name: Arc, - window: Option, - cx: &mut Context, - ) { - let available_tools = self.profile.enabled_tools(cx); - - let tool_list = available_tools - .iter() - .map(|tool| format!("- {}: {}", tool.name(), tool.description())) - .collect::>() - .join("\n"); - - let error_message = format!( - "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}", - hallucinated_tool_name, tool_list - ); - - let pending_tool_use = self.tool_use.insert_tool_output( - tool_use_id.clone(), - hallucinated_tool_name, - Err(anyhow!("Missing tool call: {error_message}")), - self.configured_model.as_ref(), - ); - - cx.emit(ThreadEvent::MissingToolUse { - tool_use_id: tool_use_id.clone(), - ui_text: error_message.into(), - }); - - self.tool_finished(tool_use_id, pending_tool_use, false, window, cx); - } - - pub fn receive_invalid_tool_json( - &mut self, - tool_use_id: LanguageModelToolUseId, - tool_name: Arc, - invalid_json: Arc, - error: String, - window: Option, - cx: &mut Context, - ) { - log::error!("The model returned invalid input JSON: {invalid_json}"); - - let pending_tool_use = self.tool_use.insert_tool_output( - tool_use_id.clone(), - tool_name, - Err(anyhow!("Error parsing input JSON: {error}")), - self.configured_model.as_ref(), - ); - let ui_text = if let Some(pending_tool_use) = &pending_tool_use { - pending_tool_use.ui_text.clone() - } else { - log::error!( - "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)." - ); - format!("Unknown tool {}", tool_use_id).into() - }; - - cx.emit(ThreadEvent::InvalidToolInput { - tool_use_id: tool_use_id.clone(), - ui_text, - invalid_input_json: invalid_json, - }); - - self.tool_finished(tool_use_id, pending_tool_use, false, window, cx); - } - - pub fn run_tool( - &mut self, - tool_use_id: LanguageModelToolUseId, - ui_text: impl Into, - input: serde_json::Value, - request: Arc, - tool: Arc, - model: Arc, - window: Option, - cx: &mut Context, - ) { - let task = - self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx); - self.tool_use - .run_pending_tool(tool_use_id, ui_text.into(), task); - } - - fn spawn_tool_use( - &mut self, - tool_use_id: LanguageModelToolUseId, - request: Arc, - input: serde_json::Value, - tool: Arc, - model: Arc, - window: Option, - cx: &mut Context, - ) -> Task<()> { - let tool_name: Arc = tool.name().into(); - - let tool_result = tool.run( - input, - request, - self.project.clone(), - self.action_log.clone(), - model, - window, - cx, - ); - - // Store the card separately if it exists - if let Some(card) = tool_result.card.clone() { - self.tool_use - .insert_tool_result_card(tool_use_id.clone(), card); - } - - cx.spawn({ - async move |thread: WeakEntity, cx| { - let output = tool_result.output.await; - - thread - .update(cx, |thread, cx| { - let pending_tool_use = thread.tool_use.insert_tool_output( - tool_use_id.clone(), - tool_name, - output, - thread.configured_model.as_ref(), - ); - thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx); - }) - .ok(); - } - }) - } - - fn tool_finished( - &mut self, - tool_use_id: LanguageModelToolUseId, - pending_tool_use: Option, - canceled: bool, - window: Option, - cx: &mut Context, - ) { - if self.all_tools_finished() { - if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() { - if !canceled { - self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx); - } - self.auto_capture_telemetry(cx); - } - } - - cx.emit(ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - }); - } - - /// Cancels the last pending completion, if there are any pending. - /// - /// Returns whether a completion was canceled. - pub fn cancel_last_completion( - &mut self, - window: Option, - cx: &mut Context, - ) -> bool { - let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some(); - - self.retry_state = None; - - for pending_tool_use in self.tool_use.cancel_pending() { - canceled = true; - self.tool_finished( - pending_tool_use.id.clone(), - Some(pending_tool_use), - true, - window, - cx, - ); - } - - if canceled { - cx.emit(ThreadEvent::CompletionCanceled); - - // When canceled, we always want to insert the checkpoint. - // (We skip over finalize_pending_checkpoint, because it - // would conclude we didn't have anything to insert here.) - if let Some(checkpoint) = self.pending_checkpoint.take() { - self.insert_checkpoint(checkpoint, cx); - } - } else { - self.finalize_pending_checkpoint(cx); - } - - canceled - } - - /// Signals that any in-progress editing should be canceled. - /// - /// This method is used to notify listeners (like ActiveThread) that - /// they should cancel any editing operations. - pub fn cancel_editing(&mut self, cx: &mut Context) { - cx.emit(ThreadEvent::CancelEditing); - } - - pub fn feedback(&self) -> Option { - self.feedback - } - - pub fn message_feedback(&self, message_id: MessageId) -> Option { - self.message_feedback.get(&message_id).copied() - } - - pub fn report_message_feedback( - &mut self, - message_id: MessageId, - feedback: ThreadFeedback, - cx: &mut Context, - ) -> Task> { - if self.message_feedback.get(&message_id) == Some(&feedback) { - return Task::ready(Ok(())); - } - - let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); - let serialized_thread = self.serialize(cx); - let thread_id = self.id().clone(); - let client = self.project.read(cx).client(); - - let enabled_tool_names: Vec = self - .profile - .enabled_tools(cx) - .iter() - .map(|tool| tool.name()) - .collect(); - - self.message_feedback.insert(message_id, feedback); - - cx.notify(); - - let message_content = self - .message(message_id) - .map(|msg| msg.to_string()) - .unwrap_or_default(); - - cx.background_spawn(async move { - let final_project_snapshot = final_project_snapshot.await; - let serialized_thread = serialized_thread.await?; - let thread_data = - serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null); - - let rating = match feedback { - ThreadFeedback::Positive => "positive", - ThreadFeedback::Negative => "negative", - }; - telemetry::event!( - "Assistant Thread Rated", - rating, - thread_id, - enabled_tool_names, - message_id = message_id.0, - message_content, - thread_data, - final_project_snapshot - ); - client.telemetry().flush_events().await; - - Ok(()) - }) - } - - pub fn report_feedback( - &mut self, - feedback: ThreadFeedback, - cx: &mut Context, - ) -> Task> { - let last_assistant_message_id = self - .messages - .iter() - .rev() - .find(|msg| msg.role == Role::Assistant) - .map(|msg| msg.id); - - if let Some(message_id) = last_assistant_message_id { - self.report_message_feedback(message_id, feedback, cx) - } else { - let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); - let serialized_thread = self.serialize(cx); - let thread_id = self.id().clone(); - let client = self.project.read(cx).client(); - self.feedback = Some(feedback); - cx.notify(); - - cx.background_spawn(async move { - let final_project_snapshot = final_project_snapshot.await; - let serialized_thread = serialized_thread.await?; - let thread_data = serde_json::to_value(serialized_thread) - .unwrap_or_else(|_| serde_json::Value::Null); - - let rating = match feedback { - ThreadFeedback::Positive => "positive", - ThreadFeedback::Negative => "negative", - }; - telemetry::event!( - "Assistant Thread Rated", - rating, - thread_id, - thread_data, - final_project_snapshot - ); - client.telemetry().flush_events().await; - - Ok(()) - }) - } + feedback: ThreadFeedback, + cx: &mut Context, + ) -> Task> { + todo!() + // let last_assistant_message_id = self + // .messages + // .iter() + // .rev() + // .find(|msg| msg.role == Role::Assistant) + // .map(|msg| msg.id); + + // if let Some(message_id) = last_assistant_message_id { + // self.report_message_feedback(message_id, feedback, cx) + // } else { + // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); + // let serialized_thread = self.serialize(cx); + // let thread_id = self.id().clone(); + // let client = self.project.read(cx).client(); + // self.feedback = Some(feedback); + // cx.notify(); + + // cx.background_spawn(async move { + // let final_project_snapshot = final_project_snapshot.await; + // let serialized_thread = serialized_thread.await?; + // let thread_data = serde_json::to_value(serialized_thread) + // .unwrap_or_else(|_| serde_json::Value::Null); + + // let rating = match feedback { + // ThreadFeedback::Positive => "positive", + // ThreadFeedback::Negative => "negative", + // }; + // telemetry::event!( + // "Assistant Thread Rated", + // rating, + // thread_id, + // thread_data, + // final_project_snapshot + // ); + // client.telemetry().flush_events().await; + + // Ok(()) + // }) + // } } /// Create a snapshot of the current project state including git information and unsaved buffers. @@ -2840,86 +1073,87 @@ impl Thread { } pub fn to_markdown(&self, cx: &App) -> Result { - let mut markdown = Vec::new(); - - let summary = self.summary().or_default(); - writeln!(markdown, "# {summary}\n")?; - - for message in self.messages() { - writeln!( - markdown, - "## {role}\n", - role = match message.role { - Role::User => "User", - Role::Assistant => "Agent", - Role::System => "System", - } - )?; - - if !message.loaded_context.text.is_empty() { - writeln!(markdown, "{}", message.loaded_context.text)?; - } - - if !message.loaded_context.images.is_empty() { - writeln!( - markdown, - "\n{} images attached as context.\n", - message.loaded_context.images.len() - )?; - } - - for segment in &message.segments { - match segment { - MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, - MessageSegment::Thinking { text, .. } => { - writeln!(markdown, "\n{}\n\n", text)? - } - MessageSegment::RedactedThinking(_) => {} - } - } - - for tool_use in self.tool_uses_for_message(message.id, cx) { - writeln!( - markdown, - "**Use Tool: {} ({})**", - tool_use.name, tool_use.id - )?; - writeln!(markdown, "```json")?; - writeln!( - markdown, - "{}", - serde_json::to_string_pretty(&tool_use.input)? - )?; - writeln!(markdown, "```")?; - } - - for tool_result in self.tool_results_for_message(message.id) { - write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?; - if tool_result.is_error { - write!(markdown, " (Error)")?; - } - - writeln!(markdown, "**\n")?; - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - writeln!(markdown, "{text}")?; - } - LanguageModelToolResultContent::Image(image) => { - writeln!(markdown, "![Image](data:base64,{})", image.source)?; - } - } - - if let Some(output) = tool_result.output.as_ref() { - writeln!( - markdown, - "\n\nDebug Output:\n\n```json\n{}\n```\n", - serde_json::to_string_pretty(output)? - )?; - } - } - } - - Ok(String::from_utf8_lossy(&markdown).to_string()) + todo!() + // let mut markdown = Vec::new(); + + // let summary = self.summary().or_default(); + // writeln!(markdown, "# {summary}\n")?; + + // for message in self.messages() { + // writeln!( + // markdown, + // "## {role}\n", + // role = match message.role { + // Role::User => "User", + // Role::Assistant => "Agent", + // Role::System => "System", + // } + // )?; + + // if !message.loaded_context.text.is_empty() { + // writeln!(markdown, "{}", message.loaded_context.text)?; + // } + + // if !message.loaded_context.images.is_empty() { + // writeln!( + // markdown, + // "\n{} images attached as context.\n", + // message.loaded_context.images.len() + // )?; + // } + + // for segment in &message.segments { + // match segment { + // MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, + // MessageSegment::Thinking { text, .. } => { + // writeln!(markdown, "\n{}\n\n", text)? + // } + // MessageSegment::RedactedThinking(_) => {} + // } + // } + + // for tool_use in self.tool_uses_for_message(message.id, cx) { + // writeln!( + // markdown, + // "**Use Tool: {} ({})**", + // tool_use.name, tool_use.id + // )?; + // writeln!(markdown, "```json")?; + // writeln!( + // markdown, + // "{}", + // serde_json::to_string_pretty(&tool_use.input)? + // )?; + // writeln!(markdown, "```")?; + // } + + // for tool_result in self.tool_results_for_message(message.id) { + // write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?; + // if tool_result.is_error { + // write!(markdown, " (Error)")?; + // } + + // writeln!(markdown, "**\n")?; + // match &tool_result.content { + // LanguageModelToolResultContent::Text(text) => { + // writeln!(markdown, "{text}")?; + // } + // LanguageModelToolResultContent::Image(image) => { + // writeln!(markdown, "![Image](data:base64,{})", image.source)?; + // } + // } + + // if let Some(output) = tool_result.output.as_ref() { + // writeln!( + // markdown, + // "\n\nDebug Output:\n\n```json\n{}\n```\n", + // serde_json::to_string_pretty(output)? + // )?; + // } + // } + // } + + // Ok(String::from_utf8_lossy(&markdown).to_string()) } pub fn keep_edits_in_range( @@ -2958,47 +1192,48 @@ impl Thread { } pub fn auto_capture_telemetry(&mut self, cx: &mut Context) { - if !cx.has_flag::() { - return; - } - - let now = Instant::now(); - if let Some(last) = self.last_auto_capture_at { - if now.duration_since(last).as_secs() < 10 { - return; - } - } - - self.last_auto_capture_at = Some(now); - - let thread_id = self.id().clone(); - let github_login = self - .project - .read(cx) - .user_store() - .read(cx) - .current_user() - .map(|user| user.github_login.clone()); - let client = self.project.read(cx).client(); - let serialize_task = self.serialize(cx); - - cx.background_executor() - .spawn(async move { - if let Ok(serialized_thread) = serialize_task.await { - if let Ok(thread_data) = serde_json::to_value(serialized_thread) { - telemetry::event!( - "Agent Thread Auto-Captured", - thread_id = thread_id.to_string(), - thread_data = thread_data, - auto_capture_reason = "tracked_user", - github_login = github_login - ); - - client.telemetry().flush_events().await; - } - } - }) - .detach(); + todo!() + // if !cx.has_flag::() { + // return; + // } + + // let now = Instant::now(); + // if let Some(last) = self.last_auto_capture_at { + // if now.duration_since(last).as_secs() < 10 { + // return; + // } + // } + + // self.last_auto_capture_at = Some(now); + + // let thread_id = self.id().clone(); + // let github_login = self + // .project + // .read(cx) + // .user_store() + // .read(cx) + // .current_user() + // .map(|user| user.github_login.clone()); + // let client = self.project.read(cx).client(); + // let serialize_task = self.serialize(cx); + + // cx.background_executor() + // .spawn(async move { + // if let Ok(serialized_thread) = serialize_task.await { + // if let Ok(thread_data) = serde_json::to_value(serialized_thread) { + // telemetry::event!( + // "Agent Thread Auto-Captured", + // thread_id = thread_id.to_string(), + // thread_data = thread_data, + // auto_capture_reason = "tracked_user", + // github_login = github_login + // ); + + // client.telemetry().flush_events().await; + // } + // } + // }) + // .detach(); } pub fn cumulative_token_usage(&self) -> TokenUsage { @@ -3006,54 +1241,56 @@ impl Thread { } pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage { - let Some(model) = self.configured_model.as_ref() else { - return TotalTokenUsage::default(); - }; + todo!() + // let Some(model) = self.configured_model.as_ref() else { + // return TotalTokenUsage::default(); + // }; - let max = model.model.max_token_count(); + // let max = model.model.max_token_count(); - let index = self - .messages - .iter() - .position(|msg| msg.id == message_id) - .unwrap_or(0); + // let index = self + // .messages + // .iter() + // .position(|msg| msg.id == message_id) + // .unwrap_or(0); - if index == 0 { - return TotalTokenUsage { total: 0, max }; - } + // if index == 0 { + // return TotalTokenUsage { total: 0, max }; + // } - let token_usage = &self - .request_token_usage - .get(index - 1) - .cloned() - .unwrap_or_default(); + // let token_usage = &self + // .request_token_usage + // .get(index - 1) + // .cloned() + // .unwrap_or_default(); - TotalTokenUsage { - total: token_usage.total_tokens(), - max, - } + // TotalTokenUsage { + // total: token_usage.total_tokens(), + // max, + // } } pub fn total_token_usage(&self) -> Option { - let model = self.configured_model.as_ref()?; + todo!() + // let model = self.configured_model.as_ref()?; - let max = model.model.max_token_count(); + // let max = model.model.max_token_count(); - if let Some(exceeded_error) = &self.exceeded_window_error { - if model.model.id() == exceeded_error.model_id { - return Some(TotalTokenUsage { - total: exceeded_error.token_count, - max, - }); - } - } + // if let Some(exceeded_error) = &self.exceeded_window_error { + // if model.model.id() == exceeded_error.model_id { + // return Some(TotalTokenUsage { + // total: exceeded_error.token_count, + // max, + // }); + // } + // } - let total = self - .token_usage_at_last_message() - .unwrap_or_default() - .total_tokens(); + // let total = self + // .token_usage_at_last_message() + // .unwrap_or_default() + // .total_tokens(); - Some(TotalTokenUsage { total, max }) + // Some(TotalTokenUsage { total, max }) } fn token_usage_at_last_message(&self) -> Option { @@ -3086,26 +1323,6 @@ impl Thread { }) }); } - - pub fn deny_tool_use( - &mut self, - tool_use_id: LanguageModelToolUseId, - tool_name: Arc, - window: Option, - cx: &mut Context, - ) { - let err = Err(anyhow::anyhow!( - "Permission to run tool action denied by user" - )); - - self.tool_use.insert_tool_output( - tool_use_id.clone(), - tool_name, - err, - self.configured_model.as_ref(), - ); - self.tool_finished(tool_use_id.clone(), None, true, window, cx); - } } #[derive(Debug, Clone, Error)] @@ -3149,15 +1366,6 @@ pub enum ThreadEvent { MessageDeleted(MessageId), SummaryGenerated, SummaryChanged, - UsePendingTools { - tool_uses: Vec, - }, - ToolFinished { - #[allow(unused)] - tool_use_id: LanguageModelToolUseId, - /// The pending tool use that corresponds to this tool. - pending_tool_use: Option, - }, CheckpointChanged, ToolConfirmationNeeded, ToolUseLimitReached, @@ -3256,2129 +1464,2129 @@ fn resolve_tool_name_conflicts(tools: &[Arc]) -> Vec<(String, Arc -The following items were attached by the user. They are up-to-date and don't need to be re-read. - - -```rs {path_part} -fn main() {{ - println!("Hello, world!"); -}} -``` - - -"# - ); - - assert_eq!(message.role, Role::User); - assert_eq!(message.segments.len(), 1); - assert_eq!( - message.segments[0], - MessageSegment::Text("Please explain this code".to_string()) - ); - assert_eq!(message.loaded_context.text, expected_context); - - // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - assert_eq!(request.messages.len(), 2); - let expected_full_message = format!("{}Please explain this code", expected_context); - assert_eq!(request.messages[1].string_contents(), expected_full_message); - } - - #[gpui::test] - async fn test_only_include_new_contexts(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project( - cx, - json!({ - "file1.rs": "fn function1() {}\n", - "file2.rs": "fn function2() {}\n", - "file3.rs": "fn function3() {}\n", - "file4.rs": "fn function4() {}\n", - }), - ) - .await; - - let (_, _thread_store, thread, context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // First message with context 1 - add_file_to_context(&project, &context_store, "test/file1.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), None) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await; - let message1_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx) - }); - - // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included) - add_file_to_context(&project, &context_store, "test/file2.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), None) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await; - let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx) - }); - - // Third message with all three contexts (contexts 1 and 2 should be skipped) - // - add_file_to_context(&project, &context_store, "test/file3.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), None) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await; - let message3_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx) - }); - - // Check what contexts are included in each message - let (message1, message2, message3) = thread.read_with(cx, |thread, _| { - ( - thread.message(message1_id).unwrap().clone(), - thread.message(message2_id).unwrap().clone(), - thread.message(message3_id).unwrap().clone(), - ) - }); - - // First message should include context 1 - assert!(message1.loaded_context.text.contains("file1.rs")); - - // Second message should include only context 2 (not 1) - assert!(!message2.loaded_context.text.contains("file1.rs")); - assert!(message2.loaded_context.text.contains("file2.rs")); - - // Third message should include only context 3 (not 1 or 2) - assert!(!message3.loaded_context.text.contains("file1.rs")); - assert!(!message3.loaded_context.text.contains("file2.rs")); - assert!(message3.loaded_context.text.contains("file3.rs")); - - // Check entire request to make sure all contexts are properly included - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - // The request should contain all 3 messages - assert_eq!(request.messages.len(), 4); - - // Check that the contexts are properly formatted in each message - assert!(request.messages[1].string_contents().contains("file1.rs")); - assert!(!request.messages[1].string_contents().contains("file2.rs")); - assert!(!request.messages[1].string_contents().contains("file3.rs")); - - assert!(!request.messages[2].string_contents().contains("file1.rs")); - assert!(request.messages[2].string_contents().contains("file2.rs")); - assert!(!request.messages[2].string_contents().contains("file3.rs")); - - assert!(!request.messages[3].string_contents().contains("file1.rs")); - assert!(!request.messages[3].string_contents().contains("file2.rs")); - assert!(request.messages[3].string_contents().contains("file3.rs")); - - add_file_to_context(&project, &context_store, "test/file4.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), Some(message2_id)) - }); - assert_eq!(new_contexts.len(), 3); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await - .loaded_context; - - assert!(!loaded_context.text.contains("file1.rs")); - assert!(loaded_context.text.contains("file2.rs")); - assert!(loaded_context.text.contains("file3.rs")); - assert!(loaded_context.text.contains("file4.rs")); - - let new_contexts = context_store.update(cx, |store, cx| { - // Remove file4.rs - store.remove_context(&loaded_context.contexts[2].handle(), cx); - store.new_context_for_thread(thread.read(cx), Some(message2_id)) - }); - assert_eq!(new_contexts.len(), 2); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await - .loaded_context; - - assert!(!loaded_context.text.contains("file1.rs")); - assert!(loaded_context.text.contains("file2.rs")); - assert!(loaded_context.text.contains("file3.rs")); - assert!(!loaded_context.text.contains("file4.rs")); - - let new_contexts = context_store.update(cx, |store, cx| { - // Remove file3.rs - store.remove_context(&loaded_context.contexts[1].handle(), cx); - store.new_context_for_thread(thread.read(cx), Some(message2_id)) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await - .loaded_context; - - assert!(!loaded_context.text.contains("file1.rs")); - assert!(loaded_context.text.contains("file2.rs")); - assert!(!loaded_context.text.contains("file3.rs")); - assert!(!loaded_context.text.contains("file4.rs")); - } - - #[gpui::test] - async fn test_message_without_files(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project( - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // Insert user message without any context (empty context vector) - let message_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "What is the best way to learn Rust?", - ContextLoadResult::default(), - None, - Vec::new(), - cx, - ) - }); - - // Check content and context in message object - let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); - - // Context should be empty when no files are included - assert_eq!(message.role, Role::User); - assert_eq!(message.segments.len(), 1); - assert_eq!( - message.segments[0], - MessageSegment::Text("What is the best way to learn Rust?".to_string()) - ); - assert_eq!(message.loaded_context.text, ""); - - // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - assert_eq!(request.messages.len(), 2); - assert_eq!( - request.messages[1].string_contents(), - "What is the best way to learn Rust?" - ); - - // Add second message, also without context - let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Are there any good books?", - ContextLoadResult::default(), - None, - Vec::new(), - cx, - ) - }); - - let message2 = - thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone()); - assert_eq!(message2.loaded_context.text, ""); - - // Check that both messages appear in the request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - assert_eq!(request.messages.len(), 3); - assert_eq!( - request.messages[1].string_contents(), - "What is the best way to learn Rust?" - ); - assert_eq!( - request.messages[2].string_contents(), - "Are there any good books?" - ); - } - - #[gpui::test] - async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project( - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, thread_store, thread, _context_store, _model) = - setup_test_environment(cx, project.clone()).await; - - // Check that we are starting with the default profile - let profile = cx.read(|cx| thread.read(cx).profile.clone()); - let tool_set = cx.read(|cx| thread_store.read(cx).tools()); - assert_eq!( - profile, - AgentProfile::new(AgentProfileId::default(), tool_set) - ); - } - - #[gpui::test] - async fn test_serializing_thread_profile(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project( - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, thread_store, thread, _context_store, _model) = - setup_test_environment(cx, project.clone()).await; - - // Profile gets serialized with default values - let serialized = thread - .update(cx, |thread, cx| thread.serialize(cx)) - .await - .unwrap(); - - assert_eq!(serialized.profile, Some(AgentProfileId::default())); - - let deserialized = cx.update(|cx| { - thread.update(cx, |thread, cx| { - Thread::deserialize( - thread.id.clone(), - serialized, - thread.project.clone(), - thread.tools.clone(), - thread.prompt_builder.clone(), - thread.project_context.clone(), - None, - cx, - ) - }) - }); - let tool_set = cx.read(|cx| thread_store.read(cx).tools()); - - assert_eq!( - deserialized.profile, - AgentProfile::new(AgentProfileId::default(), tool_set) - ); - } - - #[gpui::test] - async fn test_temperature_setting(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project( - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // Both model and provider - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: Some(model.provider_id().0.to_string().into()), - model: Some(model.id().0.clone()), - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, Some(0.66)); - - // Only model - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: None, - model: Some(model.id().0.clone()), - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, Some(0.66)); - - // Only provider - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: Some(model.provider_id().0.to_string().into()), - model: None, - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, Some(0.66)); - - // Same model name, different provider - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: Some("anthropic".into()), - model: Some(model.id().0.clone()), - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, None); - } - - #[gpui::test] - async fn test_thread_summary(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // Initial state should be pending - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Pending)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - - // Manually setting the summary should not be allowed in this state - thread.update(cx, |thread, cx| { - thread.set_summary("This should not work", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Pending)); - }); - - // Send a message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); - thread.send_to_model( - model.clone(), - CompletionIntent::ThreadSummarization, - None, - cx, - ); - }); - - let fake_model = model.as_fake(); - simulate_successful_response(&fake_model, cx); - - // Should start generating summary when there are >= 2 messages - thread.read_with(cx, |thread, _| { - assert_eq!(*thread.summary(), ThreadSummary::Generating); - }); - - // Should not be able to set the summary while generating - thread.update(cx, |thread, cx| { - thread.set_summary("This should not work either", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Generating)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - - cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief"); - fake_model.stream_last_completion_response(" Introduction"); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // Summary should be set - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), "Brief Introduction"); - }); - - // Now we should be able to set a summary - thread.update(cx, |thread, cx| { - thread.set_summary("Brief Intro", cx); - }); - - thread.read_with(cx, |thread, _| { - assert_eq!(thread.summary().or_default(), "Brief Intro"); - }); - - // Test setting an empty summary (should default to DEFAULT) - thread.update(cx, |thread, cx| { - thread.set_summary("", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - } - - #[gpui::test] - async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - test_summarize_error(&model, &thread, cx); - - // Now we should be able to set a summary - thread.update(cx, |thread, cx| { - thread.set_summary("Brief Intro", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), "Brief Intro"); - }); - } - - #[gpui::test] - async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - test_summarize_error(&model, &thread, cx); - - // Sending another message should not trigger another summarize request - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "How are you?", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - let fake_model = model.as_fake(); - simulate_successful_response(&fake_model, cx); - - thread.read_with(cx, |thread, _| { - // State is still Error, not Generating - assert!(matches!(thread.summary(), ThreadSummary::Error)); - }); - - // But the summarize request can be invoked manually - thread.update(cx, |thread, cx| { - thread.summarize(cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Generating)); - }); - - cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary"); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), "A successful summary"); - }); - } - - #[gpui::test] - fn test_resolve_tool_name_conflicts() { - use assistant_tool::{Tool, ToolSource}; - - assert_resolve_tool_name_conflicts( - vec![ - TestTool::new("tool1", ToolSource::Native), - TestTool::new("tool2", ToolSource::Native), - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), - ], - vec!["tool1", "tool2", "tool3"], - ); - - assert_resolve_tool_name_conflicts( - vec![ - TestTool::new("tool1", ToolSource::Native), - TestTool::new("tool2", ToolSource::Native), - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), - ], - vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"], - ); - - assert_resolve_tool_name_conflicts( - vec![ - TestTool::new("tool1", ToolSource::Native), - TestTool::new("tool2", ToolSource::Native), - TestTool::new("tool3", ToolSource::Native), - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), - ], - vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"], - ); - - // Test that tool with very long name is always truncated - assert_resolve_tool_name_conflicts( - vec![TestTool::new( - "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah", - ToolSource::Native, - )], - vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"], - ); - - // Test deduplication of tools with very long names, in this case the mcp server name should be truncated - assert_resolve_tool_name_conflicts( - vec![ - TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native), - TestTool::new( - "tool-with-very-very-very-long-name", - ToolSource::ContextServer { - id: "mcp-with-very-very-very-long-name".into(), - }, - ), - ], - vec![ - "tool-with-very-very-very-long-name", - "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name", - ], - ); - - fn assert_resolve_tool_name_conflicts( - tools: Vec, - expected: Vec>, - ) { - let tools: Vec> = tools - .into_iter() - .map(|t| Arc::new(t) as Arc) - .collect(); - let tools = resolve_tool_name_conflicts(&tools); - assert_eq!(tools.len(), expected.len()); - for (i, expected_name) in expected.into_iter().enumerate() { - let expected_name = expected_name.into(); - let actual_name = &tools[i].0; - assert_eq!( - actual_name, &expected_name, - "Expected '{}' got '{}' at index {}", - expected_name, actual_name, i - ); - } - } - - struct TestTool { - name: String, - source: ToolSource, - } - - impl TestTool { - fn new(name: impl Into, source: ToolSource) -> Self { - Self { - name: name.into(), - source, - } - } - } - - impl Tool for TestTool { - fn name(&self) -> String { - self.name.clone() - } - - fn icon(&self) -> IconName { - IconName::Ai - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { - true - } - - fn source(&self) -> ToolSource { - self.source.clone() - } - - fn description(&self) -> String { - "Test tool".to_string() - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - "Test tool".to_string() - } - - fn run( - self: Arc, - _input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - _cx: &mut App, - ) -> assistant_tool::ToolResult { - assistant_tool::ToolResult { - output: Task::ready(Err(anyhow::anyhow!("No content"))), - card: None, - } - } - } - } - - // Helper to create a model that returns errors - enum TestError { - Overloaded, - InternalServerError, - } - - struct ErrorInjector { - inner: Arc, - error_type: TestError, - } - - impl ErrorInjector { - fn new(error_type: TestError) -> Self { - Self { - inner: Arc::new(FakeLanguageModel::default()), - error_type, - } - } - } - - impl LanguageModel for ErrorInjector { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - _request: LanguageModelRequest, - _cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - let error = match self.error_type { - TestError::Overloaded => LanguageModelCompletionError::Overloaded, - TestError::InternalServerError => { - LanguageModelCompletionError::ApiInternalServerError - } - }; - async move { - let stream = futures::stream::once(async move { Err(error) }); - Ok(stream.boxed()) - } - .boxed() - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner - } - } - - #[gpui::test] - async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have default max attempts" - ); - }); - - // Check that a retry message was added - thread.read_with(cx, |thread, _| { - let mut messages = thread.messages(); - assert!( - messages.any(|msg| { - msg.role == Role::System - && msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("overloaded") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) - } else { - false - } - }) - }), - "Should have added a system retry message" - ); - }); - - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - assert_eq!(retry_count, 1, "Should have one retry message"); - } - - #[gpui::test] - async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Create model that returns internal server error - let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - // Check retry state on thread - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); - - // Check that a retry message was added with provider name - thread.read_with(cx, |thread, _| { - let mut messages = thread.messages(); - assert!( - messages.any(|msg| { - msg.role == Role::System - && msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("internal") - && text.contains("Fake") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) - } else { - false - } - }) - }), - "Should have added a system retry message with provider name" - ); - }); - - // Count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - assert_eq!(retry_count, 1, "Should have one retry message"); - } - - #[gpui::test] - async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track retry events and completion count - // Track completion events - let completion_count = Arc::new(Mutex::new(0)); - let completion_count_clone = completion_count.clone(); - - let _subscription = thread.update(cx, |_, cx| { - cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::NewRequest = event { - *completion_count_clone.lock() += 1; - } - }) - }); - - // First attempt - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Should have scheduled first retry - count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 1, "Should have scheduled first retry"); - - // Check retry state - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - }); - - // Advance clock for first retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); - cx.run_until_parked(); - - // Should have scheduled second retry - count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 2, "Should have scheduled second retry"); - - // Check retry state updated - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); - - // Advance clock for second retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2)); - cx.run_until_parked(); - - // Should have scheduled third retry - // Count all retry messages now - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have scheduled third retry" - ); - - // Check retry state updated - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!( - retry_state.attempt, MAX_RETRY_ATTEMPTS, - "Should be at max retry attempt" - ); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); - - // Advance clock for third retry (exponential backoff) - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4)); - cx.run_until_parked(); - - // No more retries should be scheduled after clock was advanced. - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should not exceed max retries" - ); - - // Final completion count should be initial + max retries - assert_eq!( - *completion_count.lock(), - (MAX_RETRY_ATTEMPTS + 1) as usize, - "Should have made initial + max retry attempts" - ); - } - - #[gpui::test] - async fn test_max_retries_exceeded(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track events - let retries_failed = Arc::new(Mutex::new(false)); - let retries_failed_clone = retries_failed.clone(); - - let _subscription = thread.update(cx, |_, cx| { - cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::RetriesFailed { .. } = event { - *retries_failed_clone.lock() = true; - } - }) - }); - - // Start initial completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Advance through all retries - for i in 0..MAX_RETRY_ATTEMPTS { - let delay = if i == 0 { - BASE_RETRY_DELAY_SECS - } else { - BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1) - }; - cx.executor().advance_clock(Duration::from_secs(delay)); - cx.run_until_parked(); - } - - // After the 3rd retry is scheduled, we need to wait for it to execute and fail - // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) - let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32); - cx.executor() - .advance_clock(Duration::from_secs(final_delay)); - cx.run_until_parked(); - - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - // After max retries, should emit RetriesFailed event - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have attempted max retries" - ); - assert!( - *retries_failed.lock(), - "Should emit RetriesFailed event after max retries exceeded" - ); - - // Retry state should be cleared - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_none(), - "Retry state should be cleared after max retries" - ); - - // Verify we have the expected number of retry messages - let retry_messages = thread - .messages - .iter() - .filter(|msg| msg.ui_only && msg.role == Role::System) - .count(); - assert_eq!( - retry_messages, MAX_RETRY_ATTEMPTS as usize, - "Should have one retry message per attempt" - ); - }); - } - - #[gpui::test] - async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // We'll use a wrapper to switch behavior after first failure - struct RetryTestModel { - inner: Arc, - failed_once: Arc>, - } - - impl LanguageModel for RetryTestModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - if !*self.failed_once.lock() { - *self.failed_once.lock() = true; - // Return error on first attempt - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::Overloaded) - }); - async move { Ok(stream.boxed()) }.boxed() - } else { - // Succeed on retry - self.inner.stream_completion(request, cx) - } - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner - } - } - - let model = Arc::new(RetryTestModel { - inner: Arc::new(FakeLanguageModel::default()), - failed_once: Arc::new(Mutex::new(false)), - }); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track message deletions - // Track when retry completes successfully - let retry_completed = Arc::new(Mutex::new(false)); - let retry_completed_clone = retry_completed.clone(); - - let _subscription = thread.update(cx, |_, cx| { - cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::StreamedCompletion = event { - *retry_completed_clone.lock() = true; - } - }) - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Get the retry message ID - let retry_message_id = thread.read_with(cx, |thread, _| { - thread - .messages() - .find(|msg| msg.role == Role::System && msg.ui_only) - .map(|msg| msg.id) - .expect("Should have a retry message") - }); - - // Wait for retry - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); - cx.run_until_parked(); - - // Stream some successful content - let fake_model = model.as_fake(); - // After the retry, there should be a new pending completion - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "Should have a pending completion after retry" - ); - fake_model.stream_completion_response(&pending[0], "Success!"); - fake_model.end_completion_stream(&pending[0]); - cx.run_until_parked(); - - // Check that the retry completed successfully - assert!( - *retry_completed.lock(), - "Retry should have completed successfully" - ); - - // Retry message should still exist but be marked as ui_only - thread.read_with(cx, |thread, _| { - let retry_msg = thread - .message(retry_message_id) - .expect("Retry message should still exist"); - assert!(retry_msg.ui_only, "Retry message should be ui_only"); - assert_eq!( - retry_msg.role, - Role::System, - "Retry message should have System role" - ); - }); - } - - #[gpui::test] - async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Create a model that fails once then succeeds - struct FailOnceModel { - inner: Arc, - failed_once: Arc>, - } - - impl LanguageModel for FailOnceModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - if !*self.failed_once.lock() { - *self.failed_once.lock() = true; - // Return error on first attempt - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::Overloaded) - }); - async move { Ok(stream.boxed()) }.boxed() - } else { - // Succeed on retry - self.inner.stream_completion(request, cx) - } - } - } - - let fail_once_model = Arc::new(FailOnceModel { - inner: Arc::new(FakeLanguageModel::default()), - failed_once: Arc::new(Mutex::new(false)), - }); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Test message", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - }); - - // Start completion with fail-once model - thread.update(cx, |thread, cx| { - thread.send_to_model( - fail_once_model.clone(), - CompletionIntent::UserPrompt, - None, - cx, - ); - }); - - cx.run_until_parked(); - - // Verify retry state exists after first failure - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_some(), - "Should have retry state after failure" - ); - }); - - // Wait for retry delay - cx.executor() - .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); - cx.run_until_parked(); - - // The retry should now use our FailOnceModel which should succeed - // We need to help the FakeLanguageModel complete the stream - let inner_fake = fail_once_model.inner.clone(); - - // Wait a bit for the retry to start - cx.run_until_parked(); - - // Check for pending completions and complete them - if let Some(pending) = inner_fake.pending_completions().first() { - inner_fake.stream_completion_response(pending, "Success!"); - inner_fake.end_completion_stream(pending); - } - cx.run_until_parked(); - - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_none(), - "Retry state should be cleared after successful completion" - ); - - let has_assistant_message = thread - .messages - .iter() - .any(|msg| msg.role == Role::Assistant && !msg.ui_only); - assert!( - has_assistant_message, - "Should have an assistant message after successful retry" - ); - }); - } - - #[gpui::test] - async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Create a model that returns rate limit error with retry_after - struct RateLimitModel { - inner: Arc, - } - - impl LanguageModel for RateLimitModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - _request: LanguageModelRequest, - _cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - async move { - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::RateLimitExceeded { - retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS), - }) - }); - Ok(stream.boxed()) - } - .boxed() - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner - } - } - - let model = Arc::new(RateLimitModel { - inner: Arc::new(FakeLanguageModel::default()), - }); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("rate limit exceeded") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 1, "Should have scheduled one retry"); - - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_none(), - "Rate limit errors should not set retry_state" - ); - }); - - // Verify we have one retry message - thread.read_with(cx, |thread, _| { - let retry_messages = thread - .messages - .iter() - .filter(|msg| { - msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("rate limit exceeded") - } else { - false - } - }) - }) - .count(); - assert_eq!( - retry_messages, 1, - "Should have one rate limit retry message" - ); - }); - - // Check that retry message doesn't include attempt count - thread.read_with(cx, |thread, _| { - let retry_message = thread - .messages - .iter() - .find(|msg| msg.role == Role::System && msg.ui_only) - .expect("Should have a retry message"); - - // Check that the message doesn't contain attempt count - if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { - assert!( - !text.contains("attempt"), - "Rate limit retry message should not contain attempt count" - ); - assert!( - text.contains(&format!( - "Retrying in {} seconds", - TEST_RATE_LIMIT_RETRY_SECS - )), - "Rate limit retry message should contain retry delay" - ); - } - }); - } - - #[gpui::test] - async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await; - - // Insert a regular user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Insert a UI-only message (like our retry notifications) - thread.update(cx, |thread, cx| { - let id = thread.next_message_id.post_inc(); - thread.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text( - "This is a UI-only message that should not be sent to the model".to_string(), - )], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: true, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - }); - - // Insert another regular message - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "How are you?", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - }); - - // Generate the completion request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - // Verify that the request only contains non-UI-only messages - // Should have system prompt + 2 user messages, but not the UI-only message - let user_messages: Vec<_> = request - .messages - .iter() - .filter(|msg| msg.role == Role::User) - .collect(); - assert_eq!( - user_messages.len(), - 2, - "Should have exactly 2 user messages" - ); - - // Verify the UI-only content is not present anywhere in the request - let request_text = request - .messages - .iter() - .flat_map(|msg| &msg.content) - .filter_map(|content| match content { - MessageContent::Text(text) => Some(text.as_str()), - _ => None, - }) - .collect::(); - - assert!( - !request_text.contains("UI-only message"), - "UI-only message content should not be in the request" - ); - - // Verify the thread still has all 3 messages (including UI-only) - thread.read_with(cx, |thread, _| { - assert_eq!( - thread.messages().count(), - 3, - "Thread should have 3 messages" - ); - assert_eq!( - thread.messages().filter(|m| m.ui_only).count(), - 1, - "Thread should have 1 UI-only message" - ); - }); - - // Verify that UI-only messages are not serialized - let serialized = thread - .update(cx, |thread, cx| thread.serialize(cx)) - .await - .unwrap(); - assert_eq!( - serialized.messages.len(), - 2, - "Serialized thread should only have 2 messages (no UI-only)" - ); - } - - #[gpui::test] - async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - // Verify retry was scheduled by checking for retry message - let has_retry_message = thread.read_with(cx, |thread, _| { - thread.messages.iter().any(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - }); - assert!(has_retry_message, "Should have scheduled a retry"); - - // Cancel the completion before the retry happens - thread.update(cx, |thread, cx| { - thread.cancel_last_completion(None, cx); - }); - - cx.run_until_parked(); - - // The retry should not have happened - no pending completions - let fake_model = model.as_fake(); - assert_eq!( - fake_model.pending_completions().len(), - 0, - "Should have no pending completions after cancellation" - ); - - // Verify the retry was cancelled by checking retry state - thread.read_with(cx, |thread, _| { - if let Some(retry_state) = &thread.retry_state { - panic!( - "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}", - retry_state.attempt, retry_state.max_attempts, retry_state.intent - ); - } - }); - } - - fn test_summarize_error( - model: &Arc, - thread: &Entity, - cx: &mut TestAppContext, - ) { - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); - thread.send_to_model( - model.clone(), - CompletionIntent::ThreadSummarization, - None, - cx, - ); - }); - - let fake_model = model.as_fake(); - simulate_successful_response(&fake_model, cx); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Generating)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - - // Simulate summary request ending - cx.run_until_parked(); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // State is set to Error and default message - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Error)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - } - - fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { - cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response"); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - } - - fn init_test_settings(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - AgentSettings::register(cx); - prompt_store::init(cx); - thread_store::init(cx); - workspace::init_settings(cx); - language_model::init_settings(cx); - ThemeSettings::register(cx); - ToolRegistry::default_global(cx); - }); - } - - // Helper to create a test project with test files - async fn create_test_project( - cx: &mut TestAppContext, - files: serde_json::Value, - ) -> Entity { - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), files).await; - Project::test(fs, [path!("/test").as_ref()], cx).await - } - - async fn setup_test_environment( - cx: &mut TestAppContext, - project: Entity, - ) -> ( - Entity, - Entity, - Entity, - Entity, - Arc, - ) { - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - - let thread_store = cx - .update(|_, cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .await - .unwrap(); - - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); - let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - - let provider = Arc::new(FakeLanguageModelProvider); - let model = provider.test_model(); - let model: Arc = Arc::new(model); - - cx.update(|_, cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model( - Some(ConfiguredModel { - provider: provider.clone(), - model: model.clone(), - }), - cx, - ); - registry.set_thread_summary_model( - Some(ConfiguredModel { - provider, - model: model.clone(), - }), - cx, - ); - }) - }); - - (workspace, thread_store, thread, context_store, model) - } - - async fn add_file_to_context( - project: &Entity, - context_store: &Entity, - path: &str, - cx: &mut TestAppContext, - ) -> Result> { - let buffer_path = project - .read_with(cx, |project, cx| project.find_project_path(path, cx)) - .unwrap(); - - let buffer = project - .update(cx, |project, cx| { - project.open_buffer(buffer_path.clone(), cx) - }) - .await - .unwrap(); - - context_store.update(cx, |context_store, cx| { - context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx); - }); - - Ok(buffer) - } -} +// #[cfg(test)] +// mod tests { +// use super::*; +// use crate::{ +// context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore, +// }; + +// // Test-specific constants +// const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30; +// use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; +// use assistant_tool::ToolRegistry; +// use futures::StreamExt; +// use futures::future::BoxFuture; +// use futures::stream::BoxStream; +// use gpui::TestAppContext; +// use icons::IconName; +// use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; +// use language_model::{ +// LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId, +// LanguageModelProviderName, LanguageModelToolChoice, +// }; +// use parking_lot::Mutex; +// use project::{FakeFs, Project}; +// use prompt_store::PromptBuilder; +// use serde_json::json; +// use settings::{Settings, SettingsStore}; +// use std::sync::Arc; +// use std::time::Duration; +// use theme::ThemeSettings; +// use util::path; +// use workspace::Workspace; + +// #[gpui::test] +// async fn test_message_with_context(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project( +// cx, +// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), +// ) +// .await; + +// let (_workspace, _thread_store, thread, context_store, model) = +// setup_test_environment(cx, project.clone()).await; + +// add_file_to_context(&project, &context_store, "test/code.rs", cx) +// .await +// .unwrap(); + +// let context = +// context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap()); +// let loaded_context = cx +// .update(|cx| load_context(vec![context], &project, &None, cx)) +// .await; + +// // Insert user message with context +// let message_id = thread.update(cx, |thread, cx| { +// thread.insert_user_message( +// "Please explain this code", +// loaded_context, +// None, +// Vec::new(), +// cx, +// ) +// }); + +// // Check content and context in message object +// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); + +// // Use different path format strings based on platform for the test +// #[cfg(windows)] +// let path_part = r"test\code.rs"; +// #[cfg(not(windows))] +// let path_part = "test/code.rs"; + +// let expected_context = format!( +// r#" +// +// The following items were attached by the user. They are up-to-date and don't need to be re-read. + +// +// ```rs {path_part} +// fn main() {{ +// println!("Hello, world!"); +// }} +// ``` +// +// +// "# +// ); + +// assert_eq!(message.role, Role::User); +// assert_eq!(message.segments.len(), 1); +// assert_eq!( +// message.segments[0], +// MessageSegment::Text("Please explain this code".to_string()) +// ); +// assert_eq!(message.loaded_context.text, expected_context); + +// // Check message in request +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); + +// assert_eq!(request.messages.len(), 2); +// let expected_full_message = format!("{}Please explain this code", expected_context); +// assert_eq!(request.messages[1].string_contents(), expected_full_message); +// } + +// #[gpui::test] +// async fn test_only_include_new_contexts(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project( +// cx, +// json!({ +// "file1.rs": "fn function1() {}\n", +// "file2.rs": "fn function2() {}\n", +// "file3.rs": "fn function3() {}\n", +// "file4.rs": "fn function4() {}\n", +// }), +// ) +// .await; + +// let (_, _thread_store, thread, context_store, model) = +// setup_test_environment(cx, project.clone()).await; + +// // First message with context 1 +// add_file_to_context(&project, &context_store, "test/file1.rs", cx) +// .await +// .unwrap(); +// let new_contexts = context_store.update(cx, |store, cx| { +// store.new_context_for_thread(thread.read(cx), None) +// }); +// assert_eq!(new_contexts.len(), 1); +// let loaded_context = cx +// .update(|cx| load_context(new_contexts, &project, &None, cx)) +// .await; +// let message1_id = thread.update(cx, |thread, cx| { +// thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx) +// }); + +// // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included) +// add_file_to_context(&project, &context_store, "test/file2.rs", cx) +// .await +// .unwrap(); +// let new_contexts = context_store.update(cx, |store, cx| { +// store.new_context_for_thread(thread.read(cx), None) +// }); +// assert_eq!(new_contexts.len(), 1); +// let loaded_context = cx +// .update(|cx| load_context(new_contexts, &project, &None, cx)) +// .await; +// let message2_id = thread.update(cx, |thread, cx| { +// thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx) +// }); + +// // Third message with all three contexts (contexts 1 and 2 should be skipped) +// // +// add_file_to_context(&project, &context_store, "test/file3.rs", cx) +// .await +// .unwrap(); +// let new_contexts = context_store.update(cx, |store, cx| { +// store.new_context_for_thread(thread.read(cx), None) +// }); +// assert_eq!(new_contexts.len(), 1); +// let loaded_context = cx +// .update(|cx| load_context(new_contexts, &project, &None, cx)) +// .await; +// let message3_id = thread.update(cx, |thread, cx| { +// thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx) +// }); + +// // Check what contexts are included in each message +// let (message1, message2, message3) = thread.read_with(cx, |thread, _| { +// ( +// thread.message(message1_id).unwrap().clone(), +// thread.message(message2_id).unwrap().clone(), +// thread.message(message3_id).unwrap().clone(), +// ) +// }); + +// // First message should include context 1 +// assert!(message1.loaded_context.text.contains("file1.rs")); + +// // Second message should include only context 2 (not 1) +// assert!(!message2.loaded_context.text.contains("file1.rs")); +// assert!(message2.loaded_context.text.contains("file2.rs")); + +// // Third message should include only context 3 (not 1 or 2) +// assert!(!message3.loaded_context.text.contains("file1.rs")); +// assert!(!message3.loaded_context.text.contains("file2.rs")); +// assert!(message3.loaded_context.text.contains("file3.rs")); + +// // Check entire request to make sure all contexts are properly included +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); + +// // The request should contain all 3 messages +// assert_eq!(request.messages.len(), 4); + +// // Check that the contexts are properly formatted in each message +// assert!(request.messages[1].string_contents().contains("file1.rs")); +// assert!(!request.messages[1].string_contents().contains("file2.rs")); +// assert!(!request.messages[1].string_contents().contains("file3.rs")); + +// assert!(!request.messages[2].string_contents().contains("file1.rs")); +// assert!(request.messages[2].string_contents().contains("file2.rs")); +// assert!(!request.messages[2].string_contents().contains("file3.rs")); + +// assert!(!request.messages[3].string_contents().contains("file1.rs")); +// assert!(!request.messages[3].string_contents().contains("file2.rs")); +// assert!(request.messages[3].string_contents().contains("file3.rs")); + +// add_file_to_context(&project, &context_store, "test/file4.rs", cx) +// .await +// .unwrap(); +// let new_contexts = context_store.update(cx, |store, cx| { +// store.new_context_for_thread(thread.read(cx), Some(message2_id)) +// }); +// assert_eq!(new_contexts.len(), 3); +// let loaded_context = cx +// .update(|cx| load_context(new_contexts, &project, &None, cx)) +// .await +// .loaded_context; + +// assert!(!loaded_context.text.contains("file1.rs")); +// assert!(loaded_context.text.contains("file2.rs")); +// assert!(loaded_context.text.contains("file3.rs")); +// assert!(loaded_context.text.contains("file4.rs")); + +// let new_contexts = context_store.update(cx, |store, cx| { +// // Remove file4.rs +// store.remove_context(&loaded_context.contexts[2].handle(), cx); +// store.new_context_for_thread(thread.read(cx), Some(message2_id)) +// }); +// assert_eq!(new_contexts.len(), 2); +// let loaded_context = cx +// .update(|cx| load_context(new_contexts, &project, &None, cx)) +// .await +// .loaded_context; + +// assert!(!loaded_context.text.contains("file1.rs")); +// assert!(loaded_context.text.contains("file2.rs")); +// assert!(loaded_context.text.contains("file3.rs")); +// assert!(!loaded_context.text.contains("file4.rs")); + +// let new_contexts = context_store.update(cx, |store, cx| { +// // Remove file3.rs +// store.remove_context(&loaded_context.contexts[1].handle(), cx); +// store.new_context_for_thread(thread.read(cx), Some(message2_id)) +// }); +// assert_eq!(new_contexts.len(), 1); +// let loaded_context = cx +// .update(|cx| load_context(new_contexts, &project, &None, cx)) +// .await +// .loaded_context; + +// assert!(!loaded_context.text.contains("file1.rs")); +// assert!(loaded_context.text.contains("file2.rs")); +// assert!(!loaded_context.text.contains("file3.rs")); +// assert!(!loaded_context.text.contains("file4.rs")); +// } + +// #[gpui::test] +// async fn test_message_without_files(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project( +// cx, +// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), +// ) +// .await; + +// let (_, _thread_store, thread, _context_store, model) = +// setup_test_environment(cx, project.clone()).await; + +// // Insert user message without any context (empty context vector) +// let message_id = thread.update(cx, |thread, cx| { +// thread.insert_user_message( +// "What is the best way to learn Rust?", +// ContextLoadResult::default(), +// None, +// Vec::new(), +// cx, +// ) +// }); + +// // Check content and context in message object +// let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); + +// // Context should be empty when no files are included +// assert_eq!(message.role, Role::User); +// assert_eq!(message.segments.len(), 1); +// assert_eq!( +// message.segments[0], +// MessageSegment::Text("What is the best way to learn Rust?".to_string()) +// ); +// assert_eq!(message.loaded_context.text, ""); + +// // Check message in request +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); + +// assert_eq!(request.messages.len(), 2); +// assert_eq!( +// request.messages[1].string_contents(), +// "What is the best way to learn Rust?" +// ); + +// // Add second message, also without context +// let message2_id = thread.update(cx, |thread, cx| { +// thread.insert_user_message( +// "Are there any good books?", +// ContextLoadResult::default(), +// None, +// Vec::new(), +// cx, +// ) +// }); + +// let message2 = +// thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone()); +// assert_eq!(message2.loaded_context.text, ""); + +// // Check that both messages appear in the request +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); + +// assert_eq!(request.messages.len(), 3); +// assert_eq!( +// request.messages[1].string_contents(), +// "What is the best way to learn Rust?" +// ); +// assert_eq!( +// request.messages[2].string_contents(), +// "Are there any good books?" +// ); +// } + +// #[gpui::test] +// async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project( +// cx, +// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), +// ) +// .await; + +// let (_workspace, thread_store, thread, _context_store, _model) = +// setup_test_environment(cx, project.clone()).await; + +// // Check that we are starting with the default profile +// let profile = cx.read(|cx| thread.read(cx).profile.clone()); +// let tool_set = cx.read(|cx| thread_store.read(cx).tools()); +// assert_eq!( +// profile, +// AgentProfile::new(AgentProfileId::default(), tool_set) +// ); +// } + +// #[gpui::test] +// async fn test_serializing_thread_profile(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project( +// cx, +// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), +// ) +// .await; + +// let (_workspace, thread_store, thread, _context_store, _model) = +// setup_test_environment(cx, project.clone()).await; + +// // Profile gets serialized with default values +// let serialized = thread +// .update(cx, |thread, cx| thread.serialize(cx)) +// .await +// .unwrap(); + +// assert_eq!(serialized.profile, Some(AgentProfileId::default())); + +// let deserialized = cx.update(|cx| { +// thread.update(cx, |thread, cx| { +// Thread::deserialize( +// thread.id.clone(), +// serialized, +// thread.project.clone(), +// thread.tools.clone(), +// thread.prompt_builder.clone(), +// thread.project_context.clone(), +// None, +// cx, +// ) +// }) +// }); +// let tool_set = cx.read(|cx| thread_store.read(cx).tools()); + +// assert_eq!( +// deserialized.profile, +// AgentProfile::new(AgentProfileId::default(), tool_set) +// ); +// } + +// #[gpui::test] +// async fn test_temperature_setting(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project( +// cx, +// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), +// ) +// .await; + +// let (_workspace, _thread_store, thread, _context_store, model) = +// setup_test_environment(cx, project.clone()).await; + +// // Both model and provider +// cx.update(|cx| { +// AgentSettings::override_global( +// AgentSettings { +// model_parameters: vec![LanguageModelParameters { +// provider: Some(model.provider_id().0.to_string().into()), +// model: Some(model.id().0.clone()), +// temperature: Some(0.66), +// }], +// ..AgentSettings::get_global(cx).clone() +// }, +// cx, +// ); +// }); + +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); +// assert_eq!(request.temperature, Some(0.66)); + +// // Only model +// cx.update(|cx| { +// AgentSettings::override_global( +// AgentSettings { +// model_parameters: vec![LanguageModelParameters { +// provider: None, +// model: Some(model.id().0.clone()), +// temperature: Some(0.66), +// }], +// ..AgentSettings::get_global(cx).clone() +// }, +// cx, +// ); +// }); + +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); +// assert_eq!(request.temperature, Some(0.66)); + +// // Only provider +// cx.update(|cx| { +// AgentSettings::override_global( +// AgentSettings { +// model_parameters: vec![LanguageModelParameters { +// provider: Some(model.provider_id().0.to_string().into()), +// model: None, +// temperature: Some(0.66), +// }], +// ..AgentSettings::get_global(cx).clone() +// }, +// cx, +// ); +// }); + +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); +// assert_eq!(request.temperature, Some(0.66)); + +// // Same model name, different provider +// cx.update(|cx| { +// AgentSettings::override_global( +// AgentSettings { +// model_parameters: vec![LanguageModelParameters { +// provider: Some("anthropic".into()), +// model: Some(model.id().0.clone()), +// temperature: Some(0.66), +// }], +// ..AgentSettings::get_global(cx).clone() +// }, +// cx, +// ); +// }); + +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); +// assert_eq!(request.temperature, None); +// } + +// #[gpui::test] +// async fn test_thread_summary(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; + +// let (_, _thread_store, thread, _context_store, model) = +// setup_test_environment(cx, project.clone()).await; + +// // Initial state should be pending +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Pending)); +// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); +// }); + +// // Manually setting the summary should not be allowed in this state +// thread.update(cx, |thread, cx| { +// thread.set_summary("This should not work", cx); +// }); + +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Pending)); +// }); + +// // Send a message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); +// thread.send_to_model( +// model.clone(), +// CompletionIntent::ThreadSummarization, +// None, +// cx, +// ); +// }); + +// let fake_model = model.as_fake(); +// simulate_successful_response(&fake_model, cx); + +// // Should start generating summary when there are >= 2 messages +// thread.read_with(cx, |thread, _| { +// assert_eq!(*thread.summary(), ThreadSummary::Generating); +// }); + +// // Should not be able to set the summary while generating +// thread.update(cx, |thread, cx| { +// thread.set_summary("This should not work either", cx); +// }); + +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Generating)); +// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); +// }); + +// cx.run_until_parked(); +// fake_model.stream_last_completion_response("Brief"); +// fake_model.stream_last_completion_response(" Introduction"); +// fake_model.end_last_completion_stream(); +// cx.run_until_parked(); + +// // Summary should be set +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); +// assert_eq!(thread.summary().or_default(), "Brief Introduction"); +// }); + +// // Now we should be able to set a summary +// thread.update(cx, |thread, cx| { +// thread.set_summary("Brief Intro", cx); +// }); + +// thread.read_with(cx, |thread, _| { +// assert_eq!(thread.summary().or_default(), "Brief Intro"); +// }); + +// // Test setting an empty summary (should default to DEFAULT) +// thread.update(cx, |thread, cx| { +// thread.set_summary("", cx); +// }); + +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); +// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); +// }); +// } + +// #[gpui::test] +// async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; + +// let (_, _thread_store, thread, _context_store, model) = +// setup_test_environment(cx, project.clone()).await; + +// test_summarize_error(&model, &thread, cx); + +// // Now we should be able to set a summary +// thread.update(cx, |thread, cx| { +// thread.set_summary("Brief Intro", cx); +// }); + +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); +// assert_eq!(thread.summary().or_default(), "Brief Intro"); +// }); +// } + +// #[gpui::test] +// async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; + +// let (_, _thread_store, thread, _context_store, model) = +// setup_test_environment(cx, project.clone()).await; + +// test_summarize_error(&model, &thread, cx); + +// // Sending another message should not trigger another summarize request +// thread.update(cx, |thread, cx| { +// thread.insert_user_message( +// "How are you?", +// ContextLoadResult::default(), +// None, +// vec![], +// cx, +// ); +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); + +// let fake_model = model.as_fake(); +// simulate_successful_response(&fake_model, cx); + +// thread.read_with(cx, |thread, _| { +// // State is still Error, not Generating +// assert!(matches!(thread.summary(), ThreadSummary::Error)); +// }); + +// // But the summarize request can be invoked manually +// thread.update(cx, |thread, cx| { +// thread.summarize(cx); +// }); + +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Generating)); +// }); + +// cx.run_until_parked(); +// fake_model.stream_last_completion_response("A successful summary"); +// fake_model.end_last_completion_stream(); +// cx.run_until_parked(); + +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); +// assert_eq!(thread.summary().or_default(), "A successful summary"); +// }); +// } + +// #[gpui::test] +// fn test_resolve_tool_name_conflicts() { +// use assistant_tool::{Tool, ToolSource}; + +// assert_resolve_tool_name_conflicts( +// vec![ +// TestTool::new("tool1", ToolSource::Native), +// TestTool::new("tool2", ToolSource::Native), +// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), +// ], +// vec!["tool1", "tool2", "tool3"], +// ); + +// assert_resolve_tool_name_conflicts( +// vec![ +// TestTool::new("tool1", ToolSource::Native), +// TestTool::new("tool2", ToolSource::Native), +// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), +// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), +// ], +// vec!["tool1", "tool2", "mcp-1_tool3", "mcp-2_tool3"], +// ); + +// assert_resolve_tool_name_conflicts( +// vec![ +// TestTool::new("tool1", ToolSource::Native), +// TestTool::new("tool2", ToolSource::Native), +// TestTool::new("tool3", ToolSource::Native), +// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), +// TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), +// ], +// vec!["tool1", "tool2", "tool3", "mcp-1_tool3", "mcp-2_tool3"], +// ); + +// // Test that tool with very long name is always truncated +// assert_resolve_tool_name_conflicts( +// vec![TestTool::new( +// "tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-blah-blah", +// ToolSource::Native, +// )], +// vec!["tool-with-more-then-64-characters-blah-blah-blah-blah-blah-blah-"], +// ); + +// // Test deduplication of tools with very long names, in this case the mcp server name should be truncated +// assert_resolve_tool_name_conflicts( +// vec![ +// TestTool::new("tool-with-very-very-very-long-name", ToolSource::Native), +// TestTool::new( +// "tool-with-very-very-very-long-name", +// ToolSource::ContextServer { +// id: "mcp-with-very-very-very-long-name".into(), +// }, +// ), +// ], +// vec![ +// "tool-with-very-very-very-long-name", +// "mcp-with-very-very-very-long-_tool-with-very-very-very-long-name", +// ], +// ); + +// fn assert_resolve_tool_name_conflicts( +// tools: Vec, +// expected: Vec>, +// ) { +// let tools: Vec> = tools +// .into_iter() +// .map(|t| Arc::new(t) as Arc) +// .collect(); +// let tools = resolve_tool_name_conflicts(&tools); +// assert_eq!(tools.len(), expected.len()); +// for (i, expected_name) in expected.into_iter().enumerate() { +// let expected_name = expected_name.into(); +// let actual_name = &tools[i].0; +// assert_eq!( +// actual_name, &expected_name, +// "Expected '{}' got '{}' at index {}", +// expected_name, actual_name, i +// ); +// } +// } + +// struct TestTool { +// name: String, +// source: ToolSource, +// } + +// impl TestTool { +// fn new(name: impl Into, source: ToolSource) -> Self { +// Self { +// name: name.into(), +// source, +// } +// } +// } + +// impl Tool for TestTool { +// fn name(&self) -> String { +// self.name.clone() +// } + +// fn icon(&self) -> IconName { +// IconName::Ai +// } + +// fn may_perform_edits(&self) -> bool { +// false +// } + +// fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { +// true +// } + +// fn source(&self) -> ToolSource { +// self.source.clone() +// } + +// fn description(&self) -> String { +// "Test tool".to_string() +// } + +// fn ui_text(&self, _input: &serde_json::Value) -> String { +// "Test tool".to_string() +// } + +// fn run( +// self: Arc, +// _input: serde_json::Value, +// _request: Arc, +// _project: Entity, +// _action_log: Entity, +// _model: Arc, +// _window: Option, +// _cx: &mut App, +// ) -> assistant_tool::ToolResult { +// assistant_tool::ToolResult { +// output: Task::ready(Err(anyhow::anyhow!("No content"))), +// card: None, +// } +// } +// } +// } + +// // Helper to create a model that returns errors +// enum TestError { +// Overloaded, +// InternalServerError, +// } + +// struct ErrorInjector { +// inner: Arc, +// error_type: TestError, +// } + +// impl ErrorInjector { +// fn new(error_type: TestError) -> Self { +// Self { +// inner: Arc::new(FakeLanguageModel::default()), +// error_type, +// } +// } +// } + +// impl LanguageModel for ErrorInjector { +// fn id(&self) -> LanguageModelId { +// self.inner.id() +// } + +// fn name(&self) -> LanguageModelName { +// self.inner.name() +// } + +// fn provider_id(&self) -> LanguageModelProviderId { +// self.inner.provider_id() +// } + +// fn provider_name(&self) -> LanguageModelProviderName { +// self.inner.provider_name() +// } + +// fn supports_tools(&self) -> bool { +// self.inner.supports_tools() +// } + +// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { +// self.inner.supports_tool_choice(choice) +// } + +// fn supports_images(&self) -> bool { +// self.inner.supports_images() +// } + +// fn telemetry_id(&self) -> String { +// self.inner.telemetry_id() +// } + +// fn max_token_count(&self) -> u64 { +// self.inner.max_token_count() +// } + +// fn count_tokens( +// &self, +// request: LanguageModelRequest, +// cx: &App, +// ) -> BoxFuture<'static, Result> { +// self.inner.count_tokens(request, cx) +// } + +// fn stream_completion( +// &self, +// _request: LanguageModelRequest, +// _cx: &AsyncApp, +// ) -> BoxFuture< +// 'static, +// Result< +// BoxStream< +// 'static, +// Result, +// >, +// LanguageModelCompletionError, +// >, +// > { +// let error = match self.error_type { +// TestError::Overloaded => LanguageModelCompletionError::Overloaded, +// TestError::InternalServerError => { +// LanguageModelCompletionError::ApiInternalServerError +// } +// }; +// async move { +// let stream = futures::stream::once(async move { Err(error) }); +// Ok(stream.boxed()) +// } +// .boxed() +// } + +// fn as_fake(&self) -> &FakeLanguageModel { +// &self.inner +// } +// } + +// #[gpui::test] +// async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // Create model that returns overloaded error +// let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Start completion +// thread.update(cx, |thread, cx| { +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); + +// cx.run_until_parked(); + +// thread.read_with(cx, |thread, _| { +// assert!(thread.retry_state.is_some(), "Should have retry state"); +// let retry_state = thread.retry_state.as_ref().unwrap(); +// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); +// assert_eq!( +// retry_state.max_attempts, MAX_RETRY_ATTEMPTS, +// "Should have default max attempts" +// ); +// }); + +// // Check that a retry message was added +// thread.read_with(cx, |thread, _| { +// let mut messages = thread.messages(); +// assert!( +// messages.any(|msg| { +// msg.role == Role::System +// && msg.ui_only +// && msg.segments.iter().any(|seg| { +// if let MessageSegment::Text(text) = seg { +// text.contains("overloaded") +// && text +// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) +// } else { +// false +// } +// }) +// }), +// "Should have added a system retry message" +// ); +// }); + +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// .count() +// }); + +// assert_eq!(retry_count, 1, "Should have one retry message"); +// } + +// #[gpui::test] +// async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // Create model that returns internal server error +// let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Start completion +// thread.update(cx, |thread, cx| { +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); + +// cx.run_until_parked(); + +// // Check retry state on thread +// thread.read_with(cx, |thread, _| { +// assert!(thread.retry_state.is_some(), "Should have retry state"); +// let retry_state = thread.retry_state.as_ref().unwrap(); +// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); +// assert_eq!( +// retry_state.max_attempts, MAX_RETRY_ATTEMPTS, +// "Should have correct max attempts" +// ); +// }); + +// // Check that a retry message was added with provider name +// thread.read_with(cx, |thread, _| { +// let mut messages = thread.messages(); +// assert!( +// messages.any(|msg| { +// msg.role == Role::System +// && msg.ui_only +// && msg.segments.iter().any(|seg| { +// if let MessageSegment::Text(text) = seg { +// text.contains("internal") +// && text.contains("Fake") +// && text +// .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) +// } else { +// false +// } +// }) +// }), +// "Should have added a system retry message with provider name" +// ); +// }); + +// // Count retry messages +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// .count() +// }); + +// assert_eq!(retry_count, 1, "Should have one retry message"); +// } + +// #[gpui::test] +// async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // Create model that returns overloaded error +// let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Track retry events and completion count +// // Track completion events +// let completion_count = Arc::new(Mutex::new(0)); +// let completion_count_clone = completion_count.clone(); + +// let _subscription = thread.update(cx, |_, cx| { +// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { +// if let ThreadEvent::NewRequest = event { +// *completion_count_clone.lock() += 1; +// } +// }) +// }); + +// // First attempt +// thread.update(cx, |thread, cx| { +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); +// cx.run_until_parked(); + +// // Should have scheduled first retry - count retry messages +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// .count() +// }); +// assert_eq!(retry_count, 1, "Should have scheduled first retry"); + +// // Check retry state +// thread.read_with(cx, |thread, _| { +// assert!(thread.retry_state.is_some(), "Should have retry state"); +// let retry_state = thread.retry_state.as_ref().unwrap(); +// assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); +// }); + +// // Advance clock for first retry +// cx.executor() +// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); +// cx.run_until_parked(); + +// // Should have scheduled second retry - count retry messages +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// .count() +// }); +// assert_eq!(retry_count, 2, "Should have scheduled second retry"); + +// // Check retry state updated +// thread.read_with(cx, |thread, _| { +// assert!(thread.retry_state.is_some(), "Should have retry state"); +// let retry_state = thread.retry_state.as_ref().unwrap(); +// assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); +// assert_eq!( +// retry_state.max_attempts, MAX_RETRY_ATTEMPTS, +// "Should have correct max attempts" +// ); +// }); + +// // Advance clock for second retry (exponential backoff) +// cx.executor() +// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 2)); +// cx.run_until_parked(); + +// // Should have scheduled third retry +// // Count all retry messages now +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// .count() +// }); +// assert_eq!( +// retry_count, MAX_RETRY_ATTEMPTS as usize, +// "Should have scheduled third retry" +// ); + +// // Check retry state updated +// thread.read_with(cx, |thread, _| { +// assert!(thread.retry_state.is_some(), "Should have retry state"); +// let retry_state = thread.retry_state.as_ref().unwrap(); +// assert_eq!( +// retry_state.attempt, MAX_RETRY_ATTEMPTS, +// "Should be at max retry attempt" +// ); +// assert_eq!( +// retry_state.max_attempts, MAX_RETRY_ATTEMPTS, +// "Should have correct max attempts" +// ); +// }); + +// // Advance clock for third retry (exponential backoff) +// cx.executor() +// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS * 4)); +// cx.run_until_parked(); + +// // No more retries should be scheduled after clock was advanced. +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// .count() +// }); +// assert_eq!( +// retry_count, MAX_RETRY_ATTEMPTS as usize, +// "Should not exceed max retries" +// ); + +// // Final completion count should be initial + max retries +// assert_eq!( +// *completion_count.lock(), +// (MAX_RETRY_ATTEMPTS + 1) as usize, +// "Should have made initial + max retry attempts" +// ); +// } + +// #[gpui::test] +// async fn test_max_retries_exceeded(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // Create model that returns overloaded error +// let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Track events +// let retries_failed = Arc::new(Mutex::new(false)); +// let retries_failed_clone = retries_failed.clone(); + +// let _subscription = thread.update(cx, |_, cx| { +// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { +// if let ThreadEvent::RetriesFailed { .. } = event { +// *retries_failed_clone.lock() = true; +// } +// }) +// }); + +// // Start initial completion +// thread.update(cx, |thread, cx| { +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); +// cx.run_until_parked(); + +// // Advance through all retries +// for i in 0..MAX_RETRY_ATTEMPTS { +// let delay = if i == 0 { +// BASE_RETRY_DELAY_SECS +// } else { +// BASE_RETRY_DELAY_SECS * 2u64.pow(i as u32 - 1) +// }; +// cx.executor().advance_clock(Duration::from_secs(delay)); +// cx.run_until_parked(); +// } + +// // After the 3rd retry is scheduled, we need to wait for it to execute and fail +// // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) +// let final_delay = BASE_RETRY_DELAY_SECS * 2u64.pow((MAX_RETRY_ATTEMPTS - 1) as u32); +// cx.executor() +// .advance_clock(Duration::from_secs(final_delay)); +// cx.run_until_parked(); + +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// .count() +// }); + +// // After max retries, should emit RetriesFailed event +// assert_eq!( +// retry_count, MAX_RETRY_ATTEMPTS as usize, +// "Should have attempted max retries" +// ); +// assert!( +// *retries_failed.lock(), +// "Should emit RetriesFailed event after max retries exceeded" +// ); + +// // Retry state should be cleared +// thread.read_with(cx, |thread, _| { +// assert!( +// thread.retry_state.is_none(), +// "Retry state should be cleared after max retries" +// ); + +// // Verify we have the expected number of retry messages +// let retry_messages = thread +// .messages +// .iter() +// .filter(|msg| msg.ui_only && msg.role == Role::System) +// .count(); +// assert_eq!( +// retry_messages, MAX_RETRY_ATTEMPTS as usize, +// "Should have one retry message per attempt" +// ); +// }); +// } + +// #[gpui::test] +// async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // We'll use a wrapper to switch behavior after first failure +// struct RetryTestModel { +// inner: Arc, +// failed_once: Arc>, +// } + +// impl LanguageModel for RetryTestModel { +// fn id(&self) -> LanguageModelId { +// self.inner.id() +// } + +// fn name(&self) -> LanguageModelName { +// self.inner.name() +// } + +// fn provider_id(&self) -> LanguageModelProviderId { +// self.inner.provider_id() +// } + +// fn provider_name(&self) -> LanguageModelProviderName { +// self.inner.provider_name() +// } + +// fn supports_tools(&self) -> bool { +// self.inner.supports_tools() +// } + +// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { +// self.inner.supports_tool_choice(choice) +// } + +// fn supports_images(&self) -> bool { +// self.inner.supports_images() +// } + +// fn telemetry_id(&self) -> String { +// self.inner.telemetry_id() +// } + +// fn max_token_count(&self) -> u64 { +// self.inner.max_token_count() +// } + +// fn count_tokens( +// &self, +// request: LanguageModelRequest, +// cx: &App, +// ) -> BoxFuture<'static, Result> { +// self.inner.count_tokens(request, cx) +// } + +// fn stream_completion( +// &self, +// request: LanguageModelRequest, +// cx: &AsyncApp, +// ) -> BoxFuture< +// 'static, +// Result< +// BoxStream< +// 'static, +// Result, +// >, +// LanguageModelCompletionError, +// >, +// > { +// if !*self.failed_once.lock() { +// *self.failed_once.lock() = true; +// // Return error on first attempt +// let stream = futures::stream::once(async move { +// Err(LanguageModelCompletionError::Overloaded) +// }); +// async move { Ok(stream.boxed()) }.boxed() +// } else { +// // Succeed on retry +// self.inner.stream_completion(request, cx) +// } +// } + +// fn as_fake(&self) -> &FakeLanguageModel { +// &self.inner +// } +// } + +// let model = Arc::new(RetryTestModel { +// inner: Arc::new(FakeLanguageModel::default()), +// failed_once: Arc::new(Mutex::new(false)), +// }); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Track message deletions +// // Track when retry completes successfully +// let retry_completed = Arc::new(Mutex::new(false)); +// let retry_completed_clone = retry_completed.clone(); + +// let _subscription = thread.update(cx, |_, cx| { +// cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { +// if let ThreadEvent::StreamedCompletion = event { +// *retry_completed_clone.lock() = true; +// } +// }) +// }); + +// // Start completion +// thread.update(cx, |thread, cx| { +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); +// cx.run_until_parked(); + +// // Get the retry message ID +// let retry_message_id = thread.read_with(cx, |thread, _| { +// thread +// .messages() +// .find(|msg| msg.role == Role::System && msg.ui_only) +// .map(|msg| msg.id) +// .expect("Should have a retry message") +// }); + +// // Wait for retry +// cx.executor() +// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); +// cx.run_until_parked(); + +// // Stream some successful content +// let fake_model = model.as_fake(); +// // After the retry, there should be a new pending completion +// let pending = fake_model.pending_completions(); +// assert!( +// !pending.is_empty(), +// "Should have a pending completion after retry" +// ); +// fake_model.stream_completion_response(&pending[0], "Success!"); +// fake_model.end_completion_stream(&pending[0]); +// cx.run_until_parked(); + +// // Check that the retry completed successfully +// assert!( +// *retry_completed.lock(), +// "Retry should have completed successfully" +// ); + +// // Retry message should still exist but be marked as ui_only +// thread.read_with(cx, |thread, _| { +// let retry_msg = thread +// .message(retry_message_id) +// .expect("Retry message should still exist"); +// assert!(retry_msg.ui_only, "Retry message should be ui_only"); +// assert_eq!( +// retry_msg.role, +// Role::System, +// "Retry message should have System role" +// ); +// }); +// } + +// #[gpui::test] +// async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // Create a model that fails once then succeeds +// struct FailOnceModel { +// inner: Arc, +// failed_once: Arc>, +// } + +// impl LanguageModel for FailOnceModel { +// fn id(&self) -> LanguageModelId { +// self.inner.id() +// } + +// fn name(&self) -> LanguageModelName { +// self.inner.name() +// } + +// fn provider_id(&self) -> LanguageModelProviderId { +// self.inner.provider_id() +// } + +// fn provider_name(&self) -> LanguageModelProviderName { +// self.inner.provider_name() +// } + +// fn supports_tools(&self) -> bool { +// self.inner.supports_tools() +// } + +// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { +// self.inner.supports_tool_choice(choice) +// } + +// fn supports_images(&self) -> bool { +// self.inner.supports_images() +// } + +// fn telemetry_id(&self) -> String { +// self.inner.telemetry_id() +// } + +// fn max_token_count(&self) -> u64 { +// self.inner.max_token_count() +// } + +// fn count_tokens( +// &self, +// request: LanguageModelRequest, +// cx: &App, +// ) -> BoxFuture<'static, Result> { +// self.inner.count_tokens(request, cx) +// } + +// fn stream_completion( +// &self, +// request: LanguageModelRequest, +// cx: &AsyncApp, +// ) -> BoxFuture< +// 'static, +// Result< +// BoxStream< +// 'static, +// Result, +// >, +// LanguageModelCompletionError, +// >, +// > { +// if !*self.failed_once.lock() { +// *self.failed_once.lock() = true; +// // Return error on first attempt +// let stream = futures::stream::once(async move { +// Err(LanguageModelCompletionError::Overloaded) +// }); +// async move { Ok(stream.boxed()) }.boxed() +// } else { +// // Succeed on retry +// self.inner.stream_completion(request, cx) +// } +// } +// } + +// let fail_once_model = Arc::new(FailOnceModel { +// inner: Arc::new(FakeLanguageModel::default()), +// failed_once: Arc::new(Mutex::new(false)), +// }); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message( +// "Test message", +// ContextLoadResult::default(), +// None, +// vec![], +// cx, +// ); +// }); + +// // Start completion with fail-once model +// thread.update(cx, |thread, cx| { +// thread.send_to_model( +// fail_once_model.clone(), +// CompletionIntent::UserPrompt, +// None, +// cx, +// ); +// }); + +// cx.run_until_parked(); + +// // Verify retry state exists after first failure +// thread.read_with(cx, |thread, _| { +// assert!( +// thread.retry_state.is_some(), +// "Should have retry state after failure" +// ); +// }); + +// // Wait for retry delay +// cx.executor() +// .advance_clock(Duration::from_secs(BASE_RETRY_DELAY_SECS)); +// cx.run_until_parked(); + +// // The retry should now use our FailOnceModel which should succeed +// // We need to help the FakeLanguageModel complete the stream +// let inner_fake = fail_once_model.inner.clone(); + +// // Wait a bit for the retry to start +// cx.run_until_parked(); + +// // Check for pending completions and complete them +// if let Some(pending) = inner_fake.pending_completions().first() { +// inner_fake.stream_completion_response(pending, "Success!"); +// inner_fake.end_completion_stream(pending); +// } +// cx.run_until_parked(); + +// thread.read_with(cx, |thread, _| { +// assert!( +// thread.retry_state.is_none(), +// "Retry state should be cleared after successful completion" +// ); + +// let has_assistant_message = thread +// .messages +// .iter() +// .any(|msg| msg.role == Role::Assistant && !msg.ui_only); +// assert!( +// has_assistant_message, +// "Should have an assistant message after successful retry" +// ); +// }); +// } + +// #[gpui::test] +// async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // Create a model that returns rate limit error with retry_after +// struct RateLimitModel { +// inner: Arc, +// } + +// impl LanguageModel for RateLimitModel { +// fn id(&self) -> LanguageModelId { +// self.inner.id() +// } + +// fn name(&self) -> LanguageModelName { +// self.inner.name() +// } + +// fn provider_id(&self) -> LanguageModelProviderId { +// self.inner.provider_id() +// } + +// fn provider_name(&self) -> LanguageModelProviderName { +// self.inner.provider_name() +// } + +// fn supports_tools(&self) -> bool { +// self.inner.supports_tools() +// } + +// fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { +// self.inner.supports_tool_choice(choice) +// } + +// fn supports_images(&self) -> bool { +// self.inner.supports_images() +// } + +// fn telemetry_id(&self) -> String { +// self.inner.telemetry_id() +// } + +// fn max_token_count(&self) -> u64 { +// self.inner.max_token_count() +// } + +// fn count_tokens( +// &self, +// request: LanguageModelRequest, +// cx: &App, +// ) -> BoxFuture<'static, Result> { +// self.inner.count_tokens(request, cx) +// } + +// fn stream_completion( +// &self, +// _request: LanguageModelRequest, +// _cx: &AsyncApp, +// ) -> BoxFuture< +// 'static, +// Result< +// BoxStream< +// 'static, +// Result, +// >, +// LanguageModelCompletionError, +// >, +// > { +// async move { +// let stream = futures::stream::once(async move { +// Err(LanguageModelCompletionError::RateLimitExceeded { +// retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS), +// }) +// }); +// Ok(stream.boxed()) +// } +// .boxed() +// } + +// fn as_fake(&self) -> &FakeLanguageModel { +// &self.inner +// } +// } + +// let model = Arc::new(RateLimitModel { +// inner: Arc::new(FakeLanguageModel::default()), +// }); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Start completion +// thread.update(cx, |thread, cx| { +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); + +// cx.run_until_parked(); + +// let retry_count = thread.update(cx, |thread, _| { +// thread +// .messages +// .iter() +// .filter(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("rate limit exceeded") +// } else { +// false +// } +// }) +// }) +// .count() +// }); +// assert_eq!(retry_count, 1, "Should have scheduled one retry"); + +// thread.read_with(cx, |thread, _| { +// assert!( +// thread.retry_state.is_none(), +// "Rate limit errors should not set retry_state" +// ); +// }); + +// // Verify we have one retry message +// thread.read_with(cx, |thread, _| { +// let retry_messages = thread +// .messages +// .iter() +// .filter(|msg| { +// msg.ui_only +// && msg.segments.iter().any(|seg| { +// if let MessageSegment::Text(text) = seg { +// text.contains("rate limit exceeded") +// } else { +// false +// } +// }) +// }) +// .count(); +// assert_eq!( +// retry_messages, 1, +// "Should have one rate limit retry message" +// ); +// }); + +// // Check that retry message doesn't include attempt count +// thread.read_with(cx, |thread, _| { +// let retry_message = thread +// .messages +// .iter() +// .find(|msg| msg.role == Role::System && msg.ui_only) +// .expect("Should have a retry message"); + +// // Check that the message doesn't contain attempt count +// if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { +// assert!( +// !text.contains("attempt"), +// "Rate limit retry message should not contain attempt count" +// ); +// assert!( +// text.contains(&format!( +// "Retrying in {} seconds", +// TEST_RATE_LIMIT_RETRY_SECS +// )), +// "Rate limit retry message should contain retry delay" +// ); +// } +// }); +// } + +// #[gpui::test] +// async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await; + +// // Insert a regular user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Insert a UI-only message (like our retry notifications) +// thread.update(cx, |thread, cx| { +// let id = thread.next_message_id.post_inc(); +// thread.messages.push(Message { +// id, +// role: Role::System, +// segments: vec![MessageSegment::Text( +// "This is a UI-only message that should not be sent to the model".to_string(), +// )], +// loaded_context: LoadedContext::default(), +// creases: Vec::new(), +// is_hidden: true, +// ui_only: true, +// }); +// cx.emit(ThreadEvent::MessageAdded(id)); +// }); + +// // Insert another regular message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message( +// "How are you?", +// ContextLoadResult::default(), +// None, +// vec![], +// cx, +// ); +// }); + +// // Generate the completion request +// let request = thread.update(cx, |thread, cx| { +// thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) +// }); + +// // Verify that the request only contains non-UI-only messages +// // Should have system prompt + 2 user messages, but not the UI-only message +// let user_messages: Vec<_> = request +// .messages +// .iter() +// .filter(|msg| msg.role == Role::User) +// .collect(); +// assert_eq!( +// user_messages.len(), +// 2, +// "Should have exactly 2 user messages" +// ); + +// // Verify the UI-only content is not present anywhere in the request +// let request_text = request +// .messages +// .iter() +// .flat_map(|msg| &msg.content) +// .filter_map(|content| match content { +// MessageContent::Text(text) => Some(text.as_str()), +// _ => None, +// }) +// .collect::(); + +// assert!( +// !request_text.contains("UI-only message"), +// "UI-only message content should not be in the request" +// ); + +// // Verify the thread still has all 3 messages (including UI-only) +// thread.read_with(cx, |thread, _| { +// assert_eq!( +// thread.messages().count(), +// 3, +// "Thread should have 3 messages" +// ); +// assert_eq!( +// thread.messages().filter(|m| m.ui_only).count(), +// 1, +// "Thread should have 1 UI-only message" +// ); +// }); + +// // Verify that UI-only messages are not serialized +// let serialized = thread +// .update(cx, |thread, cx| thread.serialize(cx)) +// .await +// .unwrap(); +// assert_eq!( +// serialized.messages.len(), +// 2, +// "Serialized thread should only have 2 messages (no UI-only)" +// ); +// } + +// #[gpui::test] +// async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; +// let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; + +// // Create model that returns overloaded error +// let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); + +// // Insert a user message +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); +// }); + +// // Start completion +// thread.update(cx, |thread, cx| { +// thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); +// }); + +// cx.run_until_parked(); + +// // Verify retry was scheduled by checking for retry message +// let has_retry_message = thread.read_with(cx, |thread, _| { +// thread.messages.iter().any(|m| { +// m.ui_only +// && m.segments.iter().any(|s| { +// if let MessageSegment::Text(text) = s { +// text.contains("Retrying") && text.contains("seconds") +// } else { +// false +// } +// }) +// }) +// }); +// assert!(has_retry_message, "Should have scheduled a retry"); + +// // Cancel the completion before the retry happens +// thread.update(cx, |thread, cx| { +// thread.cancel_last_completion(None, cx); +// }); + +// cx.run_until_parked(); + +// // The retry should not have happened - no pending completions +// let fake_model = model.as_fake(); +// assert_eq!( +// fake_model.pending_completions().len(), +// 0, +// "Should have no pending completions after cancellation" +// ); + +// // Verify the retry was cancelled by checking retry state +// thread.read_with(cx, |thread, _| { +// if let Some(retry_state) = &thread.retry_state { +// panic!( +// "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}", +// retry_state.attempt, retry_state.max_attempts, retry_state.intent +// ); +// } +// }); +// } + +// fn test_summarize_error( +// model: &Arc, +// thread: &Entity, +// cx: &mut TestAppContext, +// ) { +// thread.update(cx, |thread, cx| { +// thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); +// thread.send_to_model( +// model.clone(), +// CompletionIntent::ThreadSummarization, +// None, +// cx, +// ); +// }); + +// let fake_model = model.as_fake(); +// simulate_successful_response(&fake_model, cx); + +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Generating)); +// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); +// }); + +// // Simulate summary request ending +// cx.run_until_parked(); +// fake_model.end_last_completion_stream(); +// cx.run_until_parked(); + +// // State is set to Error and default message +// thread.read_with(cx, |thread, _| { +// assert!(matches!(thread.summary(), ThreadSummary::Error)); +// assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); +// }); +// } + +// fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { +// cx.run_until_parked(); +// fake_model.stream_last_completion_response("Assistant response"); +// fake_model.end_last_completion_stream(); +// cx.run_until_parked(); +// } + +// fn init_test_settings(cx: &mut TestAppContext) { +// cx.update(|cx| { +// let settings_store = SettingsStore::test(cx); +// cx.set_global(settings_store); +// language::init(cx); +// Project::init_settings(cx); +// AgentSettings::register(cx); +// prompt_store::init(cx); +// thread_store::init(cx); +// workspace::init_settings(cx); +// language_model::init_settings(cx); +// ThemeSettings::register(cx); +// ToolRegistry::default_global(cx); +// }); +// } + +// // Helper to create a test project with test files +// async fn create_test_project( +// cx: &mut TestAppContext, +// files: serde_json::Value, +// ) -> Entity { +// let fs = FakeFs::new(cx.executor()); +// fs.insert_tree(path!("/test"), files).await; +// Project::test(fs, [path!("/test").as_ref()], cx).await +// } + +// async fn setup_test_environment( +// cx: &mut TestAppContext, +// project: Entity, +// ) -> ( +// Entity, +// Entity, +// Entity, +// Entity, +// Arc, +// ) { +// let (workspace, cx) = +// cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + +// let thread_store = cx +// .update(|_, cx| { +// ThreadStore::load( +// project.clone(), +// cx.new(|_| ToolWorkingSet::default()), +// None, +// Arc::new(PromptBuilder::new(None).unwrap()), +// cx, +// ) +// }) +// .await +// .unwrap(); + +// let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); +// let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); + +// let provider = Arc::new(FakeLanguageModelProvider); +// let model = provider.test_model(); +// let model: Arc = Arc::new(model); + +// cx.update(|_, cx| { +// LanguageModelRegistry::global(cx).update(cx, |registry, cx| { +// registry.set_default_model( +// Some(ConfiguredModel { +// provider: provider.clone(), +// model: model.clone(), +// }), +// cx, +// ); +// registry.set_thread_summary_model( +// Some(ConfiguredModel { +// provider, +// model: model.clone(), +// }), +// cx, +// ); +// }) +// }); + +// (workspace, thread_store, thread, context_store, model) +// } + +// async fn add_file_to_context( +// project: &Entity, +// context_store: &Entity, +// path: &str, +// cx: &mut TestAppContext, +// ) -> Result> { +// let buffer_path = project +// .read_with(cx, |project, cx| project.find_project_path(path, cx)) +// .unwrap(); + +// let buffer = project +// .update(cx, |project, cx| { +// project.open_buffer(buffer_path.clone(), cx) +// }) +// .await +// .unwrap(); + +// context_store.update(cx, |context_store, cx| { +// context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx); +// }); + +// Ok(buffer) +// } +// } diff --git a/crates/agent/src/thread2.rs b/crates/agent/src/thread2.rs deleted file mode 100644 index b76f02a59a468b5b92bbf9671739059029893329..0000000000000000000000000000000000000000 --- a/crates/agent/src/thread2.rs +++ /dev/null @@ -1,1449 +0,0 @@ -use crate::{ - AgentThread, AgentThreadId, AgentThreadMessageId, AgentThreadUserMessageChunk, - agent_profile::AgentProfile, - context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}, - thread_store::{SharedProjectContext, ThreadStore}, -}; -use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; -use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; -use chrono::{DateTime, Utc}; -use client::{ModelRequestUsage, RequestUsage}; -use collections::{HashMap, HashSet}; -use feature_flags::{self, FeatureFlagAppExt}; -use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared}; -use git::repository::DiffType; -use gpui::{ - AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, - WeakEntity, -}; -use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage, -}; -use postage::stream::Stream as _; -use project::{ - Project, - git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, -}; -use prompt_store::{ModelContext, PromptBuilder}; -use proto::Plan; -use serde::{Deserialize, Serialize}; -use settings::Settings; -use std::{ - io::Write, - ops::Range, - sync::Arc, - time::{Duration, Instant}, -}; -use thiserror::Error; -use util::{ResultExt as _, post_inc}; -use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; - -/// Stored information that can be used to resurrect a context crease when creating an editor for a past message. -#[derive(Clone, Debug)] -pub struct MessageCrease { - pub range: Range, - pub icon_path: SharedString, - pub label: SharedString, - /// None for a deserialized message, Some otherwise. - pub context: Option, -} - -pub enum MessageTool { - Pending { - tool: Arc, - input: serde_json::Value, - }, - NeedsConfirmation { - tool: Arc, - input_json: serde_json::Value, - confirm_tx: oneshot::Sender, - }, - Confirmed { - card: AnyToolCard, - }, - Declined { - tool: Arc, - input_json: serde_json::Value, - }, -} - -/// A message in a [`Thread`]. -pub struct Message { - pub id: AgentThreadMessageId, - pub role: Role, - pub thinking: String, - pub text: String, - pub tools: Vec, - pub loaded_context: LoadedContext, - pub creases: Vec, - pub is_hidden: bool, - pub ui_only: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ProjectSnapshot { - pub worktree_snapshots: Vec, - pub unsaved_buffer_paths: Vec, - pub timestamp: DateTime, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct WorktreeSnapshot { - pub worktree_path: String, - pub git_state: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct GitState { - pub remote_url: Option, - pub head_sha: Option, - pub current_branch: Option, - pub diff: Option, -} - -#[derive(Clone, Debug)] -pub struct ThreadCheckpoint { - message_id: AgentThreadMessageId, - git_checkpoint: GitStoreCheckpoint, -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ThreadFeedback { - Positive, - Negative, -} - -pub enum LastRestoreCheckpoint { - Pending { - message_id: AgentThreadMessageId, - }, - Error { - message_id: AgentThreadMessageId, - error: String, - }, -} - -impl LastRestoreCheckpoint { - pub fn message_id(&self) -> AgentThreadMessageId { - match self { - LastRestoreCheckpoint::Pending { message_id } => *message_id, - LastRestoreCheckpoint::Error { message_id, .. } => *message_id, - } - } -} - -#[derive(Clone, Debug, Default)] -pub enum DetailedSummaryState { - #[default] - NotGenerated, - Generating { - message_id: AgentThreadMessageId, - }, - Generated { - text: SharedString, - message_id: AgentThreadMessageId, - }, -} - -impl DetailedSummaryState { - fn text(&self) -> Option { - if let Self::Generated { text, .. } = self { - Some(text.clone()) - } else { - None - } - } -} - -#[derive(Default, Debug)] -pub struct TotalTokenUsage { - pub total: u64, - pub max: u64, -} - -impl TotalTokenUsage { - pub fn ratio(&self) -> TokenUsageRatio { - #[cfg(debug_assertions)] - let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") - .unwrap_or("0.8".to_string()) - .parse() - .unwrap(); - #[cfg(not(debug_assertions))] - let warning_threshold: f32 = 0.8; - - // When the maximum is unknown because there is no selected model, - // avoid showing the token limit warning. - if self.max == 0 { - TokenUsageRatio::Normal - } else if self.total >= self.max { - TokenUsageRatio::Exceeded - } else if self.total as f32 / self.max as f32 >= warning_threshold { - TokenUsageRatio::Warning - } else { - TokenUsageRatio::Normal - } - } - - pub fn add(&self, tokens: u64) -> TotalTokenUsage { - TotalTokenUsage { - total: self.total + tokens, - max: self.max, - } - } -} - -#[derive(Debug, Default, PartialEq, Eq)] -pub enum TokenUsageRatio { - #[default] - Normal, - Warning, - Exceeded, -} - -#[derive(Debug, Clone, Copy)] -pub enum QueueState { - Sending, - Queued { position: usize }, - Started, -} - -/// A thread of conversation with the LLM. -pub struct Thread { - agent_thread: Arc, - summary: ThreadSummary, - pending_send: Option>>, - pending_summary: Task>, - detailed_summary_task: Task>, - detailed_summary_tx: postage::watch::Sender, - detailed_summary_rx: postage::watch::Receiver, - completion_mode: agent_settings::CompletionMode, - messages: Vec, - checkpoints_by_message: HashMap, - project: Entity, - action_log: Entity, - last_restore_checkpoint: Option, - pending_checkpoint: Option, - initial_project_snapshot: Shared>>>, - request_token_usage: Vec, - cumulative_token_usage: TokenUsage, - exceeded_window_error: Option, - tool_use_limit_reached: bool, - // todo!(keep track of retries from the underlying agent) - feedback: Option, - message_feedback: HashMap, - last_auto_capture_at: Option, - last_received_chunk_at: Option, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum ThreadSummary { - Pending, - Generating, - Ready(SharedString), - Error, -} - -impl ThreadSummary { - pub const DEFAULT: SharedString = SharedString::new_static("New Thread"); - - pub fn or_default(&self) -> SharedString { - self.unwrap_or(Self::DEFAULT) - } - - pub fn unwrap_or(&self, message: impl Into) -> SharedString { - self.ready().unwrap_or_else(|| message.into()) - } - - pub fn ready(&self) -> Option { - match self { - ThreadSummary::Ready(summary) => Some(summary.clone()), - ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ExceededWindowError { - /// Model used when last message exceeded context window - model_id: LanguageModelId, - /// Token count including last message - token_count: u64, -} - -impl Thread { - pub fn load( - agent_thread: Arc, - project: Entity, - cx: &mut Context, - ) -> Self { - let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); - Self { - agent_thread, - summary: ThreadSummary::Pending, - pending_send: None, - pending_summary: Task::ready(None), - detailed_summary_task: Task::ready(None), - detailed_summary_tx, - detailed_summary_rx, - completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, - messages: todo!("read from agent"), - checkpoints_by_message: HashMap::default(), - project: project.clone(), - last_restore_checkpoint: None, - pending_checkpoint: None, - action_log: cx.new(|_| ActionLog::new(project.clone())), - initial_project_snapshot: { - let project_snapshot = Self::project_snapshot(project, cx); - cx.foreground_executor() - .spawn(async move { Some(project_snapshot.await) }) - .shared() - }, - request_token_usage: Vec::new(), - cumulative_token_usage: TokenUsage::default(), - exceeded_window_error: None, - tool_use_limit_reached: false, - feedback: None, - message_feedback: HashMap::default(), - last_auto_capture_at: None, - last_received_chunk_at: None, - } - } - - pub fn id(&self) -> AgentThreadId { - self.agent_thread.id() - } - - pub fn profile(&self) -> &AgentProfile { - todo!() - } - - pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context) { - todo!() - // if &id != self.profile.id() { - // self.profile = AgentProfile::new(id, self.tools.clone()); - // cx.emit(ThreadEvent::ProfileChanged); - // } - } - - pub fn is_empty(&self) -> bool { - self.messages.is_empty() - } - - pub fn advance_prompt_id(&mut self) { - todo!() - // self.last_prompt_id = PromptId::new(); - } - - pub fn project_context(&self) -> SharedProjectContext { - todo!() - // self.project_context.clone() - } - - pub fn summary(&self) -> &ThreadSummary { - &self.summary - } - - pub fn set_summary(&mut self, new_summary: impl Into, cx: &mut Context) { - todo!() - // let current_summary = match &self.summary { - // ThreadSummary::Pending | ThreadSummary::Generating => return, - // ThreadSummary::Ready(summary) => summary, - // ThreadSummary::Error => &ThreadSummary::DEFAULT, - // }; - - // let mut new_summary = new_summary.into(); - - // if new_summary.is_empty() { - // new_summary = ThreadSummary::DEFAULT; - // } - - // if current_summary != &new_summary { - // self.summary = ThreadSummary::Ready(new_summary); - // cx.emit(ThreadEvent::SummaryChanged); - // } - } - - pub fn completion_mode(&self) -> CompletionMode { - self.completion_mode - } - - pub fn set_completion_mode(&mut self, mode: CompletionMode) { - self.completion_mode = mode; - } - - pub fn message(&self, id: AgentThreadMessageId) -> Option<&Message> { - let index = self - .messages - .binary_search_by(|message| message.id.cmp(&id)) - .ok()?; - - self.messages.get(index) - } - - pub fn messages(&self) -> impl ExactSizeIterator { - self.messages.iter() - } - - pub fn is_generating(&self) -> bool { - self.pending_send.is_some() - } - - /// Indicates whether streaming of language model events is stale. - /// When `is_generating()` is false, this method returns `None`. - pub fn is_generation_stale(&self) -> Option { - const STALE_THRESHOLD: u128 = 250; - - self.last_received_chunk_at - .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD) - } - - fn received_chunk(&mut self) { - self.last_received_chunk_at = Some(Instant::now()); - } - - pub fn checkpoint_for_message(&self, id: AgentThreadMessageId) -> Option { - self.checkpoints_by_message.get(&id).cloned() - } - - pub fn restore_checkpoint( - &mut self, - checkpoint: ThreadCheckpoint, - cx: &mut Context, - ) -> Task> { - self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending { - message_id: checkpoint.message_id, - }); - cx.emit(ThreadEvent::CheckpointChanged); - cx.notify(); - - let git_store = self.project().read(cx).git_store().clone(); - let restore = git_store.update(cx, |git_store, cx| { - git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx) - }); - - cx.spawn(async move |this, cx| { - let result = restore.await; - this.update(cx, |this, cx| { - if let Err(err) = result.as_ref() { - this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error { - message_id: checkpoint.message_id, - error: err.to_string(), - }); - } else { - this.truncate(checkpoint.message_id, cx); - this.last_restore_checkpoint = None; - } - this.pending_checkpoint = None; - cx.emit(ThreadEvent::CheckpointChanged); - cx.notify(); - })?; - result - }) - } - - fn finalize_pending_checkpoint(&mut self, cx: &mut Context) { - let pending_checkpoint = if self.is_generating() { - return; - } else if let Some(checkpoint) = self.pending_checkpoint.take() { - checkpoint - } else { - return; - }; - - self.finalize_checkpoint(pending_checkpoint, cx); - } - - fn finalize_checkpoint( - &mut self, - pending_checkpoint: ThreadCheckpoint, - cx: &mut Context, - ) { - let git_store = self.project.read(cx).git_store().clone(); - let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx)); - cx.spawn(async move |this, cx| match final_checkpoint.await { - Ok(final_checkpoint) => { - let equal = git_store - .update(cx, |store, cx| { - store.compare_checkpoints( - pending_checkpoint.git_checkpoint.clone(), - final_checkpoint.clone(), - cx, - ) - })? - .await - .unwrap_or(false); - - if !equal { - this.update(cx, |this, cx| { - this.insert_checkpoint(pending_checkpoint, cx) - })?; - } - - Ok(()) - } - Err(_) => this.update(cx, |this, cx| { - this.insert_checkpoint(pending_checkpoint, cx) - }), - }) - .detach(); - } - - fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context) { - self.checkpoints_by_message - .insert(checkpoint.message_id, checkpoint); - cx.emit(ThreadEvent::CheckpointChanged); - cx.notify(); - } - - pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> { - self.last_restore_checkpoint.as_ref() - } - - pub fn truncate(&mut self, message_id: AgentThreadMessageId, cx: &mut Context) { - todo!("call truncate on the agent"); - let Some(message_ix) = self - .messages - .iter() - .rposition(|message| message.id == message_id) - else { - return; - }; - for deleted_message in self.messages.drain(message_ix..) { - self.checkpoints_by_message.remove(&deleted_message.id); - } - cx.notify(); - } - - pub fn is_turn_end(&self, ix: usize) -> bool { - todo!() - // if self.messages.is_empty() { - // return false; - // } - - // if !self.is_generating() && ix == self.messages.len() - 1 { - // return true; - // } - - // let Some(message) = self.messages.get(ix) else { - // return false; - // }; - - // if message.role != Role::Assistant { - // return false; - // } - - // self.messages - // .get(ix + 1) - // .and_then(|message| { - // self.message(message.id) - // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden) - // }) - // .unwrap_or(false) - } - - pub fn tool_use_limit_reached(&self) -> bool { - self.tool_use_limit_reached - } - - /// Returns whether any pending tool uses may perform edits - pub fn has_pending_edit_tool_uses(&self) -> bool { - todo!() - } - - // pub fn insert_user_message( - // &mut self, - // text: impl Into, - // loaded_context: ContextLoadResult, - // git_checkpoint: Option, - // creases: Vec, - // cx: &mut Context, - // ) -> AgentThreadMessageId { - // todo!("move this logic into send") - // if !loaded_context.referenced_buffers.is_empty() { - // self.action_log.update(cx, |log, cx| { - // for buffer in loaded_context.referenced_buffers { - // log.buffer_read(buffer, cx); - // } - // }); - // } - - // let message_id = self.insert_message( - // Role::User, - // vec![MessageSegment::Text(text.into())], - // loaded_context.loaded_context, - // creases, - // false, - // cx, - // ); - - // if let Some(git_checkpoint) = git_checkpoint { - // self.pending_checkpoint = Some(ThreadCheckpoint { - // message_id, - // git_checkpoint, - // }); - // } - - // self.auto_capture_telemetry(cx); - - // message_id - // } - - pub fn send(&mut self, message: Vec, cx: &mut Context) {} - - pub fn resume(&mut self, cx: &mut Context) { - todo!() - } - - pub fn edit( - &mut self, - message_id: AgentThreadMessageId, - message: Vec, - cx: &mut Context, - ) { - todo!() - } - - pub fn cancel(&mut self, cx: &mut Context) { - todo!() - } - - // pub fn insert_invisible_continue_message( - // &mut self, - // cx: &mut Context, - // ) -> AgentThreadMessageId { - // let id = self.insert_message( - // Role::User, - // vec![MessageSegment::Text("Continue where you left off".into())], - // LoadedContext::default(), - // vec![], - // true, - // cx, - // ); - // self.pending_checkpoint = None; - - // id - // } - - // pub fn insert_assistant_message( - // &mut self, - // segments: Vec, - // cx: &mut Context, - // ) -> AgentThreadMessageId { - // self.insert_message( - // Role::Assistant, - // segments, - // LoadedContext::default(), - // Vec::new(), - // false, - // cx, - // ) - // } - - // pub fn insert_message( - // &mut self, - // role: Role, - // segments: Vec, - // loaded_context: LoadedContext, - // creases: Vec, - // is_hidden: bool, - // cx: &mut Context, - // ) -> AgentThreadMessageId { - // let id = self.next_message_id.post_inc(); - // self.messages.push(Message { - // id, - // role, - // segments, - // loaded_context, - // creases, - // is_hidden, - // ui_only: false, - // }); - // self.touch_updated_at(); - // cx.emit(ThreadEvent::MessageAdded(id)); - // id - // } - - // pub fn edit_message( - // &mut self, - // id: AgentThreadMessageId, - // new_role: Role, - // new_segments: Vec, - // creases: Vec, - // loaded_context: Option, - // checkpoint: Option, - // cx: &mut Context, - // ) -> bool { - // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { - // return false; - // }; - // message.role = new_role; - // message.segments = new_segments; - // message.creases = creases; - // if let Some(context) = loaded_context { - // message.loaded_context = context; - // } - // if let Some(git_checkpoint) = checkpoint { - // self.checkpoints_by_message.insert( - // id, - // ThreadCheckpoint { - // message_id: id, - // git_checkpoint, - // }, - // ); - // } - // self.touch_updated_at(); - // cx.emit(ThreadEvent::MessageEdited(id)); - // true - // } - - /// Returns the representation of this [`Thread`] in a textual form. - /// - /// This is the representation we use when attaching a thread as context to another thread. - pub fn text(&self) -> String { - let mut text = String::new(); - - for message in &self.messages { - text.push_str(match message.role { - language_model::Role::User => "User:", - language_model::Role::Assistant => "Agent:", - language_model::Role::System => "System:", - }); - text.push('\n'); - - text.push_str(""); - text.push_str(&message.thinking); - text.push_str(""); - text.push_str(&message.text); - - // todo!('what about tools?'); - - text.push('\n'); - } - - text - } - - pub fn used_tools_since_last_user_message(&self) -> bool { - todo!() - // for message in self.messages.iter().rev() { - // if self.tool_use.message_has_tool_results(message.id) { - // return true; - // } else if message.role == Role::User { - // return false; - // } - // } - - // false - } - - pub fn start_generating_detailed_summary_if_needed( - &mut self, - thread_store: WeakEntity, - cx: &mut Context, - ) { - let Some(last_message_id) = self.messages.last().map(|message| message.id) else { - return; - }; - - match &*self.detailed_summary_rx.borrow() { - DetailedSummaryState::Generating { message_id, .. } - | DetailedSummaryState::Generated { message_id, .. } - if *message_id == last_message_id => - { - // Already up-to-date - return; - } - _ => {} - } - - let summary = self.agent_thread.summary(); - - *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating { - message_id: last_message_id, - }; - - // Replace the detailed summarization task if there is one, cancelling it. It would probably - // be better to allow the old task to complete, but this would require logic for choosing - // which result to prefer (the old task could complete after the new one, resulting in a - // stale summary). - self.detailed_summary_task = cx.spawn(async move |thread, cx| { - let Some(summary) = summary.await.log_err() else { - thread - .update(cx, |thread, _cx| { - *thread.detailed_summary_tx.borrow_mut() = - DetailedSummaryState::NotGenerated; - }) - .ok()?; - return None; - }; - - thread - .update(cx, |thread, _cx| { - *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated { - text: summary.into(), - message_id: last_message_id, - }; - }) - .ok()?; - - Some(()) - }); - } - - pub async fn wait_for_detailed_summary_or_text( - this: &Entity, - cx: &mut AsyncApp, - ) -> Option { - let mut detailed_summary_rx = this - .read_with(cx, |this, _cx| this.detailed_summary_rx.clone()) - .ok()?; - loop { - match detailed_summary_rx.recv().await? { - DetailedSummaryState::Generating { .. } => {} - DetailedSummaryState::NotGenerated => { - return this.read_with(cx, |this, _cx| this.text().into()).ok(); - } - DetailedSummaryState::Generated { text, .. } => return Some(text), - } - } - } - - pub fn latest_detailed_summary_or_text(&self) -> SharedString { - self.detailed_summary_rx - .borrow() - .text() - .unwrap_or_else(|| self.text().into()) - } - - pub fn is_generating_detailed_summary(&self) -> bool { - matches!( - &*self.detailed_summary_rx.borrow(), - DetailedSummaryState::Generating { .. } - ) - } - - pub fn feedback(&self) -> Option { - self.feedback - } - - pub fn message_feedback(&self, message_id: AgentThreadMessageId) -> Option { - self.message_feedback.get(&message_id).copied() - } - - pub fn report_message_feedback( - &mut self, - message_id: AgentThreadMessageId, - feedback: ThreadFeedback, - cx: &mut Context, - ) -> Task> { - todo!() - // if self.message_feedback.get(&message_id) == Some(&feedback) { - // return Task::ready(Ok(())); - // } - - // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); - // let serialized_thread = self.serialize(cx); - // let thread_id = self.id().clone(); - // let client = self.project.read(cx).client(); - - // let enabled_tool_names: Vec = self - // .profile - // .enabled_tools(cx) - // .iter() - // .map(|tool| tool.name()) - // .collect(); - - // self.message_feedback.insert(message_id, feedback); - - // cx.notify(); - - // let message_content = self - // .message(message_id) - // .map(|msg| msg.to_string()) - // .unwrap_or_default(); - - // cx.background_spawn(async move { - // let final_project_snapshot = final_project_snapshot.await; - // let serialized_thread = serialized_thread.await?; - // let thread_data = - // serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null); - - // let rating = match feedback { - // ThreadFeedback::Positive => "positive", - // ThreadFeedback::Negative => "negative", - // }; - // telemetry::event!( - // "Assistant Thread Rated", - // rating, - // thread_id, - // enabled_tool_names, - // message_id = message_id, - // message_content, - // thread_data, - // final_project_snapshot - // ); - // client.telemetry().flush_events().await; - - // Ok(()) - // }) - } - - pub fn report_feedback( - &mut self, - feedback: ThreadFeedback, - cx: &mut Context, - ) -> Task> { - todo!() - // let last_assistant_message_id = self - // .messages - // .iter() - // .rev() - // .find(|msg| msg.role == Role::Assistant) - // .map(|msg| msg.id); - - // if let Some(message_id) = last_assistant_message_id { - // self.report_message_feedback(message_id, feedback, cx) - // } else { - // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); - // let serialized_thread = self.serialize(cx); - // let thread_id = self.id().clone(); - // let client = self.project.read(cx).client(); - // self.feedback = Some(feedback); - // cx.notify(); - - // cx.background_spawn(async move { - // let final_project_snapshot = final_project_snapshot.await; - // let serialized_thread = serialized_thread.await?; - // let thread_data = serde_json::to_value(serialized_thread) - // .unwrap_or_else(|_| serde_json::Value::Null); - - // let rating = match feedback { - // ThreadFeedback::Positive => "positive", - // ThreadFeedback::Negative => "negative", - // }; - // telemetry::event!( - // "Assistant Thread Rated", - // rating, - // thread_id, - // thread_data, - // final_project_snapshot - // ); - // client.telemetry().flush_events().await; - - // Ok(()) - // }) - // } - } - - /// Create a snapshot of the current project state including git information and unsaved buffers. - fn project_snapshot( - project: Entity, - cx: &mut Context, - ) -> Task> { - let git_store = project.read(cx).git_store().clone(); - let worktree_snapshots: Vec<_> = project - .read(cx) - .visible_worktrees(cx) - .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) - .collect(); - - cx.spawn(async move |_, cx| { - let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; - - let mut unsaved_buffers = Vec::new(); - cx.update(|app_cx| { - let buffer_store = project.read(app_cx).buffer_store(); - for buffer_handle in buffer_store.read(app_cx).buffers() { - let buffer = buffer_handle.read(app_cx); - if buffer.is_dirty() { - if let Some(file) = buffer.file() { - let path = file.path().to_string_lossy().to_string(); - unsaved_buffers.push(path); - } - } - } - }) - .ok(); - - Arc::new(ProjectSnapshot { - worktree_snapshots, - unsaved_buffer_paths: unsaved_buffers, - timestamp: Utc::now(), - }) - }) - } - - fn worktree_snapshot( - worktree: Entity, - git_store: Entity, - cx: &App, - ) -> Task { - cx.spawn(async move |cx| { - // Get worktree path and snapshot - let worktree_info = cx.update(|app_cx| { - let worktree = worktree.read(app_cx); - let path = worktree.abs_path().to_string_lossy().to_string(); - let snapshot = worktree.snapshot(); - (path, snapshot) - }); - - let Ok((worktree_path, _snapshot)) = worktree_info else { - return WorktreeSnapshot { - worktree_path: String::new(), - git_state: None, - }; - }; - - let git_state = git_store - .update(cx, |git_store, cx| { - git_store - .repositories() - .values() - .find(|repo| { - repo.read(cx) - .abs_path_to_repo_path(&worktree.read(cx).abs_path()) - .is_some() - }) - .cloned() - }) - .ok() - .flatten() - .map(|repo| { - repo.update(cx, |repo, _| { - let current_branch = - repo.branch.as_ref().map(|branch| branch.name().to_owned()); - repo.send_job(None, |state, _| async move { - let RepositoryState::Local { backend, .. } = state else { - return GitState { - remote_url: None, - head_sha: None, - current_branch, - diff: None, - }; - }; - - let remote_url = backend.remote_url("origin"); - let head_sha = backend.head_sha().await; - let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); - - GitState { - remote_url, - head_sha, - current_branch, - diff, - } - }) - }) - }); - - let git_state = match git_state { - Some(git_state) => match git_state.ok() { - Some(git_state) => git_state.await.ok(), - None => None, - }, - None => None, - }; - - WorktreeSnapshot { - worktree_path, - git_state, - } - }) - } - - pub fn to_markdown(&self, cx: &App) -> Result { - todo!() - // let mut markdown = Vec::new(); - - // let summary = self.summary().or_default(); - // writeln!(markdown, "# {summary}\n")?; - - // for message in self.messages() { - // writeln!( - // markdown, - // "## {role}\n", - // role = match message.role { - // Role::User => "User", - // Role::Assistant => "Agent", - // Role::System => "System", - // } - // )?; - - // if !message.loaded_context.text.is_empty() { - // writeln!(markdown, "{}", message.loaded_context.text)?; - // } - - // if !message.loaded_context.images.is_empty() { - // writeln!( - // markdown, - // "\n{} images attached as context.\n", - // message.loaded_context.images.len() - // )?; - // } - - // for segment in &message.segments { - // match segment { - // MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, - // MessageSegment::Thinking { text, .. } => { - // writeln!(markdown, "\n{}\n\n", text)? - // } - // MessageSegment::RedactedThinking(_) => {} - // } - // } - - // for tool_use in self.tool_uses_for_message(message.id, cx) { - // writeln!( - // markdown, - // "**Use Tool: {} ({})**", - // tool_use.name, tool_use.id - // )?; - // writeln!(markdown, "```json")?; - // writeln!( - // markdown, - // "{}", - // serde_json::to_string_pretty(&tool_use.input)? - // )?; - // writeln!(markdown, "```")?; - // } - - // for tool_result in self.tool_results_for_message(message.id) { - // write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?; - // if tool_result.is_error { - // write!(markdown, " (Error)")?; - // } - - // writeln!(markdown, "**\n")?; - // match &tool_result.content { - // LanguageModelToolResultContent::Text(text) => { - // writeln!(markdown, "{text}")?; - // } - // LanguageModelToolResultContent::Image(image) => { - // writeln!(markdown, "![Image](data:base64,{})", image.source)?; - // } - // } - - // if let Some(output) = tool_result.output.as_ref() { - // writeln!( - // markdown, - // "\n\nDebug Output:\n\n```json\n{}\n```\n", - // serde_json::to_string_pretty(output)? - // )?; - // } - // } - // } - - // Ok(String::from_utf8_lossy(&markdown).to_string()) - } - - pub fn keep_edits_in_range( - &mut self, - buffer: Entity, - buffer_range: Range, - cx: &mut Context, - ) { - self.action_log.update(cx, |action_log, cx| { - action_log.keep_edits_in_range(buffer, buffer_range, cx) - }); - } - - pub fn keep_all_edits(&mut self, cx: &mut Context) { - self.action_log - .update(cx, |action_log, cx| action_log.keep_all_edits(cx)); - } - - pub fn reject_edits_in_ranges( - &mut self, - buffer: Entity, - buffer_ranges: Vec>, - cx: &mut Context, - ) -> Task> { - self.action_log.update(cx, |action_log, cx| { - action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx) - }) - } - - pub fn action_log(&self) -> &Entity { - &self.action_log - } - - pub fn project(&self) -> &Entity { - &self.project - } - - pub fn auto_capture_telemetry(&mut self, cx: &mut Context) { - todo!() - // if !cx.has_flag::() { - // return; - // } - - // let now = Instant::now(); - // if let Some(last) = self.last_auto_capture_at { - // if now.duration_since(last).as_secs() < 10 { - // return; - // } - // } - - // self.last_auto_capture_at = Some(now); - - // let thread_id = self.id().clone(); - // let github_login = self - // .project - // .read(cx) - // .user_store() - // .read(cx) - // .current_user() - // .map(|user| user.github_login.clone()); - // let client = self.project.read(cx).client(); - // let serialize_task = self.serialize(cx); - - // cx.background_executor() - // .spawn(async move { - // if let Ok(serialized_thread) = serialize_task.await { - // if let Ok(thread_data) = serde_json::to_value(serialized_thread) { - // telemetry::event!( - // "Agent Thread Auto-Captured", - // thread_id = thread_id.to_string(), - // thread_data = thread_data, - // auto_capture_reason = "tracked_user", - // github_login = github_login - // ); - - // client.telemetry().flush_events().await; - // } - // } - // }) - // .detach(); - } - - pub fn cumulative_token_usage(&self) -> TokenUsage { - self.cumulative_token_usage - } - - pub fn token_usage_up_to_message(&self, message_id: AgentThreadMessageId) -> TotalTokenUsage { - todo!() - // let Some(model) = self.configured_model.as_ref() else { - // return TotalTokenUsage::default(); - // }; - - // let max = model.model.max_token_count(); - - // let index = self - // .messages - // .iter() - // .position(|msg| msg.id == message_id) - // .unwrap_or(0); - - // if index == 0 { - // return TotalTokenUsage { total: 0, max }; - // } - - // let token_usage = &self - // .request_token_usage - // .get(index - 1) - // .cloned() - // .unwrap_or_default(); - - // TotalTokenUsage { - // total: token_usage.total_tokens(), - // max, - // } - } - - pub fn total_token_usage(&self) -> Option { - todo!() - // let model = self.configured_model.as_ref()?; - - // let max = model.model.max_token_count(); - - // if let Some(exceeded_error) = &self.exceeded_window_error { - // if model.model.id() == exceeded_error.model_id { - // return Some(TotalTokenUsage { - // total: exceeded_error.token_count, - // max, - // }); - // } - // } - - // let total = self - // .token_usage_at_last_message() - // .unwrap_or_default() - // .total_tokens(); - - // Some(TotalTokenUsage { total, max }) - } - - fn token_usage_at_last_message(&self) -> Option { - self.request_token_usage - .get(self.messages.len().saturating_sub(1)) - .or_else(|| self.request_token_usage.last()) - .cloned() - } - - fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) { - let placeholder = self.token_usage_at_last_message().unwrap_or_default(); - self.request_token_usage - .resize(self.messages.len(), placeholder); - - if let Some(last) = self.request_token_usage.last_mut() { - *last = token_usage; - } - } - - fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project.update(cx, |project, cx| { - project.user_store().update(cx, |user_store, cx| { - user_store.update_model_request_usage( - ModelRequestUsage(RequestUsage { - amount: amount as i32, - limit, - }), - cx, - ) - }) - }); - } -} - -#[derive(Debug, Clone, Error)] -pub enum ThreadError { - #[error("Payment required")] - PaymentRequired, - #[error("Model request limit reached")] - ModelRequestLimitReached { plan: Plan }, - #[error("Message {header}: {message}")] - Message { - header: SharedString, - message: SharedString, - }, -} - -#[derive(Debug, Clone)] -pub enum ThreadEvent { - ShowError(ThreadError), - StreamedCompletion, - ReceivedTextChunk, - NewRequest, - StreamedAssistantText(AgentThreadMessageId, String), - StreamedAssistantThinking(AgentThreadMessageId, String), - StreamedToolUse { - tool_use_id: LanguageModelToolUseId, - ui_text: Arc, - input: serde_json::Value, - }, - MissingToolUse { - tool_use_id: LanguageModelToolUseId, - ui_text: Arc, - }, - InvalidToolInput { - tool_use_id: LanguageModelToolUseId, - ui_text: Arc, - invalid_input_json: Arc, - }, - Stopped(Result>), - MessageAdded(AgentThreadMessageId), - MessageEdited(AgentThreadMessageId), - MessageDeleted(AgentThreadMessageId), - SummaryGenerated, - SummaryChanged, - CheckpointChanged, - ToolConfirmationNeeded, - ToolUseLimitReached, - CancelEditing, - CompletionCanceled, - ProfileChanged, - RetriesFailed { - message: SharedString, - }, -} - -impl EventEmitter for Thread {} - -struct PendingCompletion { - id: usize, - queue_state: QueueState, - _task: Task<()>, -} - -/// Resolves tool name conflicts by ensuring all tool names are unique. -/// -/// When multiple tools have the same name, this function applies the following rules: -/// 1. Native tools always keep their original name -/// 2. Context server tools get prefixed with their server ID and an underscore -/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters) -/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out -/// -/// Note: This function assumes that built-in tools occur before MCP tools in the tools list. -fn resolve_tool_name_conflicts(tools: &[Arc]) -> Vec<(String, Arc)> { - fn resolve_tool_name(tool: &Arc) -> String { - let mut tool_name = tool.name(); - tool_name.truncate(MAX_TOOL_NAME_LENGTH); - tool_name - } - - const MAX_TOOL_NAME_LENGTH: usize = 64; - - let mut duplicated_tool_names = HashSet::default(); - let mut seen_tool_names = HashSet::default(); - for tool in tools { - let tool_name = resolve_tool_name(tool); - if seen_tool_names.contains(&tool_name) { - debug_assert!( - tool.source() != assistant_tool::ToolSource::Native, - "There are two built-in tools with the same name: {}", - tool_name - ); - duplicated_tool_names.insert(tool_name); - } else { - seen_tool_names.insert(tool_name); - } - } - - if duplicated_tool_names.is_empty() { - return tools - .into_iter() - .map(|tool| (resolve_tool_name(tool), tool.clone())) - .collect(); - } - - tools - .into_iter() - .filter_map(|tool| { - let mut tool_name = resolve_tool_name(tool); - if !duplicated_tool_names.contains(&tool_name) { - return Some((tool_name, tool.clone())); - } - match tool.source() { - assistant_tool::ToolSource::Native => { - // Built-in tools always keep their original name - Some((tool_name, tool.clone())) - } - assistant_tool::ToolSource::ContextServer { id } => { - // Context server tools are prefixed with the context server ID, and truncated if necessary - tool_name.insert(0, '_'); - if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH { - let len = MAX_TOOL_NAME_LENGTH - tool_name.len(); - let mut id = id.to_string(); - id.truncate(len); - tool_name.insert_str(0, &id); - } else { - tool_name.insert_str(0, &id); - } - - tool_name.truncate(MAX_TOOL_NAME_LENGTH); - - if seen_tool_names.contains(&tool_name) { - log::error!("Cannot resolve tool name conflict for tool {}", tool.name()); - None - } else { - Some((tool_name, tool.clone())) - } - } - } - }) - .collect() -} diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 516151e9ff90dd6dc4a3e4b3dd5eff37522db7f2..efad0e13f06aa90013cf3458d164710ee4b26c64 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -1,8 +1,7 @@ use crate::{ + MessageId, ThreadId, context_server_tool::ContextServerTool, - thread::{ - DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, - }, + thread::{DetailedSummaryState, ExceededWindowError, ProjectSnapshot, Thread}, }; use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; @@ -400,35 +399,17 @@ impl ThreadStore { self.threads.iter() } - pub fn create_thread(&mut self, cx: &mut Context) -> Entity { - cx.new(|cx| { - Thread::new( - self.project.clone(), - self.tools.clone(), - self.prompt_builder.clone(), - self.project_context.clone(), - cx, - ) - }) - } - - pub fn create_thread_from_serialized( - &mut self, - serialized: SerializedThread, - cx: &mut Context, - ) -> Entity { - cx.new(|cx| { - Thread::deserialize( - ThreadId::new(), - serialized, - self.project.clone(), - self.tools.clone(), - self.prompt_builder.clone(), - self.project_context.clone(), - None, - cx, - ) - }) + pub fn create_thread(&mut self, cx: &mut Context) -> Task>> { + todo!() + // cx.new(|cx| { + // Thread::new( + // self.project.clone(), + // self.tools.clone(), + // self.prompt_builder.clone(), + // self.project_context.clone(), + // cx, + // ) + // }) } pub fn open_thread( @@ -447,51 +428,53 @@ impl ThreadStore { .await? .with_context(|| format!("no thread found with ID: {id:?}"))?; - let thread = this.update_in(cx, |this, window, cx| { - cx.new(|cx| { - Thread::deserialize( - id.clone(), - thread, - this.project.clone(), - this.tools.clone(), - this.prompt_builder.clone(), - this.project_context.clone(), - Some(window), - cx, - ) - }) - })?; - - Ok(thread) + todo!(); + // let thread = this.update_in(cx, |this, window, cx| { + // cx.new(|cx| { + // Thread::deserialize( + // id.clone(), + // thread, + // this.project.clone(), + // this.tools.clone(), + // this.prompt_builder.clone(), + // this.project_context.clone(), + // Some(window), + // cx, + // ) + // }) + // })?; + // Ok(thread) }) } pub fn save_thread(&self, thread: &Entity, cx: &mut Context) -> Task> { - let (metadata, serialized_thread) = - thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx))); - - let database_future = ThreadsDatabase::global_future(cx); - cx.spawn(async move |this, cx| { - let serialized_thread = serialized_thread.await?; - let database = database_future.await.map_err(|err| anyhow!(err))?; - database.save_thread(metadata, serialized_thread).await?; - - this.update(cx, |this, cx| this.reload(cx))?.await - }) + todo!() + // let (metadata, serialized_thread) = + // thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx))); + + // let database_future = ThreadsDatabase::global_future(cx); + // cx.spawn(async move |this, cx| { + // let serialized_thread = serialized_thread.await?; + // let database = database_future.await.map_err(|err| anyhow!(err))?; + // database.save_thread(metadata, serialized_thread).await?; + + // this.update(cx, |this, cx| this.reload(cx))?.await + // }) } pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context) -> Task> { - let id = id.clone(); - let database_future = ThreadsDatabase::global_future(cx); - cx.spawn(async move |this, cx| { - let database = database_future.await.map_err(|err| anyhow!(err))?; - database.delete_thread(id.clone()).await?; - - this.update(cx, |this, cx| { - this.threads.retain(|thread| thread.id != id); - cx.notify(); - }) - }) + todo!() + // let id = id.clone(); + // let database_future = ThreadsDatabase::global_future(cx); + // cx.spawn(async move |this, cx| { + // let database = database_future.await.map_err(|err| anyhow!(err))?; + // database.delete_thread(id.clone()).await?; + + // this.update(cx, |this, cx| { + // this.threads.retain(|thread| thread.id != id); + // cx.notify(); + // }) + // }) } pub fn reload(&self, cx: &mut Context) -> Task> { @@ -1067,7 +1050,7 @@ impl ThreadsDatabase { #[cfg(test)] mod tests { use super::*; - use crate::thread::{DetailedSummaryState, MessageId}; + use crate::{MessageId, thread::DetailedSummaryState}; use chrono::Utc; use language_model::{Role, TokenUsage}; use pretty_assertions::assert_eq; diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs deleted file mode 100644 index 76de3d20223fcd1c22631029d2040c9109d9ac0d..0000000000000000000000000000000000000000 --- a/crates/agent/src/tool_use.rs +++ /dev/null @@ -1,567 +0,0 @@ -use crate::{ - thread::{MessageId, PromptId, ThreadId}, - thread_store::SerializedMessage, -}; -use anyhow::Result; -use assistant_tool::{ - AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, -}; -use collections::HashMap; -use futures::{FutureExt as _, future::Shared}; -use gpui::{App, Entity, SharedString, Task, Window}; -use icons::IconName; -use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role, -}; -use project::Project; -use std::sync::Arc; -use util::truncate_lines_to_byte_limit; - -#[derive(Debug)] -pub struct ToolUse { - pub id: LanguageModelToolUseId, - pub name: SharedString, - pub ui_text: SharedString, - pub status: ToolUseStatus, - pub input: serde_json::Value, - pub icon: icons::IconName, - pub needs_confirmation: bool, -} - -pub struct ToolUseState { - tools: Entity, - tool_uses_by_assistant_message: HashMap>, - tool_results: HashMap, - pending_tool_uses_by_id: HashMap, - tool_result_cards: HashMap, - tool_use_metadata_by_id: HashMap, -} - -impl ToolUseState { - pub fn new(tools: Entity) -> Self { - Self { - tools, - tool_uses_by_assistant_message: HashMap::default(), - tool_results: HashMap::default(), - pending_tool_uses_by_id: HashMap::default(), - tool_result_cards: HashMap::default(), - tool_use_metadata_by_id: HashMap::default(), - } - } - - /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s. - /// - /// Accepts a function to filter the tools that should be used to populate the state. - /// - /// If `window` is `None` (e.g., when in headless mode or when running evals), - /// tool cards won't be deserialized - pub fn from_serialized_messages( - tools: Entity, - messages: &[SerializedMessage], - project: Entity, - window: Option<&mut Window>, // None in headless mode - cx: &mut App, - ) -> Self { - let mut this = Self::new(tools); - let mut tool_names_by_id = HashMap::default(); - let mut window = window; - - for message in messages { - match message.role { - Role::Assistant => { - if !message.tool_uses.is_empty() { - let tool_uses = message - .tool_uses - .iter() - .map(|tool_use| LanguageModelToolUse { - id: tool_use.id.clone(), - name: tool_use.name.clone().into(), - raw_input: tool_use.input.to_string(), - input: tool_use.input.clone(), - is_input_complete: true, - }) - .collect::>(); - - tool_names_by_id.extend( - tool_uses - .iter() - .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())), - ); - - this.tool_uses_by_assistant_message - .insert(message.id, tool_uses); - - for tool_result in &message.tool_results { - let tool_use_id = tool_result.tool_use_id.clone(); - let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else { - log::warn!("no tool name found for tool use: {tool_use_id:?}"); - continue; - }; - - this.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name: tool_use.clone(), - is_error: tool_result.is_error, - content: tool_result.content.clone(), - output: tool_result.output.clone(), - }, - ); - - if let Some(window) = &mut window { - if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) { - if let Some(output) = tool_result.output.clone() { - if let Some(card) = tool.deserialize_card( - output, - project.clone(), - window, - cx, - ) { - this.tool_result_cards.insert(tool_use_id, card); - } - } - } - } - } - } - } - Role::System | Role::User => {} - } - } - - this - } - - pub fn cancel_pending(&mut self) -> Vec { - let mut cancelled_tool_uses = Vec::new(); - self.pending_tool_uses_by_id - .retain(|tool_use_id, tool_use| { - if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) { - return true; - } - - let content = "Tool canceled by user".into(); - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name: tool_use.name.clone(), - content, - output: None, - is_error: true, - }, - ); - cancelled_tool_uses.push(tool_use.clone()); - false - }); - cancelled_tool_uses - } - - pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { - self.pending_tool_uses_by_id.values().collect() - } - - pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { - let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { - return Vec::new(); - }; - - let mut tool_uses = Vec::new(); - - for tool_use in tool_uses_for_message.iter() { - let tool_result = self.tool_results.get(&tool_use.id); - - let status = (|| { - if let Some(tool_result) = tool_result { - let content = tool_result - .content - .to_str() - .map(|str| str.to_owned().into()) - .unwrap_or_default(); - - return if tool_result.is_error { - ToolUseStatus::Error(content) - } else { - ToolUseStatus::Finished(content) - }; - } - - if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) { - match pending_tool_use.status { - PendingToolUseStatus::Idle => ToolUseStatus::Pending, - PendingToolUseStatus::NeedsConfirmation { .. } => { - ToolUseStatus::NeedsConfirmation - } - PendingToolUseStatus::Running { .. } => ToolUseStatus::Running, - PendingToolUseStatus::Error(ref err) => { - ToolUseStatus::Error(err.clone().into()) - } - PendingToolUseStatus::InputStillStreaming => { - ToolUseStatus::InputStillStreaming - } - } - } else { - ToolUseStatus::Pending - } - })(); - - let (icon, needs_confirmation) = - if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { - (tool.icon(), tool.needs_confirmation(&tool_use.input, cx)) - } else { - (IconName::Cog, false) - }; - - tool_uses.push(ToolUse { - id: tool_use.id.clone(), - name: tool_use.name.clone().into(), - ui_text: self.tool_ui_label( - &tool_use.name, - &tool_use.input, - tool_use.is_input_complete, - cx, - ), - input: tool_use.input.clone(), - status, - icon, - needs_confirmation, - }) - } - - tool_uses - } - - pub fn tool_ui_label( - &self, - tool_name: &str, - input: &serde_json::Value, - is_input_complete: bool, - cx: &App, - ) -> SharedString { - if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) { - if is_input_complete { - tool.ui_text(input).into() - } else { - tool.still_streaming_ui_text(input).into() - } - } else { - format!("Unknown tool {tool_name:?}").into() - } - } - - pub fn tool_results_for_message( - &self, - assistant_message_id: MessageId, - ) -> Vec<&LanguageModelToolResult> { - let Some(tool_uses) = self - .tool_uses_by_assistant_message - .get(&assistant_message_id) - else { - return Vec::new(); - }; - - tool_uses - .iter() - .filter_map(|tool_use| self.tool_results.get(&tool_use.id)) - .collect() - } - - pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool { - self.tool_uses_by_assistant_message - .get(&assistant_message_id) - .map_or(false, |results| !results.is_empty()) - } - - pub fn tool_result( - &self, - tool_use_id: &LanguageModelToolUseId, - ) -> Option<&LanguageModelToolResult> { - self.tool_results.get(tool_use_id) - } - - pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> { - self.tool_result_cards.get(tool_use_id) - } - - pub fn insert_tool_result_card( - &mut self, - tool_use_id: LanguageModelToolUseId, - card: AnyToolCard, - ) { - self.tool_result_cards.insert(tool_use_id, card); - } - - pub fn request_tool_use( - &mut self, - assistant_message_id: MessageId, - tool_use: LanguageModelToolUse, - metadata: ToolUseMetadata, - cx: &App, - ) -> Arc { - let tool_uses = self - .tool_uses_by_assistant_message - .entry(assistant_message_id) - .or_default(); - - let mut existing_tool_use_found = false; - - for existing_tool_use in tool_uses.iter_mut() { - if existing_tool_use.id == tool_use.id { - *existing_tool_use = tool_use.clone(); - existing_tool_use_found = true; - } - } - - if !existing_tool_use_found { - tool_uses.push(tool_use.clone()); - } - - let status = if tool_use.is_input_complete { - self.tool_use_metadata_by_id - .insert(tool_use.id.clone(), metadata); - - PendingToolUseStatus::Idle - } else { - PendingToolUseStatus::InputStillStreaming - }; - - let ui_text: Arc = self - .tool_ui_label( - &tool_use.name, - &tool_use.input, - tool_use.is_input_complete, - cx, - ) - .into(); - - let may_perform_edits = self - .tools - .read(cx) - .tool(&tool_use.name, cx) - .is_some_and(|tool| tool.may_perform_edits()); - - self.pending_tool_uses_by_id.insert( - tool_use.id.clone(), - PendingToolUse { - assistant_message_id, - id: tool_use.id, - name: tool_use.name.clone(), - ui_text: ui_text.clone(), - input: tool_use.input, - may_perform_edits, - status, - }, - ); - - ui_text - } - - pub fn run_pending_tool( - &mut self, - tool_use_id: LanguageModelToolUseId, - ui_text: SharedString, - task: Task<()>, - ) { - if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { - tool_use.ui_text = ui_text.into(); - tool_use.status = PendingToolUseStatus::Running { - _task: task.shared(), - }; - } - } - - pub fn confirm_tool_use( - &mut self, - tool_use_id: LanguageModelToolUseId, - ui_text: impl Into>, - input: serde_json::Value, - request: Arc, - tool: Arc, - ) { - if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { - let ui_text = ui_text.into(); - tool_use.ui_text = ui_text.clone(); - let confirmation = Confirmation { - tool_use_id, - input, - request, - tool, - ui_text, - }; - tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation)); - } - } - - pub fn insert_tool_output( - &mut self, - tool_use_id: LanguageModelToolUseId, - tool_name: Arc, - output: Result, - configured_model: Option<&ConfiguredModel>, - ) -> Option { - let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); - - telemetry::event!( - "Agent Tool Finished", - model = metadata - .as_ref() - .map(|metadata| metadata.model.telemetry_id()), - model_provider = metadata - .as_ref() - .map(|metadata| metadata.model.provider_id().to_string()), - thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()), - prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()), - tool_name, - success = output.is_ok() - ); - - match output { - Ok(output) => { - let tool_result = output.content; - const BYTES_PER_TOKEN_ESTIMATE: usize = 3; - - let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id); - - // Protect from overly large output - let tool_output_limit = configured_model - .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE) - .unwrap_or(usize::MAX); - - let content = match tool_result { - ToolResultContent::Text(text) => { - let text = if text.len() < tool_output_limit { - text - } else { - let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit); - format!( - "Tool result too long. The first {} bytes:\n\n{}", - truncated.len(), - truncated - ) - }; - LanguageModelToolResultContent::Text(text.into()) - } - ToolResultContent::Image(language_model_image) => { - if language_model_image.estimate_tokens() < tool_output_limit { - LanguageModelToolResultContent::Image(language_model_image) - } else { - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name, - content: "Tool responded with an image that would exceeded the remaining tokens".into(), - is_error: true, - output: None, - }, - ); - - return old_use; - } - } - }; - - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name, - content, - is_error: false, - output: output.output, - }, - ); - - old_use - } - Err(err) => { - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name, - content: LanguageModelToolResultContent::Text(err.to_string().into()), - is_error: true, - output: None, - }, - ); - - if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { - tool_use.status = PendingToolUseStatus::Error(err.to_string().into()); - } - - self.pending_tool_uses_by_id.get(&tool_use_id).cloned() - } - } - } - - pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool { - self.tool_uses_by_assistant_message - .contains_key(&assistant_message_id) - } - - pub fn tool_results( - &self, - assistant_message_id: MessageId, - ) -> impl Iterator)> { - self.tool_uses_by_assistant_message - .get(&assistant_message_id) - .into_iter() - .flatten() - .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id))) - } -} - -#[derive(Debug, Clone)] -pub struct PendingToolUse { - pub id: LanguageModelToolUseId, - /// The ID of the Assistant message in which the tool use was requested. - #[allow(unused)] - pub assistant_message_id: MessageId, - pub name: Arc, - pub ui_text: Arc, - pub input: serde_json::Value, - pub status: PendingToolUseStatus, - pub may_perform_edits: bool, -} - -#[derive(Debug, Clone)] -pub struct Confirmation { - pub tool_use_id: LanguageModelToolUseId, - pub input: serde_json::Value, - pub ui_text: Arc, - pub request: Arc, - pub tool: Arc, -} - -#[derive(Debug, Clone)] -pub enum PendingToolUseStatus { - InputStillStreaming, - Idle, - NeedsConfirmation(Arc), - Running { _task: Shared> }, - Error(#[allow(unused)] Arc), -} - -impl PendingToolUseStatus { - pub fn is_idle(&self) -> bool { - matches!(self, PendingToolUseStatus::Idle) - } - - pub fn is_error(&self) -> bool { - matches!(self, PendingToolUseStatus::Error(_)) - } - - pub fn needs_confirmation(&self) -> bool { - matches!(self, PendingToolUseStatus::NeedsConfirmation { .. }) - } -} - -#[derive(Clone)] -pub struct ToolUseMetadata { - pub model: Arc, - pub thread_id: ThreadId, - pub prompt_id: PromptId, -} diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 7ee3b7158b6f9f8db6788c80f93123bd1ad463c6..d25a3fefff59202e6710337888d887e5f31a0beb 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -7,7 +7,7 @@ use crate::ui::{ use crate::{AgentPanel, ModelUsageContext}; use agent::{ ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore, - Thread, ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadSummary, + Thread, ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadTitle, context::{self, AgentContextHandle, RULES_ICON}, thread_store::RulesLoadingError, tool_use::{PendingToolUseStatus, ToolUse}, @@ -816,23 +816,24 @@ impl ActiveThread { _load_edited_message_context_task: None, }; - for message in thread.read(cx).messages().cloned().collect::>() { - let rendered_message = RenderedMessage::from_segments( - &message.segments, - this.language_registry.clone(), - cx, - ); - this.push_rendered_message(message.id, rendered_message); - - for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) { - this.render_tool_use_markdown( - tool_use.id.clone(), - tool_use.ui_text.clone(), - &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(), - tool_use.status.text(), - cx, - ); - } + for message in thread.read(cx).messages() { + todo!() + // let rendered_message = RenderedMessage::from_segments( + // &message.segments, + // this.language_registry.clone(), + // cx, + // ); + // this.push_rendered_message(message.id, rendered_message); + + // for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) { + // this.render_tool_use_markdown( + // tool_use.id.clone(), + // tool_use.ui_text.clone(), + // &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(), + // tool_use.status.text(), + // cx, + // ); + // } } this @@ -846,19 +847,18 @@ impl ActiveThread { self.messages.is_empty() } - pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary { - self.thread.read(cx).summary() + pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadTitle { + self.thread.read(cx).title() } pub fn regenerate_summary(&self, cx: &mut App) { - self.thread.update(cx, |thread, cx| thread.summarize(cx)) + self.thread + .update(cx, |thread, cx| thread.regenerate_summary(cx)) } pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool { self.last_error.take(); - self.thread.update(cx, |thread, cx| { - thread.cancel_last_completion(Some(window.window_handle()), cx) - }) + self.thread.update(cx, |thread, cx| thread.cancel(cx)) } pub fn last_error(&self) -> Option { @@ -1185,7 +1185,7 @@ impl ActiveThread { return; } - let title = self.thread.read(cx).summary().unwrap_or("Agent Panel"); + let title = self.thread.read(cx).title().unwrap_or("Agent Panel"); match AgentSettings::get_global(cx).notify_when_agent_waiting { NotifyWhenAgentWaiting::PrimaryScreen => { @@ -3605,7 +3605,7 @@ pub(crate) fn open_active_thread_as_markdown( workspace.update_in(cx, |workspace, window, cx| { let thread = thread.read(cx); let markdown = thread.to_markdown(cx)?; - let thread_summary = thread.summary().or_default().to_string(); + let thread_summary = thread.title().or_default().to_string(); let project = workspace.project().clone(); @@ -3776,357 +3776,357 @@ fn open_editor_at_position( }) } -#[cfg(test)] -mod tests { - use super::*; - use agent::{MessageSegment, context::ContextLoadResult, thread_store}; - use assistant_tool::{ToolRegistry, ToolWorkingSet}; - use editor::EditorSettings; - use fs::FakeFs; - use gpui::{AppContext, TestAppContext, VisualTestContext}; - use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelRegistry, - fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}, - }; - use project::Project; - use prompt_store::PromptBuilder; - use serde_json::json; - use settings::SettingsStore; - use util::path; - use workspace::CollaboratorId; - - #[gpui::test] - async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project( - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (cx, _active_thread, workspace, thread, model) = - setup_test_environment(cx, project.clone()).await; - - // Insert user message without any context (empty context vector) - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "What is the best way to learn Rust?", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - }); - - // Stream response to user message - thread.update(cx, |thread, cx| { - let intent = CompletionIntent::UserPrompt; - let request = thread.to_completion_request(model.clone(), intent, cx); - thread.stream_completion(request, model, intent, cx.active_window(), cx) - }); - // Follow the agent - cx.update(|window, cx| { - workspace.update(cx, |workspace, cx| { - workspace.follow(CollaboratorId::Agent, window, cx); - }) - }); - assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent))); - - // Cancel the current completion - thread.update(cx, |thread, cx| { - thread.cancel_last_completion(cx.active_window(), cx) - }); - - cx.executor().run_until_parked(); - - // No longer following the agent - assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent))); - } - - #[gpui::test] - async fn test_reinserting_creases_for_edited_message(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - - let (cx, active_thread, _, thread, model) = - setup_test_environment(cx, project.clone()).await; - cx.update(|_, cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model( - Some(ConfiguredModel { - provider: Arc::new(FakeLanguageModelProvider), - model, - }), - cx, - ); - }); - }); - - let creases = vec![MessageCrease { - range: 14..22, - icon_path: "icon".into(), - label: "foo.txt".into(), - context: None, - }]; - - let message = thread.update(cx, |thread, cx| { - let message_id = thread.insert_user_message( - "Tell me about @foo.txt", - ContextLoadResult::default(), - None, - creases, - cx, - ); - thread.message(message_id).cloned().unwrap() - }); - - active_thread.update_in(cx, |active_thread, window, cx| { - if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) { - active_thread.start_editing_message( - message.id, - message_text, - message.creases.as_slice(), - window, - cx, - ); - } - let editor = active_thread - .editing_message - .as_ref() - .unwrap() - .1 - .editor - .clone(); - editor.update(cx, |editor, cx| editor.edit([(0..13, "modified")], cx)); - active_thread.confirm_editing_message(&Default::default(), window, cx); - }); - cx.run_until_parked(); - - let message = thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap()); - active_thread.update_in(cx, |active_thread, window, cx| { - if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) { - active_thread.start_editing_message( - message.id, - message_text, - message.creases.as_slice(), - window, - cx, - ); - } - let editor = active_thread - .editing_message - .as_ref() - .unwrap() - .1 - .editor - .clone(); - let text = editor.update(cx, |editor, cx| editor.text(cx)); - assert_eq!(text, "modified @foo.txt"); - }); - } - - #[gpui::test] - async fn test_editing_message_cancels_previous_completion(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - - let (cx, active_thread, _, thread, model) = - setup_test_environment(cx, project.clone()).await; - - cx.update(|_, cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model( - Some(ConfiguredModel { - provider: Arc::new(FakeLanguageModelProvider), - model: model.clone(), - }), - cx, - ); - }); - }); - - // Track thread events to verify cancellation - let cancellation_events = Arc::new(std::sync::Mutex::new(Vec::new())); - let new_request_events = Arc::new(std::sync::Mutex::new(Vec::new())); - - let _subscription = cx.update(|_, cx| { - let cancellation_events = cancellation_events.clone(); - let new_request_events = new_request_events.clone(); - cx.subscribe( - &thread, - move |_thread, event: &ThreadEvent, _cx| match event { - ThreadEvent::CompletionCanceled => { - cancellation_events.lock().unwrap().push(()); - } - ThreadEvent::NewRequest => { - new_request_events.lock().unwrap().push(()); - } - _ => {} - }, - ) - }); - - // Insert a user message and start streaming a response - let message = thread.update(cx, |thread, cx| { - let message_id = thread.insert_user_message( - "Hello, how are you?", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - thread.advance_prompt_id(); - thread.send_to_model( - model.clone(), - CompletionIntent::UserPrompt, - cx.active_window(), - cx, - ); - thread.message(message_id).cloned().unwrap() - }); - - cx.run_until_parked(); - - // Verify that a completion is in progress - assert!(cx.read(|cx| thread.read(cx).is_generating())); - assert_eq!(new_request_events.lock().unwrap().len(), 1); - - // Edit the message while the completion is still running - active_thread.update_in(cx, |active_thread, window, cx| { - if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) { - active_thread.start_editing_message( - message.id, - message_text, - message.creases.as_slice(), - window, - cx, - ); - } - let editor = active_thread - .editing_message - .as_ref() - .unwrap() - .1 - .editor - .clone(); - editor.update(cx, |editor, cx| { - editor.set_text("What is the weather like?", window, cx); - }); - active_thread.confirm_editing_message(&Default::default(), window, cx); - }); - - cx.run_until_parked(); - - // Verify that the previous completion was cancelled - assert_eq!(cancellation_events.lock().unwrap().len(), 1); - - // Verify that a new request was started after cancellation - assert_eq!(new_request_events.lock().unwrap().len(), 2); - - // Verify that the edited message contains the new text - let edited_message = - thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap()); - match &edited_message.segments[0] { - MessageSegment::Text(text) => { - assert_eq!(text, "What is the weather like?"); - } - _ => panic!("Expected text segment"), - } - } - - fn init_test_settings(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - AgentSettings::register(cx); - prompt_store::init(cx); - thread_store::init(cx); - workspace::init_settings(cx); - language_model::init_settings(cx); - ThemeSettings::register(cx); - EditorSettings::register(cx); - ToolRegistry::default_global(cx); - }); - } - - // Helper to create a test project with test files - async fn create_test_project( - cx: &mut TestAppContext, - files: serde_json::Value, - ) -> Entity { - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/test"), files).await; - Project::test(fs, [path!("/test").as_ref()], cx).await - } - - async fn setup_test_environment( - cx: &mut TestAppContext, - project: Entity, - ) -> ( - &mut VisualTestContext, - Entity, - Entity, - Entity, - Arc, - ) { - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - - let thread_store = cx - .update(|_, cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .await - .unwrap(); - - let text_thread_store = cx - .update(|_, cx| { - TextThreadStore::new( - project.clone(), - Arc::new(PromptBuilder::new(None).unwrap()), - Default::default(), - cx, - ) - }) - .await - .unwrap(); - - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); - let context_store = - cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade()))); - - let model = FakeLanguageModel::default(); - let model: Arc = Arc::new(model); - - let language_registry = LanguageRegistry::new(cx.executor()); - let language_registry = Arc::new(language_registry); - - let active_thread = cx.update(|window, cx| { - cx.new(|cx| { - ActiveThread::new( - thread.clone(), - thread_store.clone(), - text_thread_store, - context_store.clone(), - language_registry.clone(), - workspace.downgrade(), - window, - cx, - ) - }) - }); - - (cx, active_thread, workspace, thread, model) - } -} +// #[cfg(test)] +// mod tests { +// use super::*; +// use agent::{MessageSegment, context::ContextLoadResult, thread_store}; +// use assistant_tool::{ToolRegistry, ToolWorkingSet}; +// use editor::EditorSettings; +// use fs::FakeFs; +// use gpui::{AppContext, TestAppContext, VisualTestContext}; +// use language_model::{ +// ConfiguredModel, LanguageModel, LanguageModelRegistry, +// fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}, +// }; +// use project::Project; +// use prompt_store::PromptBuilder; +// use serde_json::json; +// use settings::SettingsStore; +// use util::path; +// use workspace::CollaboratorId; + +// #[gpui::test] +// async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project( +// cx, +// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), +// ) +// .await; + +// let (cx, _active_thread, workspace, thread, model) = +// setup_test_environment(cx, project.clone()).await; + +// // Insert user message without any context (empty context vector) +// thread.update(cx, |thread, cx| { +// thread.insert_user_message( +// "What is the best way to learn Rust?", +// ContextLoadResult::default(), +// None, +// vec![], +// cx, +// ); +// }); + +// // Stream response to user message +// thread.update(cx, |thread, cx| { +// let intent = CompletionIntent::UserPrompt; +// let request = thread.to_completion_request(model.clone(), intent, cx); +// thread.stream_completion(request, model, intent, cx.active_window(), cx) +// }); +// // Follow the agent +// cx.update(|window, cx| { +// workspace.update(cx, |workspace, cx| { +// workspace.follow(CollaboratorId::Agent, window, cx); +// }) +// }); +// assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent))); + +// // Cancel the current completion +// thread.update(cx, |thread, cx| { +// thread.cancel_last_completion(cx.active_window(), cx) +// }); + +// cx.executor().run_until_parked(); + +// // No longer following the agent +// assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent))); +// } + +// #[gpui::test] +// async fn test_reinserting_creases_for_edited_message(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; + +// let (cx, active_thread, _, thread, model) = +// setup_test_environment(cx, project.clone()).await; +// cx.update(|_, cx| { +// LanguageModelRegistry::global(cx).update(cx, |registry, cx| { +// registry.set_default_model( +// Some(ConfiguredModel { +// provider: Arc::new(FakeLanguageModelProvider), +// model, +// }), +// cx, +// ); +// }); +// }); + +// let creases = vec![MessageCrease { +// range: 14..22, +// icon_path: "icon".into(), +// label: "foo.txt".into(), +// context: None, +// }]; + +// let message = thread.update(cx, |thread, cx| { +// let message_id = thread.insert_user_message( +// "Tell me about @foo.txt", +// ContextLoadResult::default(), +// None, +// creases, +// cx, +// ); +// thread.message(message_id).cloned().unwrap() +// }); + +// active_thread.update_in(cx, |active_thread, window, cx| { +// if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) { +// active_thread.start_editing_message( +// message.id, +// message_text, +// message.creases.as_slice(), +// window, +// cx, +// ); +// } +// let editor = active_thread +// .editing_message +// .as_ref() +// .unwrap() +// .1 +// .editor +// .clone(); +// editor.update(cx, |editor, cx| editor.edit([(0..13, "modified")], cx)); +// active_thread.confirm_editing_message(&Default::default(), window, cx); +// }); +// cx.run_until_parked(); + +// let message = thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap()); +// active_thread.update_in(cx, |active_thread, window, cx| { +// if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) { +// active_thread.start_editing_message( +// message.id, +// message_text, +// message.creases.as_slice(), +// window, +// cx, +// ); +// } +// let editor = active_thread +// .editing_message +// .as_ref() +// .unwrap() +// .1 +// .editor +// .clone(); +// let text = editor.update(cx, |editor, cx| editor.text(cx)); +// assert_eq!(text, "modified @foo.txt"); +// }); +// } + +// #[gpui::test] +// async fn test_editing_message_cancels_previous_completion(cx: &mut TestAppContext) { +// init_test_settings(cx); + +// let project = create_test_project(cx, json!({})).await; + +// let (cx, active_thread, _, thread, model) = +// setup_test_environment(cx, project.clone()).await; + +// cx.update(|_, cx| { +// LanguageModelRegistry::global(cx).update(cx, |registry, cx| { +// registry.set_default_model( +// Some(ConfiguredModel { +// provider: Arc::new(FakeLanguageModelProvider), +// model: model.clone(), +// }), +// cx, +// ); +// }); +// }); + +// // Track thread events to verify cancellation +// let cancellation_events = Arc::new(std::sync::Mutex::new(Vec::new())); +// let new_request_events = Arc::new(std::sync::Mutex::new(Vec::new())); + +// let _subscription = cx.update(|_, cx| { +// let cancellation_events = cancellation_events.clone(); +// let new_request_events = new_request_events.clone(); +// cx.subscribe( +// &thread, +// move |_thread, event: &ThreadEvent, _cx| match event { +// ThreadEvent::CompletionCanceled => { +// cancellation_events.lock().unwrap().push(()); +// } +// ThreadEvent::NewRequest => { +// new_request_events.lock().unwrap().push(()); +// } +// _ => {} +// }, +// ) +// }); + +// // Insert a user message and start streaming a response +// let message = thread.update(cx, |thread, cx| { +// let message_id = thread.insert_user_message( +// "Hello, how are you?", +// ContextLoadResult::default(), +// None, +// vec![], +// cx, +// ); +// thread.advance_prompt_id(); +// thread.send_to_model( +// model.clone(), +// CompletionIntent::UserPrompt, +// cx.active_window(), +// cx, +// ); +// thread.message(message_id).cloned().unwrap() +// }); + +// cx.run_until_parked(); + +// // Verify that a completion is in progress +// assert!(cx.read(|cx| thread.read(cx).is_generating())); +// assert_eq!(new_request_events.lock().unwrap().len(), 1); + +// // Edit the message while the completion is still running +// active_thread.update_in(cx, |active_thread, window, cx| { +// if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) { +// active_thread.start_editing_message( +// message.id, +// message_text, +// message.creases.as_slice(), +// window, +// cx, +// ); +// } +// let editor = active_thread +// .editing_message +// .as_ref() +// .unwrap() +// .1 +// .editor +// .clone(); +// editor.update(cx, |editor, cx| { +// editor.set_text("What is the weather like?", window, cx); +// }); +// active_thread.confirm_editing_message(&Default::default(), window, cx); +// }); + +// cx.run_until_parked(); + +// // Verify that the previous completion was cancelled +// assert_eq!(cancellation_events.lock().unwrap().len(), 1); + +// // Verify that a new request was started after cancellation +// assert_eq!(new_request_events.lock().unwrap().len(), 2); + +// // Verify that the edited message contains the new text +// let edited_message = +// thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap()); +// match &edited_message.segments[0] { +// MessageSegment::Text(text) => { +// assert_eq!(text, "What is the weather like?"); +// } +// _ => panic!("Expected text segment"), +// } +// } + +// fn init_test_settings(cx: &mut TestAppContext) { +// cx.update(|cx| { +// let settings_store = SettingsStore::test(cx); +// cx.set_global(settings_store); +// language::init(cx); +// Project::init_settings(cx); +// AgentSettings::register(cx); +// prompt_store::init(cx); +// thread_store::init(cx); +// workspace::init_settings(cx); +// language_model::init_settings(cx); +// ThemeSettings::register(cx); +// EditorSettings::register(cx); +// ToolRegistry::default_global(cx); +// }); +// } + +// // Helper to create a test project with test files +// async fn create_test_project( +// cx: &mut TestAppContext, +// files: serde_json::Value, +// ) -> Entity { +// let fs = FakeFs::new(cx.executor()); +// fs.insert_tree(path!("/test"), files).await; +// Project::test(fs, [path!("/test").as_ref()], cx).await +// } + +// async fn setup_test_environment( +// cx: &mut TestAppContext, +// project: Entity, +// ) -> ( +// &mut VisualTestContext, +// Entity, +// Entity, +// Entity, +// Arc, +// ) { +// let (workspace, cx) = +// cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + +// let thread_store = cx +// .update(|_, cx| { +// ThreadStore::load( +// project.clone(), +// cx.new(|_| ToolWorkingSet::default()), +// None, +// Arc::new(PromptBuilder::new(None).unwrap()), +// cx, +// ) +// }) +// .await +// .unwrap(); + +// let text_thread_store = cx +// .update(|_, cx| { +// TextThreadStore::new( +// project.clone(), +// Arc::new(PromptBuilder::new(None).unwrap()), +// Default::default(), +// cx, +// ) +// }) +// .await +// .unwrap(); + +// let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); +// let context_store = +// cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade()))); + +// let model = FakeLanguageModel::default(); +// let model: Arc = Arc::new(model); + +// let language_registry = LanguageRegistry::new(cx.executor()); +// let language_registry = Arc::new(language_registry); + +// let active_thread = cx.update(|window, cx| { +// cx.new(|cx| { +// ActiveThread::new( +// thread.clone(), +// thread_store.clone(), +// text_thread_store, +// context_store.clone(), +// language_registry.clone(), +// workspace.downgrade(), +// window, +// cx, +// ) +// }) +// }); + +// (cx, active_thread, workspace, thread, model) +// } +// } diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 1a0f3ff27d83a98d343985b3f827aab26afd192a..7b386f91cf89dba6c99606779869c864b9722075 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -211,7 +211,7 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let new_title = self.thread.read(cx).title().unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -461,7 +461,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); + let summary = self.thread.read(cx).title().unwrap_or("Agent Changes"); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default @@ -1369,8 +1369,6 @@ impl AgentDiff { | ThreadEvent::MessageDeleted(_) | ThreadEvent::SummaryGenerated | ThreadEvent::SummaryChanged - | ThreadEvent::UsePendingTools { .. } - | ThreadEvent::ToolFinished { .. } | ThreadEvent::CheckpointChanged | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached @@ -1801,7 +1799,10 @@ mod tests { }) .await .unwrap(); - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let thread = thread_store + .update(cx, |store, cx| store.create_thread(cx)) + .await + .unwrap(); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let (workspace, cx) = @@ -1966,7 +1967,10 @@ mod tests { }) .await .unwrap(); - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); + let thread = thread_store + .update(cx, |store, cx| store.create_thread(cx)) + .await + .unwrap(); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); let (workspace, cx) = diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index f7b9157bbb9c07abac6a80dddfc014443165a712..642b82231eac9de7bc8c039feecf94494a675e91 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -45,7 +45,7 @@ impl AgentModelSelector { let registry = LanguageModelRegistry::read_global(cx); if let Some(provider) = registry.provider(&model.provider_id()) { - thread.set_configured_model( + thread.set_model( Some(ConfiguredModel { provider, model: model.clone(), diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 560e87b1c2ad9a86f8f83a0534f6b53091ebe2a2..c31e30fc39ab978baa5b92f04838dcdcb742b79a 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -26,7 +26,7 @@ use crate::{ ui::AgentOnboardingModal, }; use agent::{ - Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, + Thread, ThreadError, ThreadEvent, ThreadId, ThreadTitle, TokenUsageRatio, context_store::ContextStore, history_store::{HistoryEntryId, HistoryStore}, thread_store::{TextThreadStore, ThreadStore}, @@ -72,7 +72,7 @@ use zed_actions::{ agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding}, assistant::{OpenRulesLibrary, ToggleFocus}, }; -use zed_llm_client::{CompletionIntent, UsageLimit}; +use zed_llm_client::UsageLimit; const AGENT_PANEL_KEY: &str = "agent_panel"; @@ -252,7 +252,7 @@ impl ActiveView { thread.update(cx, |thread, cx| { thread.thread().update(cx, |thread, cx| { - thread.set_summary(new_summary, cx); + thread.set_title(new_summary, cx); }); }) } @@ -278,7 +278,7 @@ impl ActiveView { let editor = editor.clone(); move |_, thread, event, window, cx| match event { ThreadEvent::SummaryGenerated => { - let summary = thread.read(cx).summary().or_default(); + let summary = thread.read(cx).title().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -492,10 +492,15 @@ impl AgentPanel { None }; + let thread = thread_store + .update(cx, |this, cx| this.create_thread(cx))? + .await?; + let panel = workspace.update_in(cx, |workspace, window, cx| { let panel = cx.new(|cx| { Self::new( workspace, + thread, thread_store, context_store, prompt_store, @@ -518,13 +523,13 @@ impl AgentPanel { fn new( workspace: &Workspace, + thread: Entity, thread_store: Entity, context_store: Entity, prompt_store: Option>, window: &mut Window, cx: &mut Context, ) -> Self { - let thread = thread_store.update(cx, |this, cx| this.create_thread(cx)); let fs = workspace.app_state().fs.clone(); let user_store = workspace.app_state().user_store.clone(); let project = workspace.project(); @@ -647,11 +652,12 @@ impl AgentPanel { |this, _, event: &language_model::Event, cx| match event { language_model::Event::DefaultModelChanged => match &this.active_view { ActiveView::Thread { thread, .. } => { - thread - .read(cx) - .thread() - .clone() - .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); + // todo!(do we need this?); + // thread + // .read(cx) + // .thread() + // .clone() + // .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)); } ActiveView::TextThread { .. } | ActiveView::History @@ -784,46 +790,61 @@ impl AgentPanel { .detach_and_log_err(cx); } - let active_thread = cx.new(|cx| { - ActiveThread::new( - thread.clone(), - self.thread_store.clone(), - self.context_store.clone(), - context_store.clone(), - self.language_registry.clone(), - self.workspace.clone(), - window, - cx, - ) - }); + let fs = self.fs.clone(); + let user_store = self.user_store.clone(); + let thread_store = self.thread_store.clone(); + let text_thread_store = self.context_store.clone(); + let prompt_store = self.prompt_store.clone(); + let language_registry = self.language_registry.clone(); + let workspace = self.workspace.clone(); - let message_editor = cx.new(|cx| { - MessageEditor::new( - self.fs.clone(), - self.workspace.clone(), - self.user_store.clone(), - context_store.clone(), - self.prompt_store.clone(), - self.thread_store.downgrade(), - self.context_store.downgrade(), - thread.clone(), - window, - cx, - ) - }); + cx.spawn_in(window, async move |this, cx| { + let thread = thread.await?; + let active_thread = cx.new_window_entity(|window, cx| { + ActiveThread::new( + thread.clone(), + thread_store.clone(), + text_thread_store.clone(), + context_store.clone(), + language_registry.clone(), + workspace.clone(), + window, + cx, + ) + })?; - if let Some(text) = preserved_text { - message_editor.update(cx, |editor, cx| { - editor.set_text(text, window, cx); - }); - } + let message_editor = cx.new_window_entity(|window, cx| { + MessageEditor::new( + fs.clone(), + workspace.clone(), + user_store.clone(), + context_store.clone(), + prompt_store.clone(), + thread_store.downgrade(), + text_thread_store.downgrade(), + thread.clone(), + window, + cx, + ) + })?; - message_editor.focus_handle(cx).focus(window); + if let Some(text) = preserved_text { + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text(text, window, cx); + }); + } - let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx); - self.set_active_view(thread_view, window, cx); + this.update_in(cx, |this, window, cx| { + message_editor.focus_handle(cx).focus(window); - AgentDiff::set_active_thread(&self.workspace, &thread, window, cx); + let thread_view = + ActiveView::thread(active_thread.clone(), message_editor, window, cx); + this.set_active_view(thread_view, window, cx); + + AgentDiff::set_active_thread(&this.workspace, &thread, window, cx); + }) + }) + .detach_and_log_err(cx); } fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { @@ -1254,23 +1275,11 @@ impl AgentPanel { return; } - let model = thread_state.configured_model().map(|cm| cm.model.clone()); - if let Some(model) = model { - thread.update(cx, |active_thread, cx| { - active_thread.thread().update(cx, |thread, cx| { - thread.insert_invisible_continue_message(cx); - thread.advance_prompt_id(); - thread.send_to_model( - model, - CompletionIntent::UserPrompt, - Some(window.window_handle()), - cx, - ); - }); - }); - } else { - log::warn!("No configured model available for continuation"); - } + thread.update(cx, |active_thread, cx| { + active_thread + .thread() + .update(cx, |thread, cx| thread.resume(window, cx)) + }); } fn toggle_burn_mode( @@ -1552,24 +1561,24 @@ impl AgentPanel { let state = { let active_thread = active_thread.read(cx); if active_thread.is_empty() { - &ThreadSummary::Pending + &ThreadTitle::Pending } else { active_thread.summary(cx) } }; match state { - ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone()) + ThreadTitle::Pending => Label::new(ThreadTitle::DEFAULT.clone()) .truncate() .into_any_element(), - ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) + ThreadTitle::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) .truncate() .into_any_element(), - ThreadSummary::Ready(_) => div() + ThreadTitle::Ready(_) => div() .w_full() .child(change_title_editor.clone()) .into_any_element(), - ThreadSummary::Error => h_flex() + ThreadTitle::Error => h_flex() .w_full() .child(change_title_editor.clone()) .child( @@ -2024,7 +2033,7 @@ impl AgentPanel { .read(cx) .thread() .read(cx) - .configured_model() + .model() .map_or(false, |model| { model.provider.id().0 == ZED_CLOUD_PROVIDER_ID }); @@ -2629,7 +2638,7 @@ impl AgentPanel { return None; } - let model = thread.configured_model()?.model; + let model = thread.model()?.model; let focus_handle = self.focus_handle(cx); diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 29a4f38487b34e134218b72e824b1d3aba439cb9..51f51f62470b953dc8a1a23b4e73328d9c563ec8 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -121,7 +121,7 @@ pub(crate) enum ModelUsageContext { impl ModelUsageContext { pub fn configured_model(&self, cx: &App) -> Option { match self { - Self::Thread(thread) => thread.read(cx).configured_model(), + Self::Thread(thread) => thread.read(cx).model(), Self::InlineAssistant => { LanguageModelRegistry::read_global(cx).inline_assistant_model() } diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index f303f34a52856a068f1d2da33cf1f0a4fb5813a5..f662fe01425b4295d119fd7c3554d10af143ea9e 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -670,7 +670,7 @@ fn recent_context_picker_entries( let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx) .filter(|(_, thread)| match thread { ThreadContextEntry::Thread { id, .. } => { - Some(id) != active_thread_id && !current_threads.contains(id) + Some(id) != active_thread_id.as_ref() && !current_threads.contains(id) } ThreadContextEntry::Context { .. } => true, }) diff --git a/crates/agent_ui/src/context_strip.rs b/crates/agent_ui/src/context_strip.rs index 080ffd2ea0108400b691c6a614fcdb4f81952856..a0e8bd3cecc94fd14e9728d2b47ea5ab120b73e5 100644 --- a/crates/agent_ui/src/context_strip.rs +++ b/crates/agent_ui/src/context_strip.rs @@ -169,13 +169,13 @@ impl ContextStrip { if self .context_store .read(cx) - .includes_thread(active_thread.id()) + .includes_thread(&active_thread.id()) { return None; } Some(SuggestedContext::Thread { - name: active_thread.summary().or_default(), + name: active_thread.title().or_default(), thread: weak_active_thread, }) } else if let Some(active_context_editor) = panel.active_context_editor() { diff --git a/crates/agent_ui/src/profile_selector.rs b/crates/agent_ui/src/profile_selector.rs index ddcb44d46b800f257314a8802ad01abc98560ce0..7b996f52ad190ccb43f1b9068a432b1b3bb8dc64 100644 --- a/crates/agent_ui/src/profile_selector.rs +++ b/crates/agent_ui/src/profile_selector.rs @@ -156,7 +156,7 @@ impl Render for ProfileSelector { .map(|profile| profile.name.clone()) .unwrap_or_else(|| "Unknown".into()); - let configured_model = self.thread.read(cx).configured_model().or_else(|| { + let configured_model = self.thread.read(cx).model().or_else(|| { let model_registry = LanguageModelRegistry::read_global(cx); model_registry.default_model() });