Detailed changes
@@ -5,17 +5,15 @@ pub mod context_server_tool;
pub mod context_store;
pub mod history_store;
pub mod thread;
-mod thread2;
pub mod thread_store;
-pub mod tool_use;
mod zed_agent;
pub use agent2::*;
pub use context::{AgentContext, ContextId, ContextLoadResult};
pub use context_store::ContextStore;
pub use thread::{
- LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, Thread, ThreadError,
- ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio,
+ LastRestoreCheckpoint, Message, MessageCrease, Thread, ThreadError, ThreadEvent,
+ ThreadFeedback, ThreadTitle, TokenUsageRatio,
};
pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore};
pub use zed_agent::*;
@@ -2,13 +2,42 @@ use anyhow::Result;
use assistant_tool::{Tool, ToolResultOutput};
use futures::{channel::oneshot, future::BoxFuture, stream::BoxStream};
use gpui::SharedString;
-use std::sync::Arc;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::{
+ fmt::{self, Display},
+ sync::Arc,
+};
-#[derive(Debug, Clone)]
-pub struct AgentThreadId(SharedString);
+#[derive(
+ Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, JsonSchema,
+)]
+pub struct ThreadId(SharedString);
+
+impl ThreadId {
+ pub fn as_str(&self) -> &str {
+ &self.0
+ }
+
+ pub fn to_string(&self) -> String {
+ self.0.to_string()
+ }
+}
+
+impl From<&str> for ThreadId {
+ fn from(value: &str) -> Self {
+ ThreadId(SharedString::from(value.to_string()))
+ }
+}
+
+impl Display for ThreadId {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
-#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
-pub struct AgentThreadMessageId(usize);
+#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize)]
+pub struct MessageId(pub usize);
#[derive(Debug, Clone)]
pub struct AgentThreadToolCallId(SharedString);
@@ -31,11 +60,11 @@ pub enum AgentThreadResponseEvent {
pub enum AgentThreadMessage {
User {
- id: AgentThreadMessageId,
+ id: MessageId,
chunks: Vec<AgentThreadUserMessageChunk>,
},
Assistant {
- id: AgentThreadMessageId,
+ id: MessageId,
chunks: Vec<AgentThreadAssistantMessageChunk>,
},
}
@@ -56,20 +85,20 @@ pub enum AgentThreadAssistantMessageChunk {
},
}
-struct AgentThreadResponse {
- user_message_id: AgentThreadMessageId,
- events: BoxStream<'static, Result<AgentThreadResponseEvent>>,
+pub struct AgentThreadResponse {
+ pub user_message_id: MessageId,
+ pub events: BoxStream<'static, Result<AgentThreadResponseEvent>>,
}
pub trait AgentThread {
- fn id(&self) -> AgentThreadId;
+ fn id(&self) -> ThreadId;
fn title(&self) -> BoxFuture<'static, Result<String>>;
fn summary(&self) -> BoxFuture<'static, Result<String>>;
fn messages(&self) -> BoxFuture<'static, Result<Vec<AgentThreadMessage>>>;
- fn truncate(&self, message_id: AgentThreadMessageId) -> BoxFuture<'static, Result<()>>;
+ fn truncate(&self, message_id: MessageId) -> BoxFuture<'static, Result<()>>;
fn edit(
&self,
- message_id: AgentThreadMessageId,
+ message_id: MessageId,
content: Vec<AgentThreadUserMessageChunk>,
max_iterations: usize,
) -> BoxFuture<'static, Result<AgentThreadResponse>>;
@@ -581,7 +581,7 @@ impl ThreadContextHandle {
}
pub fn title(&self, cx: &App) -> SharedString {
- self.thread.read(cx).summary().or_default()
+ self.thread.read(cx).title().or_default()
}
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
@@ -589,7 +589,7 @@ impl ThreadContextHandle {
let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?;
let title = self
.thread
- .read_with(cx, |thread, _cx| thread.summary().or_default())
+ .read_with(cx, |thread, _cx| thread.title().or_default())
.ok()?;
let context = AgentContext::Thread(ThreadContext {
title,
@@ -1,10 +1,11 @@
use crate::{
+ MessageId, ThreadId,
context::{
AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
},
- thread::{MessageId, Thread, ThreadId},
+ thread::Thread,
thread_store::ThreadStore,
};
use anyhow::{Context as _, Result, anyhow};
@@ -71,6 +72,7 @@ impl ContextStore {
) -> Vec<AgentContextHandle> {
let existing_context = thread
.messages()
+ .iter()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
.flat_map(|message| {
message
@@ -441,7 +443,7 @@ impl ContextStore {
match context {
AgentContextHandle::Thread(thread_context) => {
self.context_thread_ids
- .remove(thread_context.thread.read(cx).id());
+ .remove(&thread_context.thread.read(cx).id());
}
AgentContextHandle::TextThread(text_thread_context) => {
if let Some(path) = text_thread_context.context.read(cx).path() {
@@ -1,12 +1,8 @@
use crate::{
+ AgentThread, AgentThreadUserMessageChunk, MessageId, ThreadId,
agent_profile::AgentProfile,
context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
- thread_store::{
- SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
- SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext,
- ThreadStore,
- },
- tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState},
+ thread_store::{SharedProjectContext, ThreadStore},
};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow};
@@ -15,7 +11,7 @@ use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use collections::{HashMap, HashSet};
use feature_flags::{self, FeatureFlagAppExt};
-use futures::{FutureExt, StreamExt as _, future::Shared};
+use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
use git::repository::DiffType;
use gpui::{
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
@@ -26,8 +22,7 @@ use language_model::{
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
- ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
- TokenUsage,
+ ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::{
@@ -36,7 +31,6 @@ use project::{
};
use prompt_store::{ModelContext, PromptBuilder};
use proto::Plan;
-use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
@@ -47,66 +41,8 @@ use std::{
};
use thiserror::Error;
use util::{ResultExt as _, post_inc};
-use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
-const MAX_RETRY_ATTEMPTS: u8 = 3;
-const BASE_RETRY_DELAY_SECS: u64 = 5;
-
-#[derive(
- Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
-)]
-pub struct ThreadId(Arc<str>);
-
-impl ThreadId {
- pub fn new() -> Self {
- Self(Uuid::new_v4().to_string().into())
- }
-}
-
-impl std::fmt::Display for ThreadId {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-impl From<&str> for ThreadId {
- fn from(value: &str) -> Self {
- Self(value.into())
- }
-}
-
-/// The ID of the user prompt that initiated a request.
-///
-/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
-pub struct PromptId(Arc<str>);
-
-impl PromptId {
- pub fn new() -> Self {
- Self(Uuid::new_v4().to_string().into())
- }
-}
-
-impl std::fmt::Display for PromptId {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
-pub struct MessageId(pub(crate) usize);
-
-impl MessageId {
- fn post_inc(&mut self) -> Self {
- Self(post_inc(&mut self.0))
- }
-
- pub fn as_usize(&self) -> usize {
- self.0
- }
-}
-
/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
#[derive(Clone, Debug)]
pub struct MessageCrease {
@@ -117,105 +53,38 @@ pub struct MessageCrease {
pub context: Option<AgentContextHandle>,
}
+pub enum MessageTool {
+ Pending {
+ tool: Arc<dyn Tool>,
+ input: serde_json::Value,
+ },
+ NeedsConfirmation {
+ tool: Arc<dyn Tool>,
+ input_json: serde_json::Value,
+ confirm_tx: oneshot::Sender<bool>,
+ },
+ Confirmed {
+ card: AnyToolCard,
+ },
+ Declined {
+ tool: Arc<dyn Tool>,
+ input_json: serde_json::Value,
+ },
+}
+
/// A message in a [`Thread`].
-#[derive(Debug, Clone)]
pub struct Message {
pub id: MessageId,
pub role: Role,
- pub segments: Vec<MessageSegment>,
+ pub thinking: String,
+ pub text: String,
+ pub tools: Vec<MessageTool>,
pub loaded_context: LoadedContext,
pub creases: Vec<MessageCrease>,
pub is_hidden: bool,
pub ui_only: bool,
}
-impl Message {
- /// Returns whether the message contains any meaningful text that should be displayed
- /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
- pub fn should_display_content(&self) -> bool {
- self.segments.iter().all(|segment| segment.should_display())
- }
-
- pub fn push_thinking(&mut self, text: &str, signature: Option<String>) {
- if let Some(MessageSegment::Thinking {
- text: segment,
- signature: current_signature,
- }) = self.segments.last_mut()
- {
- if let Some(signature) = signature {
- *current_signature = Some(signature);
- }
- segment.push_str(text);
- } else {
- self.segments.push(MessageSegment::Thinking {
- text: text.to_string(),
- signature,
- });
- }
- }
-
- pub fn push_redacted_thinking(&mut self, data: String) {
- self.segments.push(MessageSegment::RedactedThinking(data));
- }
-
- pub fn push_text(&mut self, text: &str) {
- if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
- segment.push_str(text);
- } else {
- self.segments.push(MessageSegment::Text(text.to_string()));
- }
- }
-
- pub fn to_string(&self) -> String {
- let mut result = String::new();
-
- if !self.loaded_context.text.is_empty() {
- result.push_str(&self.loaded_context.text);
- }
-
- for segment in &self.segments {
- match segment {
- MessageSegment::Text(text) => result.push_str(text),
- MessageSegment::Thinking { text, .. } => {
- result.push_str("<think>\n");
- result.push_str(text);
- result.push_str("\n</think>");
- }
- MessageSegment::RedactedThinking(_) => {}
- }
- }
-
- result
- }
-}
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum MessageSegment {
- Text(String),
- Thinking {
- text: String,
- signature: Option<String>,
- },
- RedactedThinking(String),
-}
-
-impl MessageSegment {
- pub fn should_display(&self) -> bool {
- match self {
- Self::Text(text) => text.is_empty(),
- Self::Thinking { text, .. } => text.is_empty(),
- Self::RedactedThinking(_) => false,
- }
- }
-
- pub fn text(&self) -> Option<&str> {
- match self {
- MessageSegment::Text(text) => Some(text),
- _ => None,
- }
- }
-}
-
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProjectSnapshot {
pub worktree_snapshots: Vec<WorktreeSnapshot>,
@@ -345,25 +214,17 @@ pub enum QueueState {
/// A thread of conversation with the LLM.
pub struct Thread {
- id: ThreadId,
- updated_at: DateTime<Utc>,
- summary: ThreadSummary,
+ agent_thread: Arc<dyn AgentThread>,
+ title: ThreadTitle,
+ pending_send: Option<Task<Result<()>>>,
pending_summary: Task<Option<()>>,
detailed_summary_task: Task<Option<()>>,
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
completion_mode: agent_settings::CompletionMode,
messages: Vec<Message>,
- next_message_id: MessageId,
- last_prompt_id: PromptId,
- project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
- completion_count: usize,
- pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
- prompt_builder: Arc<PromptBuilder>,
- tools: Entity<ToolWorkingSet>,
- tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
@@ -372,35 +233,22 @@ pub struct Thread {
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
tool_use_limit_reached: bool,
+ // todo!(keep track of retries from the underlying agent)
feedback: Option<ThreadFeedback>,
- retry_state: Option<RetryState>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
last_received_chunk_at: Option<Instant>,
- request_callback: Option<
- Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
- >,
- remaining_turns: u32,
- configured_model: Option<ConfiguredModel>,
- profile: AgentProfile,
-}
-
-#[derive(Clone, Debug)]
-struct RetryState {
- attempt: u8,
- max_attempts: u8,
- intent: CompletionIntent,
}
#[derive(Clone, Debug, PartialEq, Eq)]
-pub enum ThreadSummary {
+pub enum ThreadTitle {
Pending,
Generating,
Ready(SharedString),
Error,
}
-impl ThreadSummary {
+impl ThreadTitle {
pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
pub fn or_default(&self) -> SharedString {
@@ -413,8 +261,8 @@ impl ThreadSummary {
pub fn ready(&self) -> Option<SharedString> {
match self {
- ThreadSummary::Ready(summary) => Some(summary.clone()),
- ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
+ ThreadTitle::Ready(summary) => Some(summary.clone()),
+ ThreadTitle::Pending | ThreadTitle::Generating | ThreadTitle::Error => None,
}
}
}
@@ -428,39 +276,26 @@ pub struct ExceededWindowError {
}
impl Thread {
- pub fn new(
+ pub fn load(
+ agent_thread: Arc<dyn AgentThread>,
project: Entity<Project>,
- tools: Entity<ToolWorkingSet>,
- prompt_builder: Arc<PromptBuilder>,
- system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
- let configured_model = LanguageModelRegistry::read_global(cx).default_model();
- let profile_id = AgentSettings::get_global(cx).default_profile.clone();
-
Self {
- id: ThreadId::new(),
- updated_at: Utc::now(),
- summary: ThreadSummary::Pending,
+ agent_thread,
+ title: ThreadTitle::Pending,
+ pending_send: None,
pending_summary: Task::ready(None),
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
detailed_summary_rx,
completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
- messages: Vec::new(),
- next_message_id: MessageId(0),
- last_prompt_id: PromptId::new(),
- project_context: system_prompt,
+ messages: todo!("read from agent"),
checkpoints_by_message: HashMap::default(),
- completion_count: 0,
- pending_completions: Vec::new(),
project: project.clone(),
- prompt_builder,
- tools: tools.clone(),
last_restore_checkpoint: None,
pending_checkpoint: None,
- tool_use: ToolUseState::new(tools.clone()),
action_log: cx.new(|_| ActionLog::new(project.clone())),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx);
@@ -473,221 +308,64 @@ impl Thread {
exceeded_window_error: None,
tool_use_limit_reached: false,
feedback: None,
- retry_state: None,
- message_feedback: HashMap::default(),
- last_auto_capture_at: None,
- last_received_chunk_at: None,
- request_callback: None,
- remaining_turns: u32::MAX,
- configured_model,
- profile: AgentProfile::new(profile_id, tools),
- }
- }
-
- pub fn deserialize(
- id: ThreadId,
- serialized: SerializedThread,
- project: Entity<Project>,
- tools: Entity<ToolWorkingSet>,
- prompt_builder: Arc<PromptBuilder>,
- project_context: SharedProjectContext,
- window: Option<&mut Window>, // None in headless mode
- cx: &mut Context<Self>,
- ) -> Self {
- let next_message_id = MessageId(
- serialized
- .messages
- .last()
- .map(|message| message.id.0 + 1)
- .unwrap_or(0),
- );
- let tool_use = ToolUseState::from_serialized_messages(
- tools.clone(),
- &serialized.messages,
- project.clone(),
- window,
- cx,
- );
- let (detailed_summary_tx, detailed_summary_rx) =
- postage::watch::channel_with(serialized.detailed_summary_state);
-
- let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- serialized
- .model
- .and_then(|model| {
- let model = SelectedModel {
- provider: model.provider.clone().into(),
- model: model.model.clone().into(),
- };
- registry.select_model(&model, cx)
- })
- .or_else(|| registry.default_model())
- });
-
- let completion_mode = serialized
- .completion_mode
- .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
- let profile_id = serialized
- .profile
- .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
-
- Self {
- id,
- updated_at: serialized.updated_at,
- summary: ThreadSummary::Ready(serialized.summary),
- pending_summary: Task::ready(None),
- detailed_summary_task: Task::ready(None),
- detailed_summary_tx,
- detailed_summary_rx,
- completion_mode,
- retry_state: None,
- messages: serialized
- .messages
- .into_iter()
- .map(|message| Message {
- id: message.id,
- role: message.role,
- segments: message
- .segments
- .into_iter()
- .map(|segment| match segment {
- SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
- SerializedMessageSegment::Thinking { text, signature } => {
- MessageSegment::Thinking { text, signature }
- }
- SerializedMessageSegment::RedactedThinking { data } => {
- MessageSegment::RedactedThinking(data)
- }
- })
- .collect(),
- loaded_context: LoadedContext {
- contexts: Vec::new(),
- text: message.context,
- images: Vec::new(),
- },
- creases: message
- .creases
- .into_iter()
- .map(|crease| MessageCrease {
- range: crease.start..crease.end,
- icon_path: crease.icon_path,
- label: crease.label,
- context: None,
- })
- .collect(),
- is_hidden: message.is_hidden,
- ui_only: false, // UI-only messages are not persisted
- })
- .collect(),
- next_message_id,
- last_prompt_id: PromptId::new(),
- project_context,
- checkpoints_by_message: HashMap::default(),
- completion_count: 0,
- pending_completions: Vec::new(),
- last_restore_checkpoint: None,
- pending_checkpoint: None,
- project: project.clone(),
- prompt_builder,
- tools: tools.clone(),
- tool_use,
- action_log: cx.new(|_| ActionLog::new(project)),
- initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
- request_token_usage: serialized.request_token_usage,
- cumulative_token_usage: serialized.cumulative_token_usage,
- exceeded_window_error: None,
- tool_use_limit_reached: serialized.tool_use_limit_reached,
- feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
last_received_chunk_at: None,
- request_callback: None,
- remaining_turns: u32::MAX,
- configured_model,
- profile: AgentProfile::new(profile_id, tools),
}
}
- pub fn set_request_callback(
- &mut self,
- callback: impl 'static
- + FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
- ) {
- self.request_callback = Some(Box::new(callback));
- }
-
- pub fn id(&self) -> &ThreadId {
- &self.id
+ pub fn id(&self) -> ThreadId {
+ self.agent_thread.id()
}
pub fn profile(&self) -> &AgentProfile {
- &self.profile
+ todo!()
}
pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
- if &id != self.profile.id() {
- self.profile = AgentProfile::new(id, self.tools.clone());
- cx.emit(ThreadEvent::ProfileChanged);
- }
+ todo!()
+ // if &id != self.profile.id() {
+ // self.profile = AgentProfile::new(id, self.tools.clone());
+ // cx.emit(ThreadEvent::ProfileChanged);
+ // }
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
- pub fn updated_at(&self) -> DateTime<Utc> {
- self.updated_at
- }
-
- pub fn touch_updated_at(&mut self) {
- self.updated_at = Utc::now();
- }
-
- pub fn advance_prompt_id(&mut self) {
- self.last_prompt_id = PromptId::new();
- }
-
pub fn project_context(&self) -> SharedProjectContext {
- self.project_context.clone()
- }
-
- pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
- if self.configured_model.is_none() {
- self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
- }
- self.configured_model.clone()
- }
-
- pub fn configured_model(&self) -> Option<ConfiguredModel> {
- self.configured_model.clone()
+ todo!()
+ // self.project_context.clone()
}
- pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
- self.configured_model = model;
- cx.notify();
+ pub fn title(&self) -> &ThreadTitle {
+ &self.title
}
- pub fn summary(&self) -> &ThreadSummary {
- &self.summary
- }
+ pub fn set_title(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
+ todo!()
+ // let current_summary = match &self.summary {
+ // ThreadSummary::Pending | ThreadSummary::Generating => return,
+ // ThreadSummary::Ready(summary) => summary,
+ // ThreadSummary::Error => &ThreadSummary::DEFAULT,
+ // };
- pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
- let current_summary = match &self.summary {
- ThreadSummary::Pending | ThreadSummary::Generating => return,
- ThreadSummary::Ready(summary) => summary,
- ThreadSummary::Error => &ThreadSummary::DEFAULT,
- };
+ // let mut new_summary = new_summary.into();
- let mut new_summary = new_summary.into();
+ // if new_summary.is_empty() {
+ // new_summary = ThreadSummary::DEFAULT;
+ // }
- if new_summary.is_empty() {
- new_summary = ThreadSummary::DEFAULT;
- }
+ // if current_summary != &new_summary {
+ // self.summary = ThreadSummary::Ready(new_summary);
+ // cx.emit(ThreadEvent::SummaryChanged);
+ // }
+ }
- if current_summary != &new_summary {
- self.summary = ThreadSummary::Ready(new_summary);
- cx.emit(ThreadEvent::SummaryChanged);
- }
+ pub fn regenerate_summary(&self, cx: &mut Context<Self>) {
+ todo!()
+ // self.summarize(cx);
}
pub fn completion_mode(&self) -> CompletionMode {
@@ -707,12 +385,12 @@ impl Thread {
self.messages.get(index)
}
- pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
- self.messages.iter()
+ pub fn messages(&self) -> &[Message] {
+ &self.messages
}
pub fn is_generating(&self) -> bool {
- !self.pending_completions.is_empty() || !self.all_tools_finished()
+ self.pending_send.is_some()
}
/// Indicates whether streaming of language model events is stale.
@@ -728,34 +406,6 @@ impl Thread {
self.last_received_chunk_at = Some(Instant::now());
}
- pub fn queue_state(&self) -> Option<QueueState> {
- self.pending_completions
- .first()
- .map(|pending_completion| pending_completion.queue_state)
- }
-
- pub fn tools(&self) -> &Entity<ToolWorkingSet> {
- &self.tools
- }
-
- pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
- self.tool_use
- .pending_tool_uses()
- .into_iter()
- .find(|tool_use| &tool_use.id == id)
- }
-
- pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
- self.tool_use
- .pending_tool_uses()
- .into_iter()
- .filter(|tool_use| tool_use.status.needs_confirmation())
- }
-
- pub fn has_pending_tool_uses(&self) -> bool {
- !self.tool_use.pending_tool_uses().is_empty()
- }
-
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
self.checkpoints_by_message.get(&id).cloned()
}
@@ -855,6 +505,7 @@ impl Thread {
}
pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
+ todo!("call truncate on the agent");
let Some(message_ix) = self
.messages
.iter()
@@ -868,248 +519,203 @@ impl Thread {
cx.notify();
}
- pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
- self.messages
- .iter()
- .find(|message| message.id == id)
- .into_iter()
- .flat_map(|message| message.loaded_context.contexts.iter())
- }
-
pub fn is_turn_end(&self, ix: usize) -> bool {
- if self.messages.is_empty() {
- return false;
- }
+ todo!()
+ // if self.messages.is_empty() {
+ // return false;
+ // }
- if !self.is_generating() && ix == self.messages.len() - 1 {
- return true;
- }
+ // if !self.is_generating() && ix == self.messages.len() - 1 {
+ // return true;
+ // }
- let Some(message) = self.messages.get(ix) else {
- return false;
- };
+ // let Some(message) = self.messages.get(ix) else {
+ // return false;
+ // };
- if message.role != Role::Assistant {
- return false;
- }
+ // if message.role != Role::Assistant {
+ // return false;
+ // }
- self.messages
- .get(ix + 1)
- .and_then(|message| {
- self.message(message.id)
- .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
- })
- .unwrap_or(false)
+ // self.messages
+ // .get(ix + 1)
+ // .and_then(|message| {
+ // self.message(message.id)
+ // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
+ // })
+ // .unwrap_or(false)
}
pub fn tool_use_limit_reached(&self) -> bool {
self.tool_use_limit_reached
}
- /// Returns whether all of the tool uses have finished running.
- pub fn all_tools_finished(&self) -> bool {
- // If the only pending tool uses left are the ones with errors, then
- // that means that we've finished running all of the pending tools.
- self.tool_use
- .pending_tool_uses()
- .iter()
- .all(|pending_tool_use| pending_tool_use.status.is_error())
- }
-
/// Returns whether any pending tool uses may perform edits
pub fn has_pending_edit_tool_uses(&self) -> bool {
- self.tool_use
- .pending_tool_uses()
- .iter()
- .filter(|pending_tool_use| !pending_tool_use.status.is_error())
- .any(|pending_tool_use| pending_tool_use.may_perform_edits)
- }
-
- pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
- self.tool_use.tool_uses_for_message(id, cx)
- }
-
- pub fn tool_results_for_message(
- &self,
- assistant_message_id: MessageId,
- ) -> Vec<&LanguageModelToolResult> {
- self.tool_use.tool_results_for_message(assistant_message_id)
- }
-
- pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
- self.tool_use.tool_result(id)
- }
-
- pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
- match &self.tool_use.tool_result(id)?.content {
- LanguageModelToolResultContent::Text(text) => Some(text),
- LanguageModelToolResultContent::Image(_) => {
- // TODO: We should display image
- None
- }
- }
- }
-
- pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
- self.tool_use.tool_result_card(id).cloned()
- }
-
- /// Return tools that are both enabled and supported by the model
- pub fn available_tools(
- &self,
- cx: &App,
- model: Arc<dyn LanguageModel>,
- ) -> Vec<LanguageModelRequestTool> {
- if model.supports_tools() {
- resolve_tool_name_conflicts(self.profile.enabled_tools(cx).as_slice())
- .into_iter()
- .filter_map(|(name, tool)| {
- // Skip tools that cannot be supported
- let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
- Some(LanguageModelRequestTool {
- name,
- description: tool.description(),
- input_schema,
- })
- })
- .collect()
- } else {
- Vec::default()
- }
- }
-
- pub fn insert_user_message(
+ todo!()
+ }
+
+ // pub fn insert_user_message(
+ // &mut self,
+ // text: impl Into<String>,
+ // loaded_context: ContextLoadResult,
+ // git_checkpoint: Option<GitStoreCheckpoint>,
+ // creases: Vec<MessageCrease>,
+ // cx: &mut Context<Self>,
+ // ) -> AgentThreadMessageId {
+ // todo!("move this logic into send")
+ // if !loaded_context.referenced_buffers.is_empty() {
+ // self.action_log.update(cx, |log, cx| {
+ // for buffer in loaded_context.referenced_buffers {
+ // log.buffer_read(buffer, cx);
+ // }
+ // });
+ // }
+
+ // let message_id = self.insert_message(
+ // Role::User,
+ // vec![MessageSegment::Text(text.into())],
+ // loaded_context.loaded_context,
+ // creases,
+ // false,
+ // cx,
+ // );
+
+ // if let Some(git_checkpoint) = git_checkpoint {
+ // self.pending_checkpoint = Some(ThreadCheckpoint {
+ // message_id,
+ // git_checkpoint,
+ // });
+ // }
+
+ // self.auto_capture_telemetry(cx);
+
+ // message_id
+ // }
+
+ pub fn set_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
+ todo!()
+ }
+
+ pub fn model(&self) -> Option<ConfiguredModel> {
+ todo!()
+ }
+
+ pub fn send(
&mut self,
- text: impl Into<String>,
- loaded_context: ContextLoadResult,
- git_checkpoint: Option<GitStoreCheckpoint>,
- creases: Vec<MessageCrease>,
+ message: Vec<AgentThreadUserMessageChunk>,
+ window: &mut Window,
cx: &mut Context<Self>,
- ) -> MessageId {
- if !loaded_context.referenced_buffers.is_empty() {
- self.action_log.update(cx, |log, cx| {
- for buffer in loaded_context.referenced_buffers {
- log.buffer_read(buffer, cx);
- }
- });
- }
-
- let message_id = self.insert_message(
- Role::User,
- vec![MessageSegment::Text(text.into())],
- loaded_context.loaded_context,
- creases,
- false,
- cx,
- );
-
- if let Some(git_checkpoint) = git_checkpoint {
- self.pending_checkpoint = Some(ThreadCheckpoint {
- message_id,
- git_checkpoint,
- });
- }
-
- self.auto_capture_telemetry(cx);
-
- message_id
- }
-
- pub fn insert_invisible_continue_message(&mut self, cx: &mut Context<Self>) -> MessageId {
- let id = self.insert_message(
- Role::User,
- vec![MessageSegment::Text("Continue where you left off".into())],
- LoadedContext::default(),
- vec![],
- true,
- cx,
- );
- self.pending_checkpoint = None;
-
- id
- }
-
- pub fn insert_assistant_message(
- &mut self,
- segments: Vec<MessageSegment>,
- cx: &mut Context<Self>,
- ) -> MessageId {
- self.insert_message(
- Role::Assistant,
- segments,
- LoadedContext::default(),
- Vec::new(),
- false,
- cx,
- )
+ ) {
+ todo!()
}
- pub fn insert_message(
- &mut self,
- role: Role,
- segments: Vec<MessageSegment>,
- loaded_context: LoadedContext,
- creases: Vec<MessageCrease>,
- is_hidden: bool,
- cx: &mut Context<Self>,
- ) -> MessageId {
- let id = self.next_message_id.post_inc();
- self.messages.push(Message {
- id,
- role,
- segments,
- loaded_context,
- creases,
- is_hidden,
- ui_only: false,
- });
- self.touch_updated_at();
- cx.emit(ThreadEvent::MessageAdded(id));
- id
+ pub fn resume(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ todo!()
}
- pub fn edit_message(
+ pub fn edit(
&mut self,
- id: MessageId,
- new_role: Role,
- new_segments: Vec<MessageSegment>,
- creases: Vec<MessageCrease>,
- loaded_context: Option<LoadedContext>,
- checkpoint: Option<GitStoreCheckpoint>,
+ message_id: MessageId,
+ message: Vec<AgentThreadUserMessageChunk>,
+ window: &mut Window,
cx: &mut Context<Self>,
- ) -> bool {
- let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
- return false;
- };
- message.role = new_role;
- message.segments = new_segments;
- message.creases = creases;
- if let Some(context) = loaded_context {
- message.loaded_context = context;
- }
- if let Some(git_checkpoint) = checkpoint {
- self.checkpoints_by_message.insert(
- id,
- ThreadCheckpoint {
- message_id: id,
- git_checkpoint,
- },
- );
- }
- self.touch_updated_at();
- cx.emit(ThreadEvent::MessageEdited(id));
- true
- }
-
- pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
- let Some(index) = self.messages.iter().position(|message| message.id == id) else {
- return false;
- };
- self.messages.remove(index);
- self.touch_updated_at();
- cx.emit(ThreadEvent::MessageDeleted(id));
- true
- }
+ ) {
+ todo!()
+ }
+
+ pub fn cancel(&mut self, window: &mut Window, cx: &mut Context<Self>) -> bool {
+ todo!()
+ }
+
+ // pub fn insert_invisible_continue_message(
+ // &mut self,
+ // cx: &mut Context<Self>,
+ // ) -> AgentThreadMessageId {
+ // let id = self.insert_message(
+ // Role::User,
+ // vec![MessageSegment::Text("Continue where you left off".into())],
+ // LoadedContext::default(),
+ // vec![],
+ // true,
+ // cx,
+ // );
+ // self.pending_checkpoint = None;
+
+ // id
+ // }
+
+ // pub fn insert_assistant_message(
+ // &mut self,
+ // segments: Vec<MessageSegment>,
+ // cx: &mut Context<Self>,
+ // ) -> AgentThreadMessageId {
+ // self.insert_message(
+ // Role::Assistant,
+ // segments,
+ // LoadedContext::default(),
+ // Vec::new(),
+ // false,
+ // cx,
+ // )
+ // }
+
+ // pub fn insert_message(
+ // &mut self,
+ // role: Role,
+ // segments: Vec<MessageSegment>,
+ // loaded_context: LoadedContext,
+ // creases: Vec<MessageCrease>,
+ // is_hidden: bool,
+ // cx: &mut Context<Self>,
+ // ) -> AgentThreadMessageId {
+ // let id = self.next_message_id.post_inc();
+ // self.messages.push(Message {
+ // id,
+ // role,
+ // segments,
+ // loaded_context,
+ // creases,
+ // is_hidden,
+ // ui_only: false,
+ // });
+ // self.touch_updated_at();
+ // cx.emit(ThreadEvent::MessageAdded(id));
+ // id
+ // }
+
+ // pub fn edit_message(
+ // &mut self,
+ // id: AgentThreadMessageId,
+ // new_role: Role,
+ // new_segments: Vec<MessageSegment>,
+ // creases: Vec<MessageCrease>,
+ // loaded_context: Option<LoadedContext>,
+ // checkpoint: Option<GitStoreCheckpoint>,
+ // cx: &mut Context<Self>,
+ // ) -> bool {
+ // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
+ // return false;
+ // };
+ // message.role = new_role;
+ // message.segments = new_segments;
+ // message.creases = creases;
+ // if let Some(context) = loaded_context {
+ // message.loaded_context = context;
+ // }
+ // if let Some(git_checkpoint) = checkpoint {
+ // self.checkpoints_by_message.insert(
+ // id,
+ // ThreadCheckpoint {
+ // message_id: id,
+ // git_checkpoint,
+ // },
+ // );
+ // }
+ // self.touch_updated_at();
+ // cx.emit(ThreadEvent::MessageEdited(id));
+ // true
+ // }
/// Returns the representation of this [`Thread`] in a textual form.
///
@@ -1,1449 +0,0 @@
-use crate::{
- AgentThread, AgentThreadId, AgentThreadMessageId, AgentThreadUserMessageChunk,
- agent_profile::AgentProfile,
- context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext},
- thread_store::{SharedProjectContext, ThreadStore},
-};
-use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
-use anyhow::{Result, anyhow};
-use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
-use chrono::{DateTime, Utc};
-use client::{ModelRequestUsage, RequestUsage};
-use collections::{HashMap, HashSet};
-use feature_flags::{self, FeatureFlagAppExt};
-use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
-use git::repository::DiffType;
-use gpui::{
- AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task,
- WeakEntity,
-};
-use language_model::{
- ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
- ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
-};
-use postage::stream::Stream as _;
-use project::{
- Project,
- git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
-};
-use prompt_store::{ModelContext, PromptBuilder};
-use proto::Plan;
-use serde::{Deserialize, Serialize};
-use settings::Settings;
-use std::{
- io::Write,
- ops::Range,
- sync::Arc,
- time::{Duration, Instant},
-};
-use thiserror::Error;
-use util::{ResultExt as _, post_inc};
-use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
-
-/// Stored information that can be used to resurrect a context crease when creating an editor for a past message.
-#[derive(Clone, Debug)]
-pub struct MessageCrease {
- pub range: Range<usize>,
- pub icon_path: SharedString,
- pub label: SharedString,
- /// None for a deserialized message, Some otherwise.
- pub context: Option<AgentContextHandle>,
-}
-
-pub enum MessageTool {
- Pending {
- tool: Arc<dyn Tool>,
- input: serde_json::Value,
- },
- NeedsConfirmation {
- tool: Arc<dyn Tool>,
- input_json: serde_json::Value,
- confirm_tx: oneshot::Sender<bool>,
- },
- Confirmed {
- card: AnyToolCard,
- },
- Declined {
- tool: Arc<dyn Tool>,
- input_json: serde_json::Value,
- },
-}
-
-/// A message in a [`Thread`].
-pub struct Message {
- pub id: AgentThreadMessageId,
- pub role: Role,
- pub thinking: String,
- pub text: String,
- pub tools: Vec<MessageTool>,
- pub loaded_context: LoadedContext,
- pub creases: Vec<MessageCrease>,
- pub is_hidden: bool,
- pub ui_only: bool,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
-pub struct ProjectSnapshot {
- pub worktree_snapshots: Vec<WorktreeSnapshot>,
- pub unsaved_buffer_paths: Vec<String>,
- pub timestamp: DateTime<Utc>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
-pub struct WorktreeSnapshot {
- pub worktree_path: String,
- pub git_state: Option<GitState>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
-pub struct GitState {
- pub remote_url: Option<String>,
- pub head_sha: Option<String>,
- pub current_branch: Option<String>,
- pub diff: Option<String>,
-}
-
-#[derive(Clone, Debug)]
-pub struct ThreadCheckpoint {
- message_id: AgentThreadMessageId,
- git_checkpoint: GitStoreCheckpoint,
-}
-
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-pub enum ThreadFeedback {
- Positive,
- Negative,
-}
-
-pub enum LastRestoreCheckpoint {
- Pending {
- message_id: AgentThreadMessageId,
- },
- Error {
- message_id: AgentThreadMessageId,
- error: String,
- },
-}
-
-impl LastRestoreCheckpoint {
- pub fn message_id(&self) -> AgentThreadMessageId {
- match self {
- LastRestoreCheckpoint::Pending { message_id } => *message_id,
- LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
- }
- }
-}
-
-#[derive(Clone, Debug, Default)]
-pub enum DetailedSummaryState {
- #[default]
- NotGenerated,
- Generating {
- message_id: AgentThreadMessageId,
- },
- Generated {
- text: SharedString,
- message_id: AgentThreadMessageId,
- },
-}
-
-impl DetailedSummaryState {
- fn text(&self) -> Option<SharedString> {
- if let Self::Generated { text, .. } = self {
- Some(text.clone())
- } else {
- None
- }
- }
-}
-
-#[derive(Default, Debug)]
-pub struct TotalTokenUsage {
- pub total: u64,
- pub max: u64,
-}
-
-impl TotalTokenUsage {
- pub fn ratio(&self) -> TokenUsageRatio {
- #[cfg(debug_assertions)]
- let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
- .unwrap_or("0.8".to_string())
- .parse()
- .unwrap();
- #[cfg(not(debug_assertions))]
- let warning_threshold: f32 = 0.8;
-
- // When the maximum is unknown because there is no selected model,
- // avoid showing the token limit warning.
- if self.max == 0 {
- TokenUsageRatio::Normal
- } else if self.total >= self.max {
- TokenUsageRatio::Exceeded
- } else if self.total as f32 / self.max as f32 >= warning_threshold {
- TokenUsageRatio::Warning
- } else {
- TokenUsageRatio::Normal
- }
- }
-
- pub fn add(&self, tokens: u64) -> TotalTokenUsage {
- TotalTokenUsage {
- total: self.total + tokens,
- max: self.max,
- }
- }
-}
-
-#[derive(Debug, Default, PartialEq, Eq)]
-pub enum TokenUsageRatio {
- #[default]
- Normal,
- Warning,
- Exceeded,
-}
-
-#[derive(Debug, Clone, Copy)]
-pub enum QueueState {
- Sending,
- Queued { position: usize },
- Started,
-}
-
-/// A thread of conversation with the LLM.
-pub struct Thread {
- agent_thread: Arc<dyn AgentThread>,
- summary: ThreadSummary,
- pending_send: Option<Task<Result<()>>>,
- pending_summary: Task<Option<()>>,
- detailed_summary_task: Task<Option<()>>,
- detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
- detailed_summary_rx: postage::watch::Receiver<DetailedSummaryState>,
- completion_mode: agent_settings::CompletionMode,
- messages: Vec<Message>,
- checkpoints_by_message: HashMap<AgentThreadMessageId, ThreadCheckpoint>,
- project: Entity<Project>,
- action_log: Entity<ActionLog>,
- last_restore_checkpoint: Option<LastRestoreCheckpoint>,
- pending_checkpoint: Option<ThreadCheckpoint>,
- initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
- request_token_usage: Vec<TokenUsage>,
- cumulative_token_usage: TokenUsage,
- exceeded_window_error: Option<ExceededWindowError>,
- tool_use_limit_reached: bool,
- // todo!(keep track of retries from the underlying agent)
- feedback: Option<ThreadFeedback>,
- message_feedback: HashMap<AgentThreadMessageId, ThreadFeedback>,
- last_auto_capture_at: Option<Instant>,
- last_received_chunk_at: Option<Instant>,
-}
-
-#[derive(Clone, Debug, PartialEq, Eq)]
-pub enum ThreadSummary {
- Pending,
- Generating,
- Ready(SharedString),
- Error,
-}
-
-impl ThreadSummary {
- pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
-
- pub fn or_default(&self) -> SharedString {
- self.unwrap_or(Self::DEFAULT)
- }
-
- pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
- self.ready().unwrap_or_else(|| message.into())
- }
-
- pub fn ready(&self) -> Option<SharedString> {
- match self {
- ThreadSummary::Ready(summary) => Some(summary.clone()),
- ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
- }
- }
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
-pub struct ExceededWindowError {
- /// Model used when last message exceeded context window
- model_id: LanguageModelId,
- /// Token count including last message
- token_count: u64,
-}
-
-impl Thread {
- pub fn load(
- agent_thread: Arc<dyn AgentThread>,
- project: Entity<Project>,
- cx: &mut Context<Self>,
- ) -> Self {
- let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
- Self {
- agent_thread,
- summary: ThreadSummary::Pending,
- pending_send: None,
- pending_summary: Task::ready(None),
- detailed_summary_task: Task::ready(None),
- detailed_summary_tx,
- detailed_summary_rx,
- completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
- messages: todo!("read from agent"),
- checkpoints_by_message: HashMap::default(),
- project: project.clone(),
- last_restore_checkpoint: None,
- pending_checkpoint: None,
- action_log: cx.new(|_| ActionLog::new(project.clone())),
- initial_project_snapshot: {
- let project_snapshot = Self::project_snapshot(project, cx);
- cx.foreground_executor()
- .spawn(async move { Some(project_snapshot.await) })
- .shared()
- },
- request_token_usage: Vec::new(),
- cumulative_token_usage: TokenUsage::default(),
- exceeded_window_error: None,
- tool_use_limit_reached: false,
- feedback: None,
- message_feedback: HashMap::default(),
- last_auto_capture_at: None,
- last_received_chunk_at: None,
- }
- }
-
- pub fn id(&self) -> AgentThreadId {
- self.agent_thread.id()
- }
-
- pub fn profile(&self) -> &AgentProfile {
- todo!()
- }
-
- pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
- todo!()
- // if &id != self.profile.id() {
- // self.profile = AgentProfile::new(id, self.tools.clone());
- // cx.emit(ThreadEvent::ProfileChanged);
- // }
- }
-
- pub fn is_empty(&self) -> bool {
- self.messages.is_empty()
- }
-
- pub fn advance_prompt_id(&mut self) {
- todo!()
- // self.last_prompt_id = PromptId::new();
- }
-
- pub fn project_context(&self) -> SharedProjectContext {
- todo!()
- // self.project_context.clone()
- }
-
- pub fn summary(&self) -> &ThreadSummary {
- &self.summary
- }
-
- pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
- todo!()
- // let current_summary = match &self.summary {
- // ThreadSummary::Pending | ThreadSummary::Generating => return,
- // ThreadSummary::Ready(summary) => summary,
- // ThreadSummary::Error => &ThreadSummary::DEFAULT,
- // };
-
- // let mut new_summary = new_summary.into();
-
- // if new_summary.is_empty() {
- // new_summary = ThreadSummary::DEFAULT;
- // }
-
- // if current_summary != &new_summary {
- // self.summary = ThreadSummary::Ready(new_summary);
- // cx.emit(ThreadEvent::SummaryChanged);
- // }
- }
-
- pub fn completion_mode(&self) -> CompletionMode {
- self.completion_mode
- }
-
- pub fn set_completion_mode(&mut self, mode: CompletionMode) {
- self.completion_mode = mode;
- }
-
- pub fn message(&self, id: AgentThreadMessageId) -> Option<&Message> {
- let index = self
- .messages
- .binary_search_by(|message| message.id.cmp(&id))
- .ok()?;
-
- self.messages.get(index)
- }
-
- pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
- self.messages.iter()
- }
-
- pub fn is_generating(&self) -> bool {
- self.pending_send.is_some()
- }
-
- /// Indicates whether streaming of language model events is stale.
- /// When `is_generating()` is false, this method returns `None`.
- pub fn is_generation_stale(&self) -> Option<bool> {
- const STALE_THRESHOLD: u128 = 250;
-
- self.last_received_chunk_at
- .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD)
- }
-
- fn received_chunk(&mut self) {
- self.last_received_chunk_at = Some(Instant::now());
- }
-
- pub fn checkpoint_for_message(&self, id: AgentThreadMessageId) -> Option<ThreadCheckpoint> {
- self.checkpoints_by_message.get(&id).cloned()
- }
-
- pub fn restore_checkpoint(
- &mut self,
- checkpoint: ThreadCheckpoint,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
- message_id: checkpoint.message_id,
- });
- cx.emit(ThreadEvent::CheckpointChanged);
- cx.notify();
-
- let git_store = self.project().read(cx).git_store().clone();
- let restore = git_store.update(cx, |git_store, cx| {
- git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
- });
-
- cx.spawn(async move |this, cx| {
- let result = restore.await;
- this.update(cx, |this, cx| {
- if let Err(err) = result.as_ref() {
- this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
- message_id: checkpoint.message_id,
- error: err.to_string(),
- });
- } else {
- this.truncate(checkpoint.message_id, cx);
- this.last_restore_checkpoint = None;
- }
- this.pending_checkpoint = None;
- cx.emit(ThreadEvent::CheckpointChanged);
- cx.notify();
- })?;
- result
- })
- }
-
- fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
- let pending_checkpoint = if self.is_generating() {
- return;
- } else if let Some(checkpoint) = self.pending_checkpoint.take() {
- checkpoint
- } else {
- return;
- };
-
- self.finalize_checkpoint(pending_checkpoint, cx);
- }
-
- fn finalize_checkpoint(
- &mut self,
- pending_checkpoint: ThreadCheckpoint,
- cx: &mut Context<Self>,
- ) {
- let git_store = self.project.read(cx).git_store().clone();
- let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
- cx.spawn(async move |this, cx| match final_checkpoint.await {
- Ok(final_checkpoint) => {
- let equal = git_store
- .update(cx, |store, cx| {
- store.compare_checkpoints(
- pending_checkpoint.git_checkpoint.clone(),
- final_checkpoint.clone(),
- cx,
- )
- })?
- .await
- .unwrap_or(false);
-
- if !equal {
- this.update(cx, |this, cx| {
- this.insert_checkpoint(pending_checkpoint, cx)
- })?;
- }
-
- Ok(())
- }
- Err(_) => this.update(cx, |this, cx| {
- this.insert_checkpoint(pending_checkpoint, cx)
- }),
- })
- .detach();
- }
-
- fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
- self.checkpoints_by_message
- .insert(checkpoint.message_id, checkpoint);
- cx.emit(ThreadEvent::CheckpointChanged);
- cx.notify();
- }
-
- pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
- self.last_restore_checkpoint.as_ref()
- }
-
- pub fn truncate(&mut self, message_id: AgentThreadMessageId, cx: &mut Context<Self>) {
- todo!("call truncate on the agent");
- let Some(message_ix) = self
- .messages
- .iter()
- .rposition(|message| message.id == message_id)
- else {
- return;
- };
- for deleted_message in self.messages.drain(message_ix..) {
- self.checkpoints_by_message.remove(&deleted_message.id);
- }
- cx.notify();
- }
-
- pub fn is_turn_end(&self, ix: usize) -> bool {
- todo!()
- // if self.messages.is_empty() {
- // return false;
- // }
-
- // if !self.is_generating() && ix == self.messages.len() - 1 {
- // return true;
- // }
-
- // let Some(message) = self.messages.get(ix) else {
- // return false;
- // };
-
- // if message.role != Role::Assistant {
- // return false;
- // }
-
- // self.messages
- // .get(ix + 1)
- // .and_then(|message| {
- // self.message(message.id)
- // .map(|next_message| next_message.role == Role::User && !next_message.is_hidden)
- // })
- // .unwrap_or(false)
- }
-
- pub fn tool_use_limit_reached(&self) -> bool {
- self.tool_use_limit_reached
- }
-
- /// Returns whether any pending tool uses may perform edits
- pub fn has_pending_edit_tool_uses(&self) -> bool {
- todo!()
- }
-
- // pub fn insert_user_message(
- // &mut self,
- // text: impl Into<String>,
- // loaded_context: ContextLoadResult,
- // git_checkpoint: Option<GitStoreCheckpoint>,
- // creases: Vec<MessageCrease>,
- // cx: &mut Context<Self>,
- // ) -> AgentThreadMessageId {
- // todo!("move this logic into send")
- // if !loaded_context.referenced_buffers.is_empty() {
- // self.action_log.update(cx, |log, cx| {
- // for buffer in loaded_context.referenced_buffers {
- // log.buffer_read(buffer, cx);
- // }
- // });
- // }
-
- // let message_id = self.insert_message(
- // Role::User,
- // vec![MessageSegment::Text(text.into())],
- // loaded_context.loaded_context,
- // creases,
- // false,
- // cx,
- // );
-
- // if let Some(git_checkpoint) = git_checkpoint {
- // self.pending_checkpoint = Some(ThreadCheckpoint {
- // message_id,
- // git_checkpoint,
- // });
- // }
-
- // self.auto_capture_telemetry(cx);
-
- // message_id
- // }
-
- pub fn send(&mut self, message: Vec<AgentThreadUserMessageChunk>, cx: &mut Context<Self>) {}
-
- pub fn resume(&mut self, cx: &mut Context<Self>) {
- todo!()
- }
-
- pub fn edit(
- &mut self,
- message_id: AgentThreadMessageId,
- message: Vec<AgentThreadUserMessageChunk>,
- cx: &mut Context<Self>,
- ) {
- todo!()
- }
-
- pub fn cancel(&mut self, cx: &mut Context<Self>) {
- todo!()
- }
-
- // pub fn insert_invisible_continue_message(
- // &mut self,
- // cx: &mut Context<Self>,
- // ) -> AgentThreadMessageId {
- // let id = self.insert_message(
- // Role::User,
- // vec![MessageSegment::Text("Continue where you left off".into())],
- // LoadedContext::default(),
- // vec![],
- // true,
- // cx,
- // );
- // self.pending_checkpoint = None;
-
- // id
- // }
-
- // pub fn insert_assistant_message(
- // &mut self,
- // segments: Vec<MessageSegment>,
- // cx: &mut Context<Self>,
- // ) -> AgentThreadMessageId {
- // self.insert_message(
- // Role::Assistant,
- // segments,
- // LoadedContext::default(),
- // Vec::new(),
- // false,
- // cx,
- // )
- // }
-
- // pub fn insert_message(
- // &mut self,
- // role: Role,
- // segments: Vec<MessageSegment>,
- // loaded_context: LoadedContext,
- // creases: Vec<MessageCrease>,
- // is_hidden: bool,
- // cx: &mut Context<Self>,
- // ) -> AgentThreadMessageId {
- // let id = self.next_message_id.post_inc();
- // self.messages.push(Message {
- // id,
- // role,
- // segments,
- // loaded_context,
- // creases,
- // is_hidden,
- // ui_only: false,
- // });
- // self.touch_updated_at();
- // cx.emit(ThreadEvent::MessageAdded(id));
- // id
- // }
-
- // pub fn edit_message(
- // &mut self,
- // id: AgentThreadMessageId,
- // new_role: Role,
- // new_segments: Vec<MessageSegment>,
- // creases: Vec<MessageCrease>,
- // loaded_context: Option<LoadedContext>,
- // checkpoint: Option<GitStoreCheckpoint>,
- // cx: &mut Context<Self>,
- // ) -> bool {
- // let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
- // return false;
- // };
- // message.role = new_role;
- // message.segments = new_segments;
- // message.creases = creases;
- // if let Some(context) = loaded_context {
- // message.loaded_context = context;
- // }
- // if let Some(git_checkpoint) = checkpoint {
- // self.checkpoints_by_message.insert(
- // id,
- // ThreadCheckpoint {
- // message_id: id,
- // git_checkpoint,
- // },
- // );
- // }
- // self.touch_updated_at();
- // cx.emit(ThreadEvent::MessageEdited(id));
- // true
- // }
-
- /// Returns the representation of this [`Thread`] in a textual form.
- ///
- /// This is the representation we use when attaching a thread as context to another thread.
- pub fn text(&self) -> String {
- let mut text = String::new();
-
- for message in &self.messages {
- text.push_str(match message.role {
- language_model::Role::User => "User:",
- language_model::Role::Assistant => "Agent:",
- language_model::Role::System => "System:",
- });
- text.push('\n');
-
- text.push_str("<think>");
- text.push_str(&message.thinking);
- text.push_str("</think>");
- text.push_str(&message.text);
-
- // todo!('what about tools?');
-
- text.push('\n');
- }
-
- text
- }
-
- pub fn used_tools_since_last_user_message(&self) -> bool {
- todo!()
- // for message in self.messages.iter().rev() {
- // if self.tool_use.message_has_tool_results(message.id) {
- // return true;
- // } else if message.role == Role::User {
- // return false;
- // }
- // }
-
- // false
- }
-
- pub fn start_generating_detailed_summary_if_needed(
- &mut self,
- thread_store: WeakEntity<ThreadStore>,
- cx: &mut Context<Self>,
- ) {
- let Some(last_message_id) = self.messages.last().map(|message| message.id) else {
- return;
- };
-
- match &*self.detailed_summary_rx.borrow() {
- DetailedSummaryState::Generating { message_id, .. }
- | DetailedSummaryState::Generated { message_id, .. }
- if *message_id == last_message_id =>
- {
- // Already up-to-date
- return;
- }
- _ => {}
- }
-
- let summary = self.agent_thread.summary();
-
- *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
- message_id: last_message_id,
- };
-
- // Replace the detailed summarization task if there is one, cancelling it. It would probably
- // be better to allow the old task to complete, but this would require logic for choosing
- // which result to prefer (the old task could complete after the new one, resulting in a
- // stale summary).
- self.detailed_summary_task = cx.spawn(async move |thread, cx| {
- let Some(summary) = summary.await.log_err() else {
- thread
- .update(cx, |thread, _cx| {
- *thread.detailed_summary_tx.borrow_mut() =
- DetailedSummaryState::NotGenerated;
- })
- .ok()?;
- return None;
- };
-
- thread
- .update(cx, |thread, _cx| {
- *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated {
- text: summary.into(),
- message_id: last_message_id,
- };
- })
- .ok()?;
-
- Some(())
- });
- }
-
- pub async fn wait_for_detailed_summary_or_text(
- this: &Entity<Self>,
- cx: &mut AsyncApp,
- ) -> Option<SharedString> {
- let mut detailed_summary_rx = this
- .read_with(cx, |this, _cx| this.detailed_summary_rx.clone())
- .ok()?;
- loop {
- match detailed_summary_rx.recv().await? {
- DetailedSummaryState::Generating { .. } => {}
- DetailedSummaryState::NotGenerated => {
- return this.read_with(cx, |this, _cx| this.text().into()).ok();
- }
- DetailedSummaryState::Generated { text, .. } => return Some(text),
- }
- }
- }
-
- pub fn latest_detailed_summary_or_text(&self) -> SharedString {
- self.detailed_summary_rx
- .borrow()
- .text()
- .unwrap_or_else(|| self.text().into())
- }
-
- pub fn is_generating_detailed_summary(&self) -> bool {
- matches!(
- &*self.detailed_summary_rx.borrow(),
- DetailedSummaryState::Generating { .. }
- )
- }
-
- pub fn feedback(&self) -> Option<ThreadFeedback> {
- self.feedback
- }
-
- pub fn message_feedback(&self, message_id: AgentThreadMessageId) -> Option<ThreadFeedback> {
- self.message_feedback.get(&message_id).copied()
- }
-
- pub fn report_message_feedback(
- &mut self,
- message_id: AgentThreadMessageId,
- feedback: ThreadFeedback,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- todo!()
- // if self.message_feedback.get(&message_id) == Some(&feedback) {
- // return Task::ready(Ok(()));
- // }
-
- // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
- // let serialized_thread = self.serialize(cx);
- // let thread_id = self.id().clone();
- // let client = self.project.read(cx).client();
-
- // let enabled_tool_names: Vec<String> = self
- // .profile
- // .enabled_tools(cx)
- // .iter()
- // .map(|tool| tool.name())
- // .collect();
-
- // self.message_feedback.insert(message_id, feedback);
-
- // cx.notify();
-
- // let message_content = self
- // .message(message_id)
- // .map(|msg| msg.to_string())
- // .unwrap_or_default();
-
- // cx.background_spawn(async move {
- // let final_project_snapshot = final_project_snapshot.await;
- // let serialized_thread = serialized_thread.await?;
- // let thread_data =
- // serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
-
- // let rating = match feedback {
- // ThreadFeedback::Positive => "positive",
- // ThreadFeedback::Negative => "negative",
- // };
- // telemetry::event!(
- // "Assistant Thread Rated",
- // rating,
- // thread_id,
- // enabled_tool_names,
- // message_id = message_id,
- // message_content,
- // thread_data,
- // final_project_snapshot
- // );
- // client.telemetry().flush_events().await;
-
- // Ok(())
- // })
- }
-
- pub fn report_feedback(
- &mut self,
- feedback: ThreadFeedback,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- todo!()
- // let last_assistant_message_id = self
- // .messages
- // .iter()
- // .rev()
- // .find(|msg| msg.role == Role::Assistant)
- // .map(|msg| msg.id);
-
- // if let Some(message_id) = last_assistant_message_id {
- // self.report_message_feedback(message_id, feedback, cx)
- // } else {
- // let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
- // let serialized_thread = self.serialize(cx);
- // let thread_id = self.id().clone();
- // let client = self.project.read(cx).client();
- // self.feedback = Some(feedback);
- // cx.notify();
-
- // cx.background_spawn(async move {
- // let final_project_snapshot = final_project_snapshot.await;
- // let serialized_thread = serialized_thread.await?;
- // let thread_data = serde_json::to_value(serialized_thread)
- // .unwrap_or_else(|_| serde_json::Value::Null);
-
- // let rating = match feedback {
- // ThreadFeedback::Positive => "positive",
- // ThreadFeedback::Negative => "negative",
- // };
- // telemetry::event!(
- // "Assistant Thread Rated",
- // rating,
- // thread_id,
- // thread_data,
- // final_project_snapshot
- // );
- // client.telemetry().flush_events().await;
-
- // Ok(())
- // })
- // }
- }
-
- /// Create a snapshot of the current project state including git information and unsaved buffers.
- fn project_snapshot(
- project: Entity<Project>,
- cx: &mut Context<Self>,
- ) -> Task<Arc<ProjectSnapshot>> {
- let git_store = project.read(cx).git_store().clone();
- let worktree_snapshots: Vec<_> = project
- .read(cx)
- .visible_worktrees(cx)
- .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
- .collect();
-
- cx.spawn(async move |_, cx| {
- let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
-
- let mut unsaved_buffers = Vec::new();
- cx.update(|app_cx| {
- let buffer_store = project.read(app_cx).buffer_store();
- for buffer_handle in buffer_store.read(app_cx).buffers() {
- let buffer = buffer_handle.read(app_cx);
- if buffer.is_dirty() {
- if let Some(file) = buffer.file() {
- let path = file.path().to_string_lossy().to_string();
- unsaved_buffers.push(path);
- }
- }
- }
- })
- .ok();
-
- Arc::new(ProjectSnapshot {
- worktree_snapshots,
- unsaved_buffer_paths: unsaved_buffers,
- timestamp: Utc::now(),
- })
- })
- }
-
- fn worktree_snapshot(
- worktree: Entity<project::Worktree>,
- git_store: Entity<GitStore>,
- cx: &App,
- ) -> Task<WorktreeSnapshot> {
- cx.spawn(async move |cx| {
- // Get worktree path and snapshot
- let worktree_info = cx.update(|app_cx| {
- let worktree = worktree.read(app_cx);
- let path = worktree.abs_path().to_string_lossy().to_string();
- let snapshot = worktree.snapshot();
- (path, snapshot)
- });
-
- let Ok((worktree_path, _snapshot)) = worktree_info else {
- return WorktreeSnapshot {
- worktree_path: String::new(),
- git_state: None,
- };
- };
-
- let git_state = git_store
- .update(cx, |git_store, cx| {
- git_store
- .repositories()
- .values()
- .find(|repo| {
- repo.read(cx)
- .abs_path_to_repo_path(&worktree.read(cx).abs_path())
- .is_some()
- })
- .cloned()
- })
- .ok()
- .flatten()
- .map(|repo| {
- repo.update(cx, |repo, _| {
- let current_branch =
- repo.branch.as_ref().map(|branch| branch.name().to_owned());
- repo.send_job(None, |state, _| async move {
- let RepositoryState::Local { backend, .. } = state else {
- return GitState {
- remote_url: None,
- head_sha: None,
- current_branch,
- diff: None,
- };
- };
-
- let remote_url = backend.remote_url("origin");
- let head_sha = backend.head_sha().await;
- let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
-
- GitState {
- remote_url,
- head_sha,
- current_branch,
- diff,
- }
- })
- })
- });
-
- let git_state = match git_state {
- Some(git_state) => match git_state.ok() {
- Some(git_state) => git_state.await.ok(),
- None => None,
- },
- None => None,
- };
-
- WorktreeSnapshot {
- worktree_path,
- git_state,
- }
- })
- }
-
- pub fn to_markdown(&self, cx: &App) -> Result<String> {
- todo!()
- // let mut markdown = Vec::new();
-
- // let summary = self.summary().or_default();
- // writeln!(markdown, "# {summary}\n")?;
-
- // for message in self.messages() {
- // writeln!(
- // markdown,
- // "## {role}\n",
- // role = match message.role {
- // Role::User => "User",
- // Role::Assistant => "Agent",
- // Role::System => "System",
- // }
- // )?;
-
- // if !message.loaded_context.text.is_empty() {
- // writeln!(markdown, "{}", message.loaded_context.text)?;
- // }
-
- // if !message.loaded_context.images.is_empty() {
- // writeln!(
- // markdown,
- // "\n{} images attached as context.\n",
- // message.loaded_context.images.len()
- // )?;
- // }
-
- // for segment in &message.segments {
- // match segment {
- // MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
- // MessageSegment::Thinking { text, .. } => {
- // writeln!(markdown, "<think>\n{}\n</think>\n", text)?
- // }
- // MessageSegment::RedactedThinking(_) => {}
- // }
- // }
-
- // for tool_use in self.tool_uses_for_message(message.id, cx) {
- // writeln!(
- // markdown,
- // "**Use Tool: {} ({})**",
- // tool_use.name, tool_use.id
- // )?;
- // writeln!(markdown, "```json")?;
- // writeln!(
- // markdown,
- // "{}",
- // serde_json::to_string_pretty(&tool_use.input)?
- // )?;
- // writeln!(markdown, "```")?;
- // }
-
- // for tool_result in self.tool_results_for_message(message.id) {
- // write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
- // if tool_result.is_error {
- // write!(markdown, " (Error)")?;
- // }
-
- // writeln!(markdown, "**\n")?;
- // match &tool_result.content {
- // LanguageModelToolResultContent::Text(text) => {
- // writeln!(markdown, "{text}")?;
- // }
- // LanguageModelToolResultContent::Image(image) => {
- // writeln!(markdown, "", image.source)?;
- // }
- // }
-
- // if let Some(output) = tool_result.output.as_ref() {
- // writeln!(
- // markdown,
- // "\n\nDebug Output:\n\n```json\n{}\n```\n",
- // serde_json::to_string_pretty(output)?
- // )?;
- // }
- // }
- // }
-
- // Ok(String::from_utf8_lossy(&markdown).to_string())
- }
-
- pub fn keep_edits_in_range(
- &mut self,
- buffer: Entity<language::Buffer>,
- buffer_range: Range<language::Anchor>,
- cx: &mut Context<Self>,
- ) {
- self.action_log.update(cx, |action_log, cx| {
- action_log.keep_edits_in_range(buffer, buffer_range, cx)
- });
- }
-
- pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
- self.action_log
- .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
- }
-
- pub fn reject_edits_in_ranges(
- &mut self,
- buffer: Entity<language::Buffer>,
- buffer_ranges: Vec<Range<language::Anchor>>,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- self.action_log.update(cx, |action_log, cx| {
- action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
- })
- }
-
- pub fn action_log(&self) -> &Entity<ActionLog> {
- &self.action_log
- }
-
- pub fn project(&self) -> &Entity<Project> {
- &self.project
- }
-
- pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
- todo!()
- // if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
- // return;
- // }
-
- // let now = Instant::now();
- // if let Some(last) = self.last_auto_capture_at {
- // if now.duration_since(last).as_secs() < 10 {
- // return;
- // }
- // }
-
- // self.last_auto_capture_at = Some(now);
-
- // let thread_id = self.id().clone();
- // let github_login = self
- // .project
- // .read(cx)
- // .user_store()
- // .read(cx)
- // .current_user()
- // .map(|user| user.github_login.clone());
- // let client = self.project.read(cx).client();
- // let serialize_task = self.serialize(cx);
-
- // cx.background_executor()
- // .spawn(async move {
- // if let Ok(serialized_thread) = serialize_task.await {
- // if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
- // telemetry::event!(
- // "Agent Thread Auto-Captured",
- // thread_id = thread_id.to_string(),
- // thread_data = thread_data,
- // auto_capture_reason = "tracked_user",
- // github_login = github_login
- // );
-
- // client.telemetry().flush_events().await;
- // }
- // }
- // })
- // .detach();
- }
-
- pub fn cumulative_token_usage(&self) -> TokenUsage {
- self.cumulative_token_usage
- }
-
- pub fn token_usage_up_to_message(&self, message_id: AgentThreadMessageId) -> TotalTokenUsage {
- todo!()
- // let Some(model) = self.configured_model.as_ref() else {
- // return TotalTokenUsage::default();
- // };
-
- // let max = model.model.max_token_count();
-
- // let index = self
- // .messages
- // .iter()
- // .position(|msg| msg.id == message_id)
- // .unwrap_or(0);
-
- // if index == 0 {
- // return TotalTokenUsage { total: 0, max };
- // }
-
- // let token_usage = &self
- // .request_token_usage
- // .get(index - 1)
- // .cloned()
- // .unwrap_or_default();
-
- // TotalTokenUsage {
- // total: token_usage.total_tokens(),
- // max,
- // }
- }
-
- pub fn total_token_usage(&self) -> Option<TotalTokenUsage> {
- todo!()
- // let model = self.configured_model.as_ref()?;
-
- // let max = model.model.max_token_count();
-
- // if let Some(exceeded_error) = &self.exceeded_window_error {
- // if model.model.id() == exceeded_error.model_id {
- // return Some(TotalTokenUsage {
- // total: exceeded_error.token_count,
- // max,
- // });
- // }
- // }
-
- // let total = self
- // .token_usage_at_last_message()
- // .unwrap_or_default()
- // .total_tokens();
-
- // Some(TotalTokenUsage { total, max })
- }
-
- fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
- self.request_token_usage
- .get(self.messages.len().saturating_sub(1))
- .or_else(|| self.request_token_usage.last())
- .cloned()
- }
-
- fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
- let placeholder = self.token_usage_at_last_message().unwrap_or_default();
- self.request_token_usage
- .resize(self.messages.len(), placeholder);
-
- if let Some(last) = self.request_token_usage.last_mut() {
- *last = token_usage;
- }
- }
-
- fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
- self.project.update(cx, |project, cx| {
- project.user_store().update(cx, |user_store, cx| {
- user_store.update_model_request_usage(
- ModelRequestUsage(RequestUsage {
- amount: amount as i32,
- limit,
- }),
- cx,
- )
- })
- });
- }
-}
-
-#[derive(Debug, Clone, Error)]
-pub enum ThreadError {
- #[error("Payment required")]
- PaymentRequired,
- #[error("Model request limit reached")]
- ModelRequestLimitReached { plan: Plan },
- #[error("Message {header}: {message}")]
- Message {
- header: SharedString,
- message: SharedString,
- },
-}
-
-#[derive(Debug, Clone)]
-pub enum ThreadEvent {
- ShowError(ThreadError),
- StreamedCompletion,
- ReceivedTextChunk,
- NewRequest,
- StreamedAssistantText(AgentThreadMessageId, String),
- StreamedAssistantThinking(AgentThreadMessageId, String),
- StreamedToolUse {
- tool_use_id: LanguageModelToolUseId,
- ui_text: Arc<str>,
- input: serde_json::Value,
- },
- MissingToolUse {
- tool_use_id: LanguageModelToolUseId,
- ui_text: Arc<str>,
- },
- InvalidToolInput {
- tool_use_id: LanguageModelToolUseId,
- ui_text: Arc<str>,
- invalid_input_json: Arc<str>,
- },
- Stopped(Result<StopReason, Arc<anyhow::Error>>),
- MessageAdded(AgentThreadMessageId),
- MessageEdited(AgentThreadMessageId),
- MessageDeleted(AgentThreadMessageId),
- SummaryGenerated,
- SummaryChanged,
- CheckpointChanged,
- ToolConfirmationNeeded,
- ToolUseLimitReached,
- CancelEditing,
- CompletionCanceled,
- ProfileChanged,
- RetriesFailed {
- message: SharedString,
- },
-}
-
-impl EventEmitter<ThreadEvent> for Thread {}
-
-struct PendingCompletion {
- id: usize,
- queue_state: QueueState,
- _task: Task<()>,
-}
-
-/// Resolves tool name conflicts by ensuring all tool names are unique.
-///
-/// When multiple tools have the same name, this function applies the following rules:
-/// 1. Native tools always keep their original name
-/// 2. Context server tools get prefixed with their server ID and an underscore
-/// 3. All tool names are truncated to MAX_TOOL_NAME_LENGTH (64 characters)
-/// 4. If conflicts still exist after prefixing, the conflicting tools are filtered out
-///
-/// Note: This function assumes that built-in tools occur before MCP tools in the tools list.
-fn resolve_tool_name_conflicts(tools: &[Arc<dyn Tool>]) -> Vec<(String, Arc<dyn Tool>)> {
- fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
- let mut tool_name = tool.name();
- tool_name.truncate(MAX_TOOL_NAME_LENGTH);
- tool_name
- }
-
- const MAX_TOOL_NAME_LENGTH: usize = 64;
-
- let mut duplicated_tool_names = HashSet::default();
- let mut seen_tool_names = HashSet::default();
- for tool in tools {
- let tool_name = resolve_tool_name(tool);
- if seen_tool_names.contains(&tool_name) {
- debug_assert!(
- tool.source() != assistant_tool::ToolSource::Native,
- "There are two built-in tools with the same name: {}",
- tool_name
- );
- duplicated_tool_names.insert(tool_name);
- } else {
- seen_tool_names.insert(tool_name);
- }
- }
-
- if duplicated_tool_names.is_empty() {
- return tools
- .into_iter()
- .map(|tool| (resolve_tool_name(tool), tool.clone()))
- .collect();
- }
-
- tools
- .into_iter()
- .filter_map(|tool| {
- let mut tool_name = resolve_tool_name(tool);
- if !duplicated_tool_names.contains(&tool_name) {
- return Some((tool_name, tool.clone()));
- }
- match tool.source() {
- assistant_tool::ToolSource::Native => {
- // Built-in tools always keep their original name
- Some((tool_name, tool.clone()))
- }
- assistant_tool::ToolSource::ContextServer { id } => {
- // Context server tools are prefixed with the context server ID, and truncated if necessary
- tool_name.insert(0, '_');
- if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
- let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
- let mut id = id.to_string();
- id.truncate(len);
- tool_name.insert_str(0, &id);
- } else {
- tool_name.insert_str(0, &id);
- }
-
- tool_name.truncate(MAX_TOOL_NAME_LENGTH);
-
- if seen_tool_names.contains(&tool_name) {
- log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
- None
- } else {
- Some((tool_name, tool.clone()))
- }
- }
- }
- })
- .collect()
-}
@@ -1,8 +1,7 @@
use crate::{
+ MessageId, ThreadId,
context_server_tool::ContextServerTool,
- thread::{
- DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
- },
+ thread::{DetailedSummaryState, ExceededWindowError, ProjectSnapshot, Thread},
};
use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
@@ -400,35 +399,17 @@ impl ThreadStore {
self.threads.iter()
}
- pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
- cx.new(|cx| {
- Thread::new(
- self.project.clone(),
- self.tools.clone(),
- self.prompt_builder.clone(),
- self.project_context.clone(),
- cx,
- )
- })
- }
-
- pub fn create_thread_from_serialized(
- &mut self,
- serialized: SerializedThread,
- cx: &mut Context<Self>,
- ) -> Entity<Thread> {
- cx.new(|cx| {
- Thread::deserialize(
- ThreadId::new(),
- serialized,
- self.project.clone(),
- self.tools.clone(),
- self.prompt_builder.clone(),
- self.project_context.clone(),
- None,
- cx,
- )
- })
+ pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
+ todo!()
+ // cx.new(|cx| {
+ // Thread::new(
+ // self.project.clone(),
+ // self.tools.clone(),
+ // self.prompt_builder.clone(),
+ // self.project_context.clone(),
+ // cx,
+ // )
+ // })
}
pub fn open_thread(
@@ -447,51 +428,53 @@ impl ThreadStore {
.await?
.with_context(|| format!("no thread found with ID: {id:?}"))?;
- let thread = this.update_in(cx, |this, window, cx| {
- cx.new(|cx| {
- Thread::deserialize(
- id.clone(),
- thread,
- this.project.clone(),
- this.tools.clone(),
- this.prompt_builder.clone(),
- this.project_context.clone(),
- Some(window),
- cx,
- )
- })
- })?;
-
- Ok(thread)
+ todo!();
+ // let thread = this.update_in(cx, |this, window, cx| {
+ // cx.new(|cx| {
+ // Thread::deserialize(
+ // id.clone(),
+ // thread,
+ // this.project.clone(),
+ // this.tools.clone(),
+ // this.prompt_builder.clone(),
+ // this.project_context.clone(),
+ // Some(window),
+ // cx,
+ // )
+ // })
+ // })?;
+ // Ok(thread)
})
}
pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
- let (metadata, serialized_thread) =
- thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
-
- let database_future = ThreadsDatabase::global_future(cx);
- cx.spawn(async move |this, cx| {
- let serialized_thread = serialized_thread.await?;
- let database = database_future.await.map_err(|err| anyhow!(err))?;
- database.save_thread(metadata, serialized_thread).await?;
-
- this.update(cx, |this, cx| this.reload(cx))?.await
- })
+ todo!()
+ // let (metadata, serialized_thread) =
+ // thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
+
+ // let database_future = ThreadsDatabase::global_future(cx);
+ // cx.spawn(async move |this, cx| {
+ // let serialized_thread = serialized_thread.await?;
+ // let database = database_future.await.map_err(|err| anyhow!(err))?;
+ // database.save_thread(metadata, serialized_thread).await?;
+
+ // this.update(cx, |this, cx| this.reload(cx))?.await
+ // })
}
pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
- let id = id.clone();
- let database_future = ThreadsDatabase::global_future(cx);
- cx.spawn(async move |this, cx| {
- let database = database_future.await.map_err(|err| anyhow!(err))?;
- database.delete_thread(id.clone()).await?;
-
- this.update(cx, |this, cx| {
- this.threads.retain(|thread| thread.id != id);
- cx.notify();
- })
- })
+ todo!()
+ // let id = id.clone();
+ // let database_future = ThreadsDatabase::global_future(cx);
+ // cx.spawn(async move |this, cx| {
+ // let database = database_future.await.map_err(|err| anyhow!(err))?;
+ // database.delete_thread(id.clone()).await?;
+
+ // this.update(cx, |this, cx| {
+ // this.threads.retain(|thread| thread.id != id);
+ // cx.notify();
+ // })
+ // })
}
pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
@@ -1067,7 +1050,7 @@ impl ThreadsDatabase {
#[cfg(test)]
mod tests {
use super::*;
- use crate::thread::{DetailedSummaryState, MessageId};
+ use crate::{MessageId, thread::DetailedSummaryState};
use chrono::Utc;
use language_model::{Role, TokenUsage};
use pretty_assertions::assert_eq;
@@ -1,567 +0,0 @@
-use crate::{
- thread::{MessageId, PromptId, ThreadId},
- thread_store::SerializedMessage,
-};
-use anyhow::Result;
-use assistant_tool::{
- AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
-};
-use collections::HashMap;
-use futures::{FutureExt as _, future::Shared};
-use gpui::{App, Entity, SharedString, Task, Window};
-use icons::IconName;
-use language_model::{
- ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
- LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
-};
-use project::Project;
-use std::sync::Arc;
-use util::truncate_lines_to_byte_limit;
-
-#[derive(Debug)]
-pub struct ToolUse {
- pub id: LanguageModelToolUseId,
- pub name: SharedString,
- pub ui_text: SharedString,
- pub status: ToolUseStatus,
- pub input: serde_json::Value,
- pub icon: icons::IconName,
- pub needs_confirmation: bool,
-}
-
-pub struct ToolUseState {
- tools: Entity<ToolWorkingSet>,
- tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
- tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
- pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
- tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
- tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
-}
-
-impl ToolUseState {
- pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
- Self {
- tools,
- tool_uses_by_assistant_message: HashMap::default(),
- tool_results: HashMap::default(),
- pending_tool_uses_by_id: HashMap::default(),
- tool_result_cards: HashMap::default(),
- tool_use_metadata_by_id: HashMap::default(),
- }
- }
-
- /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
- ///
- /// Accepts a function to filter the tools that should be used to populate the state.
- ///
- /// If `window` is `None` (e.g., when in headless mode or when running evals),
- /// tool cards won't be deserialized
- pub fn from_serialized_messages(
- tools: Entity<ToolWorkingSet>,
- messages: &[SerializedMessage],
- project: Entity<Project>,
- window: Option<&mut Window>, // None in headless mode
- cx: &mut App,
- ) -> Self {
- let mut this = Self::new(tools);
- let mut tool_names_by_id = HashMap::default();
- let mut window = window;
-
- for message in messages {
- match message.role {
- Role::Assistant => {
- if !message.tool_uses.is_empty() {
- let tool_uses = message
- .tool_uses
- .iter()
- .map(|tool_use| LanguageModelToolUse {
- id: tool_use.id.clone(),
- name: tool_use.name.clone().into(),
- raw_input: tool_use.input.to_string(),
- input: tool_use.input.clone(),
- is_input_complete: true,
- })
- .collect::<Vec<_>>();
-
- tool_names_by_id.extend(
- tool_uses
- .iter()
- .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
- );
-
- this.tool_uses_by_assistant_message
- .insert(message.id, tool_uses);
-
- for tool_result in &message.tool_results {
- let tool_use_id = tool_result.tool_use_id.clone();
- let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
- log::warn!("no tool name found for tool use: {tool_use_id:?}");
- continue;
- };
-
- this.tool_results.insert(
- tool_use_id.clone(),
- LanguageModelToolResult {
- tool_use_id: tool_use_id.clone(),
- tool_name: tool_use.clone(),
- is_error: tool_result.is_error,
- content: tool_result.content.clone(),
- output: tool_result.output.clone(),
- },
- );
-
- if let Some(window) = &mut window {
- if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
- if let Some(output) = tool_result.output.clone() {
- if let Some(card) = tool.deserialize_card(
- output,
- project.clone(),
- window,
- cx,
- ) {
- this.tool_result_cards.insert(tool_use_id, card);
- }
- }
- }
- }
- }
- }
- }
- Role::System | Role::User => {}
- }
- }
-
- this
- }
-
- pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
- let mut cancelled_tool_uses = Vec::new();
- self.pending_tool_uses_by_id
- .retain(|tool_use_id, tool_use| {
- if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
- return true;
- }
-
- let content = "Tool canceled by user".into();
- self.tool_results.insert(
- tool_use_id.clone(),
- LanguageModelToolResult {
- tool_use_id: tool_use_id.clone(),
- tool_name: tool_use.name.clone(),
- content,
- output: None,
- is_error: true,
- },
- );
- cancelled_tool_uses.push(tool_use.clone());
- false
- });
- cancelled_tool_uses
- }
-
- pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
- self.pending_tool_uses_by_id.values().collect()
- }
-
- pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
- let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
- return Vec::new();
- };
-
- let mut tool_uses = Vec::new();
-
- for tool_use in tool_uses_for_message.iter() {
- let tool_result = self.tool_results.get(&tool_use.id);
-
- let status = (|| {
- if let Some(tool_result) = tool_result {
- let content = tool_result
- .content
- .to_str()
- .map(|str| str.to_owned().into())
- .unwrap_or_default();
-
- return if tool_result.is_error {
- ToolUseStatus::Error(content)
- } else {
- ToolUseStatus::Finished(content)
- };
- }
-
- if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
- match pending_tool_use.status {
- PendingToolUseStatus::Idle => ToolUseStatus::Pending,
- PendingToolUseStatus::NeedsConfirmation { .. } => {
- ToolUseStatus::NeedsConfirmation
- }
- PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
- PendingToolUseStatus::Error(ref err) => {
- ToolUseStatus::Error(err.clone().into())
- }
- PendingToolUseStatus::InputStillStreaming => {
- ToolUseStatus::InputStillStreaming
- }
- }
- } else {
- ToolUseStatus::Pending
- }
- })();
-
- let (icon, needs_confirmation) =
- if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
- (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
- } else {
- (IconName::Cog, false)
- };
-
- tool_uses.push(ToolUse {
- id: tool_use.id.clone(),
- name: tool_use.name.clone().into(),
- ui_text: self.tool_ui_label(
- &tool_use.name,
- &tool_use.input,
- tool_use.is_input_complete,
- cx,
- ),
- input: tool_use.input.clone(),
- status,
- icon,
- needs_confirmation,
- })
- }
-
- tool_uses
- }
-
- pub fn tool_ui_label(
- &self,
- tool_name: &str,
- input: &serde_json::Value,
- is_input_complete: bool,
- cx: &App,
- ) -> SharedString {
- if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
- if is_input_complete {
- tool.ui_text(input).into()
- } else {
- tool.still_streaming_ui_text(input).into()
- }
- } else {
- format!("Unknown tool {tool_name:?}").into()
- }
- }
-
- pub fn tool_results_for_message(
- &self,
- assistant_message_id: MessageId,
- ) -> Vec<&LanguageModelToolResult> {
- let Some(tool_uses) = self
- .tool_uses_by_assistant_message
- .get(&assistant_message_id)
- else {
- return Vec::new();
- };
-
- tool_uses
- .iter()
- .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
- .collect()
- }
-
- pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
- self.tool_uses_by_assistant_message
- .get(&assistant_message_id)
- .map_or(false, |results| !results.is_empty())
- }
-
- pub fn tool_result(
- &self,
- tool_use_id: &LanguageModelToolUseId,
- ) -> Option<&LanguageModelToolResult> {
- self.tool_results.get(tool_use_id)
- }
-
- pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
- self.tool_result_cards.get(tool_use_id)
- }
-
- pub fn insert_tool_result_card(
- &mut self,
- tool_use_id: LanguageModelToolUseId,
- card: AnyToolCard,
- ) {
- self.tool_result_cards.insert(tool_use_id, card);
- }
-
- pub fn request_tool_use(
- &mut self,
- assistant_message_id: MessageId,
- tool_use: LanguageModelToolUse,
- metadata: ToolUseMetadata,
- cx: &App,
- ) -> Arc<str> {
- let tool_uses = self
- .tool_uses_by_assistant_message
- .entry(assistant_message_id)
- .or_default();
-
- let mut existing_tool_use_found = false;
-
- for existing_tool_use in tool_uses.iter_mut() {
- if existing_tool_use.id == tool_use.id {
- *existing_tool_use = tool_use.clone();
- existing_tool_use_found = true;
- }
- }
-
- if !existing_tool_use_found {
- tool_uses.push(tool_use.clone());
- }
-
- let status = if tool_use.is_input_complete {
- self.tool_use_metadata_by_id
- .insert(tool_use.id.clone(), metadata);
-
- PendingToolUseStatus::Idle
- } else {
- PendingToolUseStatus::InputStillStreaming
- };
-
- let ui_text: Arc<str> = self
- .tool_ui_label(
- &tool_use.name,
- &tool_use.input,
- tool_use.is_input_complete,
- cx,
- )
- .into();
-
- let may_perform_edits = self
- .tools
- .read(cx)
- .tool(&tool_use.name, cx)
- .is_some_and(|tool| tool.may_perform_edits());
-
- self.pending_tool_uses_by_id.insert(
- tool_use.id.clone(),
- PendingToolUse {
- assistant_message_id,
- id: tool_use.id,
- name: tool_use.name.clone(),
- ui_text: ui_text.clone(),
- input: tool_use.input,
- may_perform_edits,
- status,
- },
- );
-
- ui_text
- }
-
- pub fn run_pending_tool(
- &mut self,
- tool_use_id: LanguageModelToolUseId,
- ui_text: SharedString,
- task: Task<()>,
- ) {
- if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
- tool_use.ui_text = ui_text.into();
- tool_use.status = PendingToolUseStatus::Running {
- _task: task.shared(),
- };
- }
- }
-
- pub fn confirm_tool_use(
- &mut self,
- tool_use_id: LanguageModelToolUseId,
- ui_text: impl Into<Arc<str>>,
- input: serde_json::Value,
- request: Arc<LanguageModelRequest>,
- tool: Arc<dyn Tool>,
- ) {
- if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
- let ui_text = ui_text.into();
- tool_use.ui_text = ui_text.clone();
- let confirmation = Confirmation {
- tool_use_id,
- input,
- request,
- tool,
- ui_text,
- };
- tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
- }
- }
-
- pub fn insert_tool_output(
- &mut self,
- tool_use_id: LanguageModelToolUseId,
- tool_name: Arc<str>,
- output: Result<ToolResultOutput>,
- configured_model: Option<&ConfiguredModel>,
- ) -> Option<PendingToolUse> {
- let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
-
- telemetry::event!(
- "Agent Tool Finished",
- model = metadata
- .as_ref()
- .map(|metadata| metadata.model.telemetry_id()),
- model_provider = metadata
- .as_ref()
- .map(|metadata| metadata.model.provider_id().to_string()),
- thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
- prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
- tool_name,
- success = output.is_ok()
- );
-
- match output {
- Ok(output) => {
- let tool_result = output.content;
- const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
-
- let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
-
- // Protect from overly large output
- let tool_output_limit = configured_model
- .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
- .unwrap_or(usize::MAX);
-
- let content = match tool_result {
- ToolResultContent::Text(text) => {
- let text = if text.len() < tool_output_limit {
- text
- } else {
- let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
- format!(
- "Tool result too long. The first {} bytes:\n\n{}",
- truncated.len(),
- truncated
- )
- };
- LanguageModelToolResultContent::Text(text.into())
- }
- ToolResultContent::Image(language_model_image) => {
- if language_model_image.estimate_tokens() < tool_output_limit {
- LanguageModelToolResultContent::Image(language_model_image)
- } else {
- self.tool_results.insert(
- tool_use_id.clone(),
- LanguageModelToolResult {
- tool_use_id: tool_use_id.clone(),
- tool_name,
- content: "Tool responded with an image that would exceeded the remaining tokens".into(),
- is_error: true,
- output: None,
- },
- );
-
- return old_use;
- }
- }
- };
-
- self.tool_results.insert(
- tool_use_id.clone(),
- LanguageModelToolResult {
- tool_use_id: tool_use_id.clone(),
- tool_name,
- content,
- is_error: false,
- output: output.output,
- },
- );
-
- old_use
- }
- Err(err) => {
- self.tool_results.insert(
- tool_use_id.clone(),
- LanguageModelToolResult {
- tool_use_id: tool_use_id.clone(),
- tool_name,
- content: LanguageModelToolResultContent::Text(err.to_string().into()),
- is_error: true,
- output: None,
- },
- );
-
- if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
- tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
- }
-
- self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
- }
- }
- }
-
- pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
- self.tool_uses_by_assistant_message
- .contains_key(&assistant_message_id)
- }
-
- pub fn tool_results(
- &self,
- assistant_message_id: MessageId,
- ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
- self.tool_uses_by_assistant_message
- .get(&assistant_message_id)
- .into_iter()
- .flatten()
- .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
- }
-}
-
-#[derive(Debug, Clone)]
-pub struct PendingToolUse {
- pub id: LanguageModelToolUseId,
- /// The ID of the Assistant message in which the tool use was requested.
- #[allow(unused)]
- pub assistant_message_id: MessageId,
- pub name: Arc<str>,
- pub ui_text: Arc<str>,
- pub input: serde_json::Value,
- pub status: PendingToolUseStatus,
- pub may_perform_edits: bool,
-}
-
-#[derive(Debug, Clone)]
-pub struct Confirmation {
- pub tool_use_id: LanguageModelToolUseId,
- pub input: serde_json::Value,
- pub ui_text: Arc<str>,
- pub request: Arc<LanguageModelRequest>,
- pub tool: Arc<dyn Tool>,
-}
-
-#[derive(Debug, Clone)]
-pub enum PendingToolUseStatus {
- InputStillStreaming,
- Idle,
- NeedsConfirmation(Arc<Confirmation>),
- Running { _task: Shared<Task<()>> },
- Error(#[allow(unused)] Arc<str>),
-}
-
-impl PendingToolUseStatus {
- pub fn is_idle(&self) -> bool {
- matches!(self, PendingToolUseStatus::Idle)
- }
-
- pub fn is_error(&self) -> bool {
- matches!(self, PendingToolUseStatus::Error(_))
- }
-
- pub fn needs_confirmation(&self) -> bool {
- matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
- }
-}
-
-#[derive(Clone)]
-pub struct ToolUseMetadata {
- pub model: Arc<dyn LanguageModel>,
- pub thread_id: ThreadId,
- pub prompt_id: PromptId,
-}
@@ -7,7 +7,7 @@ use crate::ui::{
use crate::{AgentPanel, ModelUsageContext};
use agent::{
ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore,
- Thread, ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadSummary,
+ Thread, ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadTitle,
context::{self, AgentContextHandle, RULES_ICON},
thread_store::RulesLoadingError,
tool_use::{PendingToolUseStatus, ToolUse},
@@ -816,23 +816,24 @@ impl ActiveThread {
_load_edited_message_context_task: None,
};
- for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
- let rendered_message = RenderedMessage::from_segments(
- &message.segments,
- this.language_registry.clone(),
- cx,
- );
- this.push_rendered_message(message.id, rendered_message);
-
- for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
- this.render_tool_use_markdown(
- tool_use.id.clone(),
- tool_use.ui_text.clone(),
- &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
- tool_use.status.text(),
- cx,
- );
- }
+ for message in thread.read(cx).messages() {
+ todo!()
+ // let rendered_message = RenderedMessage::from_segments(
+ // &message.segments,
+ // this.language_registry.clone(),
+ // cx,
+ // );
+ // this.push_rendered_message(message.id, rendered_message);
+
+ // for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
+ // this.render_tool_use_markdown(
+ // tool_use.id.clone(),
+ // tool_use.ui_text.clone(),
+ // &serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
+ // tool_use.status.text(),
+ // cx,
+ // );
+ // }
}
this
@@ -846,19 +847,18 @@ impl ActiveThread {
self.messages.is_empty()
}
- pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary {
- self.thread.read(cx).summary()
+ pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadTitle {
+ self.thread.read(cx).title()
}
pub fn regenerate_summary(&self, cx: &mut App) {
- self.thread.update(cx, |thread, cx| thread.summarize(cx))
+ self.thread
+ .update(cx, |thread, cx| thread.regenerate_summary(cx))
}
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
self.last_error.take();
- self.thread.update(cx, |thread, cx| {
- thread.cancel_last_completion(Some(window.window_handle()), cx)
- })
+ self.thread.update(cx, |thread, cx| thread.cancel(cx))
}
pub fn last_error(&self) -> Option<ThreadError> {
@@ -1185,7 +1185,7 @@ impl ActiveThread {
return;
}
- let title = self.thread.read(cx).summary().unwrap_or("Agent Panel");
+ let title = self.thread.read(cx).title().unwrap_or("Agent Panel");
match AgentSettings::get_global(cx).notify_when_agent_waiting {
NotifyWhenAgentWaiting::PrimaryScreen => {
@@ -3605,7 +3605,7 @@ pub(crate) fn open_active_thread_as_markdown(
workspace.update_in(cx, |workspace, window, cx| {
let thread = thread.read(cx);
let markdown = thread.to_markdown(cx)?;
- let thread_summary = thread.summary().or_default().to_string();
+ let thread_summary = thread.title().or_default().to_string();
let project = workspace.project().clone();
@@ -3776,357 +3776,357 @@ fn open_editor_at_position(
})
}
-#[cfg(test)]
-mod tests {
- use super::*;
- use agent::{MessageSegment, context::ContextLoadResult, thread_store};
- use assistant_tool::{ToolRegistry, ToolWorkingSet};
- use editor::EditorSettings;
- use fs::FakeFs;
- use gpui::{AppContext, TestAppContext, VisualTestContext};
- use language_model::{
- ConfiguredModel, LanguageModel, LanguageModelRegistry,
- fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
- };
- use project::Project;
- use prompt_store::PromptBuilder;
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
- use workspace::CollaboratorId;
-
- #[gpui::test]
- async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) {
- init_test_settings(cx);
-
- let project = create_test_project(
- cx,
- json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
- )
- .await;
-
- let (cx, _active_thread, workspace, thread, model) =
- setup_test_environment(cx, project.clone()).await;
-
- // Insert user message without any context (empty context vector)
- thread.update(cx, |thread, cx| {
- thread.insert_user_message(
- "What is the best way to learn Rust?",
- ContextLoadResult::default(),
- None,
- vec![],
- cx,
- );
- });
-
- // Stream response to user message
- thread.update(cx, |thread, cx| {
- let intent = CompletionIntent::UserPrompt;
- let request = thread.to_completion_request(model.clone(), intent, cx);
- thread.stream_completion(request, model, intent, cx.active_window(), cx)
- });
- // Follow the agent
- cx.update(|window, cx| {
- workspace.update(cx, |workspace, cx| {
- workspace.follow(CollaboratorId::Agent, window, cx);
- })
- });
- assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
-
- // Cancel the current completion
- thread.update(cx, |thread, cx| {
- thread.cancel_last_completion(cx.active_window(), cx)
- });
-
- cx.executor().run_until_parked();
-
- // No longer following the agent
- assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
- }
-
- #[gpui::test]
- async fn test_reinserting_creases_for_edited_message(cx: &mut TestAppContext) {
- init_test_settings(cx);
-
- let project = create_test_project(cx, json!({})).await;
-
- let (cx, active_thread, _, thread, model) =
- setup_test_environment(cx, project.clone()).await;
- cx.update(|_, cx| {
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.set_default_model(
- Some(ConfiguredModel {
- provider: Arc::new(FakeLanguageModelProvider),
- model,
- }),
- cx,
- );
- });
- });
-
- let creases = vec![MessageCrease {
- range: 14..22,
- icon_path: "icon".into(),
- label: "foo.txt".into(),
- context: None,
- }];
-
- let message = thread.update(cx, |thread, cx| {
- let message_id = thread.insert_user_message(
- "Tell me about @foo.txt",
- ContextLoadResult::default(),
- None,
- creases,
- cx,
- );
- thread.message(message_id).cloned().unwrap()
- });
-
- active_thread.update_in(cx, |active_thread, window, cx| {
- if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
- active_thread.start_editing_message(
- message.id,
- message_text,
- message.creases.as_slice(),
- window,
- cx,
- );
- }
- let editor = active_thread
- .editing_message
- .as_ref()
- .unwrap()
- .1
- .editor
- .clone();
- editor.update(cx, |editor, cx| editor.edit([(0..13, "modified")], cx));
- active_thread.confirm_editing_message(&Default::default(), window, cx);
- });
- cx.run_until_parked();
-
- let message = thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
- active_thread.update_in(cx, |active_thread, window, cx| {
- if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
- active_thread.start_editing_message(
- message.id,
- message_text,
- message.creases.as_slice(),
- window,
- cx,
- );
- }
- let editor = active_thread
- .editing_message
- .as_ref()
- .unwrap()
- .1
- .editor
- .clone();
- let text = editor.update(cx, |editor, cx| editor.text(cx));
- assert_eq!(text, "modified @foo.txt");
- });
- }
-
- #[gpui::test]
- async fn test_editing_message_cancels_previous_completion(cx: &mut TestAppContext) {
- init_test_settings(cx);
-
- let project = create_test_project(cx, json!({})).await;
-
- let (cx, active_thread, _, thread, model) =
- setup_test_environment(cx, project.clone()).await;
-
- cx.update(|_, cx| {
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.set_default_model(
- Some(ConfiguredModel {
- provider: Arc::new(FakeLanguageModelProvider),
- model: model.clone(),
- }),
- cx,
- );
- });
- });
-
- // Track thread events to verify cancellation
- let cancellation_events = Arc::new(std::sync::Mutex::new(Vec::new()));
- let new_request_events = Arc::new(std::sync::Mutex::new(Vec::new()));
-
- let _subscription = cx.update(|_, cx| {
- let cancellation_events = cancellation_events.clone();
- let new_request_events = new_request_events.clone();
- cx.subscribe(
- &thread,
- move |_thread, event: &ThreadEvent, _cx| match event {
- ThreadEvent::CompletionCanceled => {
- cancellation_events.lock().unwrap().push(());
- }
- ThreadEvent::NewRequest => {
- new_request_events.lock().unwrap().push(());
- }
- _ => {}
- },
- )
- });
-
- // Insert a user message and start streaming a response
- let message = thread.update(cx, |thread, cx| {
- let message_id = thread.insert_user_message(
- "Hello, how are you?",
- ContextLoadResult::default(),
- None,
- vec![],
- cx,
- );
- thread.advance_prompt_id();
- thread.send_to_model(
- model.clone(),
- CompletionIntent::UserPrompt,
- cx.active_window(),
- cx,
- );
- thread.message(message_id).cloned().unwrap()
- });
-
- cx.run_until_parked();
-
- // Verify that a completion is in progress
- assert!(cx.read(|cx| thread.read(cx).is_generating()));
- assert_eq!(new_request_events.lock().unwrap().len(), 1);
-
- // Edit the message while the completion is still running
- active_thread.update_in(cx, |active_thread, window, cx| {
- if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
- active_thread.start_editing_message(
- message.id,
- message_text,
- message.creases.as_slice(),
- window,
- cx,
- );
- }
- let editor = active_thread
- .editing_message
- .as_ref()
- .unwrap()
- .1
- .editor
- .clone();
- editor.update(cx, |editor, cx| {
- editor.set_text("What is the weather like?", window, cx);
- });
- active_thread.confirm_editing_message(&Default::default(), window, cx);
- });
-
- cx.run_until_parked();
-
- // Verify that the previous completion was cancelled
- assert_eq!(cancellation_events.lock().unwrap().len(), 1);
-
- // Verify that a new request was started after cancellation
- assert_eq!(new_request_events.lock().unwrap().len(), 2);
-
- // Verify that the edited message contains the new text
- let edited_message =
- thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
- match &edited_message.segments[0] {
- MessageSegment::Text(text) => {
- assert_eq!(text, "What is the weather like?");
- }
- _ => panic!("Expected text segment"),
- }
- }
-
- fn init_test_settings(cx: &mut TestAppContext) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- language::init(cx);
- Project::init_settings(cx);
- AgentSettings::register(cx);
- prompt_store::init(cx);
- thread_store::init(cx);
- workspace::init_settings(cx);
- language_model::init_settings(cx);
- ThemeSettings::register(cx);
- EditorSettings::register(cx);
- ToolRegistry::default_global(cx);
- });
- }
-
- // Helper to create a test project with test files
- async fn create_test_project(
- cx: &mut TestAppContext,
- files: serde_json::Value,
- ) -> Entity<Project> {
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(path!("/test"), files).await;
- Project::test(fs, [path!("/test").as_ref()], cx).await
- }
-
- async fn setup_test_environment(
- cx: &mut TestAppContext,
- project: Entity<Project>,
- ) -> (
- &mut VisualTestContext,
- Entity<ActiveThread>,
- Entity<Workspace>,
- Entity<Thread>,
- Arc<dyn LanguageModel>,
- ) {
- let (workspace, cx) =
- cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
-
- let thread_store = cx
- .update(|_, cx| {
- ThreadStore::load(
- project.clone(),
- cx.new(|_| ToolWorkingSet::default()),
- None,
- Arc::new(PromptBuilder::new(None).unwrap()),
- cx,
- )
- })
- .await
- .unwrap();
-
- let text_thread_store = cx
- .update(|_, cx| {
- TextThreadStore::new(
- project.clone(),
- Arc::new(PromptBuilder::new(None).unwrap()),
- Default::default(),
- cx,
- )
- })
- .await
- .unwrap();
-
- let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
- let context_store =
- cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
-
- let model = FakeLanguageModel::default();
- let model: Arc<dyn LanguageModel> = Arc::new(model);
-
- let language_registry = LanguageRegistry::new(cx.executor());
- let language_registry = Arc::new(language_registry);
-
- let active_thread = cx.update(|window, cx| {
- cx.new(|cx| {
- ActiveThread::new(
- thread.clone(),
- thread_store.clone(),
- text_thread_store,
- context_store.clone(),
- language_registry.clone(),
- workspace.downgrade(),
- window,
- cx,
- )
- })
- });
-
- (cx, active_thread, workspace, thread, model)
- }
-}
+// #[cfg(test)]
+// mod tests {
+// use super::*;
+// use agent::{MessageSegment, context::ContextLoadResult, thread_store};
+// use assistant_tool::{ToolRegistry, ToolWorkingSet};
+// use editor::EditorSettings;
+// use fs::FakeFs;
+// use gpui::{AppContext, TestAppContext, VisualTestContext};
+// use language_model::{
+// ConfiguredModel, LanguageModel, LanguageModelRegistry,
+// fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
+// };
+// use project::Project;
+// use prompt_store::PromptBuilder;
+// use serde_json::json;
+// use settings::SettingsStore;
+// use util::path;
+// use workspace::CollaboratorId;
+
+// #[gpui::test]
+// async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) {
+// init_test_settings(cx);
+
+// let project = create_test_project(
+// cx,
+// json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
+// )
+// .await;
+
+// let (cx, _active_thread, workspace, thread, model) =
+// setup_test_environment(cx, project.clone()).await;
+
+// // Insert user message without any context (empty context vector)
+// thread.update(cx, |thread, cx| {
+// thread.insert_user_message(
+// "What is the best way to learn Rust?",
+// ContextLoadResult::default(),
+// None,
+// vec![],
+// cx,
+// );
+// });
+
+// // Stream response to user message
+// thread.update(cx, |thread, cx| {
+// let intent = CompletionIntent::UserPrompt;
+// let request = thread.to_completion_request(model.clone(), intent, cx);
+// thread.stream_completion(request, model, intent, cx.active_window(), cx)
+// });
+// // Follow the agent
+// cx.update(|window, cx| {
+// workspace.update(cx, |workspace, cx| {
+// workspace.follow(CollaboratorId::Agent, window, cx);
+// })
+// });
+// assert!(cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
+
+// // Cancel the current completion
+// thread.update(cx, |thread, cx| {
+// thread.cancel_last_completion(cx.active_window(), cx)
+// });
+
+// cx.executor().run_until_parked();
+
+// // No longer following the agent
+// assert!(!cx.read(|cx| workspace.read(cx).is_being_followed(CollaboratorId::Agent)));
+// }
+
+// #[gpui::test]
+// async fn test_reinserting_creases_for_edited_message(cx: &mut TestAppContext) {
+// init_test_settings(cx);
+
+// let project = create_test_project(cx, json!({})).await;
+
+// let (cx, active_thread, _, thread, model) =
+// setup_test_environment(cx, project.clone()).await;
+// cx.update(|_, cx| {
+// LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+// registry.set_default_model(
+// Some(ConfiguredModel {
+// provider: Arc::new(FakeLanguageModelProvider),
+// model,
+// }),
+// cx,
+// );
+// });
+// });
+
+// let creases = vec![MessageCrease {
+// range: 14..22,
+// icon_path: "icon".into(),
+// label: "foo.txt".into(),
+// context: None,
+// }];
+
+// let message = thread.update(cx, |thread, cx| {
+// let message_id = thread.insert_user_message(
+// "Tell me about @foo.txt",
+// ContextLoadResult::default(),
+// None,
+// creases,
+// cx,
+// );
+// thread.message(message_id).cloned().unwrap()
+// });
+
+// active_thread.update_in(cx, |active_thread, window, cx| {
+// if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
+// active_thread.start_editing_message(
+// message.id,
+// message_text,
+// message.creases.as_slice(),
+// window,
+// cx,
+// );
+// }
+// let editor = active_thread
+// .editing_message
+// .as_ref()
+// .unwrap()
+// .1
+// .editor
+// .clone();
+// editor.update(cx, |editor, cx| editor.edit([(0..13, "modified")], cx));
+// active_thread.confirm_editing_message(&Default::default(), window, cx);
+// });
+// cx.run_until_parked();
+
+// let message = thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
+// active_thread.update_in(cx, |active_thread, window, cx| {
+// if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
+// active_thread.start_editing_message(
+// message.id,
+// message_text,
+// message.creases.as_slice(),
+// window,
+// cx,
+// );
+// }
+// let editor = active_thread
+// .editing_message
+// .as_ref()
+// .unwrap()
+// .1
+// .editor
+// .clone();
+// let text = editor.update(cx, |editor, cx| editor.text(cx));
+// assert_eq!(text, "modified @foo.txt");
+// });
+// }
+
+// #[gpui::test]
+// async fn test_editing_message_cancels_previous_completion(cx: &mut TestAppContext) {
+// init_test_settings(cx);
+
+// let project = create_test_project(cx, json!({})).await;
+
+// let (cx, active_thread, _, thread, model) =
+// setup_test_environment(cx, project.clone()).await;
+
+// cx.update(|_, cx| {
+// LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+// registry.set_default_model(
+// Some(ConfiguredModel {
+// provider: Arc::new(FakeLanguageModelProvider),
+// model: model.clone(),
+// }),
+// cx,
+// );
+// });
+// });
+
+// // Track thread events to verify cancellation
+// let cancellation_events = Arc::new(std::sync::Mutex::new(Vec::new()));
+// let new_request_events = Arc::new(std::sync::Mutex::new(Vec::new()));
+
+// let _subscription = cx.update(|_, cx| {
+// let cancellation_events = cancellation_events.clone();
+// let new_request_events = new_request_events.clone();
+// cx.subscribe(
+// &thread,
+// move |_thread, event: &ThreadEvent, _cx| match event {
+// ThreadEvent::CompletionCanceled => {
+// cancellation_events.lock().unwrap().push(());
+// }
+// ThreadEvent::NewRequest => {
+// new_request_events.lock().unwrap().push(());
+// }
+// _ => {}
+// },
+// )
+// });
+
+// // Insert a user message and start streaming a response
+// let message = thread.update(cx, |thread, cx| {
+// let message_id = thread.insert_user_message(
+// "Hello, how are you?",
+// ContextLoadResult::default(),
+// None,
+// vec![],
+// cx,
+// );
+// thread.advance_prompt_id();
+// thread.send_to_model(
+// model.clone(),
+// CompletionIntent::UserPrompt,
+// cx.active_window(),
+// cx,
+// );
+// thread.message(message_id).cloned().unwrap()
+// });
+
+// cx.run_until_parked();
+
+// // Verify that a completion is in progress
+// assert!(cx.read(|cx| thread.read(cx).is_generating()));
+// assert_eq!(new_request_events.lock().unwrap().len(), 1);
+
+// // Edit the message while the completion is still running
+// active_thread.update_in(cx, |active_thread, window, cx| {
+// if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
+// active_thread.start_editing_message(
+// message.id,
+// message_text,
+// message.creases.as_slice(),
+// window,
+// cx,
+// );
+// }
+// let editor = active_thread
+// .editing_message
+// .as_ref()
+// .unwrap()
+// .1
+// .editor
+// .clone();
+// editor.update(cx, |editor, cx| {
+// editor.set_text("What is the weather like?", window, cx);
+// });
+// active_thread.confirm_editing_message(&Default::default(), window, cx);
+// });
+
+// cx.run_until_parked();
+
+// // Verify that the previous completion was cancelled
+// assert_eq!(cancellation_events.lock().unwrap().len(), 1);
+
+// // Verify that a new request was started after cancellation
+// assert_eq!(new_request_events.lock().unwrap().len(), 2);
+
+// // Verify that the edited message contains the new text
+// let edited_message =
+// thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
+// match &edited_message.segments[0] {
+// MessageSegment::Text(text) => {
+// assert_eq!(text, "What is the weather like?");
+// }
+// _ => panic!("Expected text segment"),
+// }
+// }
+
+// fn init_test_settings(cx: &mut TestAppContext) {
+// cx.update(|cx| {
+// let settings_store = SettingsStore::test(cx);
+// cx.set_global(settings_store);
+// language::init(cx);
+// Project::init_settings(cx);
+// AgentSettings::register(cx);
+// prompt_store::init(cx);
+// thread_store::init(cx);
+// workspace::init_settings(cx);
+// language_model::init_settings(cx);
+// ThemeSettings::register(cx);
+// EditorSettings::register(cx);
+// ToolRegistry::default_global(cx);
+// });
+// }
+
+// // Helper to create a test project with test files
+// async fn create_test_project(
+// cx: &mut TestAppContext,
+// files: serde_json::Value,
+// ) -> Entity<Project> {
+// let fs = FakeFs::new(cx.executor());
+// fs.insert_tree(path!("/test"), files).await;
+// Project::test(fs, [path!("/test").as_ref()], cx).await
+// }
+
+// async fn setup_test_environment(
+// cx: &mut TestAppContext,
+// project: Entity<Project>,
+// ) -> (
+// &mut VisualTestContext,
+// Entity<ActiveThread>,
+// Entity<Workspace>,
+// Entity<Thread>,
+// Arc<dyn LanguageModel>,
+// ) {
+// let (workspace, cx) =
+// cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+// let thread_store = cx
+// .update(|_, cx| {
+// ThreadStore::load(
+// project.clone(),
+// cx.new(|_| ToolWorkingSet::default()),
+// None,
+// Arc::new(PromptBuilder::new(None).unwrap()),
+// cx,
+// )
+// })
+// .await
+// .unwrap();
+
+// let text_thread_store = cx
+// .update(|_, cx| {
+// TextThreadStore::new(
+// project.clone(),
+// Arc::new(PromptBuilder::new(None).unwrap()),
+// Default::default(),
+// cx,
+// )
+// })
+// .await
+// .unwrap();
+
+// let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
+// let context_store =
+// cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade())));
+
+// let model = FakeLanguageModel::default();
+// let model: Arc<dyn LanguageModel> = Arc::new(model);
+
+// let language_registry = LanguageRegistry::new(cx.executor());
+// let language_registry = Arc::new(language_registry);
+
+// let active_thread = cx.update(|window, cx| {
+// cx.new(|cx| {
+// ActiveThread::new(
+// thread.clone(),
+// thread_store.clone(),
+// text_thread_store,
+// context_store.clone(),
+// language_registry.clone(),
+// workspace.downgrade(),
+// window,
+// cx,
+// )
+// })
+// });
+
+// (cx, active_thread, workspace, thread, model)
+// }
+// }
@@ -211,7 +211,7 @@ impl AgentDiffPane {
}
fn update_title(&mut self, cx: &mut Context<Self>) {
- let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes");
+ let new_title = self.thread.read(cx).title().unwrap_or("Agent Changes");
if new_title != self.title {
self.title = new_title;
cx.emit(EditorEvent::TitleChanged);
@@ -461,7 +461,7 @@ impl Item for AgentDiffPane {
}
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
- let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes");
+ let summary = self.thread.read(cx).title().unwrap_or("Agent Changes");
Label::new(format!("Review: {}", summary))
.color(if params.selected {
Color::Default
@@ -1369,8 +1369,6 @@ impl AgentDiff {
| ThreadEvent::MessageDeleted(_)
| ThreadEvent::SummaryGenerated
| ThreadEvent::SummaryChanged
- | ThreadEvent::UsePendingTools { .. }
- | ThreadEvent::ToolFinished { .. }
| ThreadEvent::CheckpointChanged
| ThreadEvent::ToolConfirmationNeeded
| ThreadEvent::ToolUseLimitReached
@@ -1801,7 +1799,10 @@ mod tests {
})
.await
.unwrap();
- let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
+ let thread = thread_store
+ .update(cx, |store, cx| store.create_thread(cx))
+ .await
+ .unwrap();
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let (workspace, cx) =
@@ -1966,7 +1967,10 @@ mod tests {
})
.await
.unwrap();
- let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
+ let thread = thread_store
+ .update(cx, |store, cx| store.create_thread(cx))
+ .await
+ .unwrap();
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let (workspace, cx) =
@@ -45,7 +45,7 @@ impl AgentModelSelector {
let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id())
{
- thread.set_configured_model(
+ thread.set_model(
Some(ConfiguredModel {
provider,
model: model.clone(),
@@ -26,7 +26,7 @@ use crate::{
ui::AgentOnboardingModal,
};
use agent::{
- Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
+ Thread, ThreadError, ThreadEvent, ThreadId, ThreadTitle, TokenUsageRatio,
context_store::ContextStore,
history_store::{HistoryEntryId, HistoryStore},
thread_store::{TextThreadStore, ThreadStore},
@@ -72,7 +72,7 @@ use zed_actions::{
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding},
assistant::{OpenRulesLibrary, ToggleFocus},
};
-use zed_llm_client::{CompletionIntent, UsageLimit};
+use zed_llm_client::UsageLimit;
const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -252,7 +252,7 @@ impl ActiveView {
thread.update(cx, |thread, cx| {
thread.thread().update(cx, |thread, cx| {
- thread.set_summary(new_summary, cx);
+ thread.set_title(new_summary, cx);
});
})
}
@@ -278,7 +278,7 @@ impl ActiveView {
let editor = editor.clone();
move |_, thread, event, window, cx| match event {
ThreadEvent::SummaryGenerated => {
- let summary = thread.read(cx).summary().or_default();
+ let summary = thread.read(cx).title().or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
@@ -492,10 +492,15 @@ impl AgentPanel {
None
};
+ let thread = thread_store
+ .update(cx, |this, cx| this.create_thread(cx))?
+ .await?;
+
let panel = workspace.update_in(cx, |workspace, window, cx| {
let panel = cx.new(|cx| {
Self::new(
workspace,
+ thread,
thread_store,
context_store,
prompt_store,
@@ -518,13 +523,13 @@ impl AgentPanel {
fn new(
workspace: &Workspace,
+ thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
context_store: Entity<TextThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
- let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone();
let project = workspace.project();
@@ -647,11 +652,12 @@ impl AgentPanel {
|this, _, event: &language_model::Event, cx| match event {
language_model::Event::DefaultModelChanged => match &this.active_view {
ActiveView::Thread { thread, .. } => {
- thread
- .read(cx)
- .thread()
- .clone()
- .update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
+ // todo!(do we need this?);
+ // thread
+ // .read(cx)
+ // .thread()
+ // .clone()
+ // .update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
}
ActiveView::TextThread { .. }
| ActiveView::History
@@ -784,46 +790,61 @@ impl AgentPanel {
.detach_and_log_err(cx);
}
- let active_thread = cx.new(|cx| {
- ActiveThread::new(
- thread.clone(),
- self.thread_store.clone(),
- self.context_store.clone(),
- context_store.clone(),
- self.language_registry.clone(),
- self.workspace.clone(),
- window,
- cx,
- )
- });
+ let fs = self.fs.clone();
+ let user_store = self.user_store.clone();
+ let thread_store = self.thread_store.clone();
+ let text_thread_store = self.context_store.clone();
+ let prompt_store = self.prompt_store.clone();
+ let language_registry = self.language_registry.clone();
+ let workspace = self.workspace.clone();
- let message_editor = cx.new(|cx| {
- MessageEditor::new(
- self.fs.clone(),
- self.workspace.clone(),
- self.user_store.clone(),
- context_store.clone(),
- self.prompt_store.clone(),
- self.thread_store.downgrade(),
- self.context_store.downgrade(),
- thread.clone(),
- window,
- cx,
- )
- });
+ cx.spawn_in(window, async move |this, cx| {
+ let thread = thread.await?;
+ let active_thread = cx.new_window_entity(|window, cx| {
+ ActiveThread::new(
+ thread.clone(),
+ thread_store.clone(),
+ text_thread_store.clone(),
+ context_store.clone(),
+ language_registry.clone(),
+ workspace.clone(),
+ window,
+ cx,
+ )
+ })?;
- if let Some(text) = preserved_text {
- message_editor.update(cx, |editor, cx| {
- editor.set_text(text, window, cx);
- });
- }
+ let message_editor = cx.new_window_entity(|window, cx| {
+ MessageEditor::new(
+ fs.clone(),
+ workspace.clone(),
+ user_store.clone(),
+ context_store.clone(),
+ prompt_store.clone(),
+ thread_store.downgrade(),
+ text_thread_store.downgrade(),
+ thread.clone(),
+ window,
+ cx,
+ )
+ })?;
- message_editor.focus_handle(cx).focus(window);
+ if let Some(text) = preserved_text {
+ message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text(text, window, cx);
+ });
+ }
- let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx);
- self.set_active_view(thread_view, window, cx);
+ this.update_in(cx, |this, window, cx| {
+ message_editor.focus_handle(cx).focus(window);
- AgentDiff::set_active_thread(&self.workspace, &thread, window, cx);
+ let thread_view =
+ ActiveView::thread(active_thread.clone(), message_editor, window, cx);
+ this.set_active_view(thread_view, window, cx);
+
+ AgentDiff::set_active_thread(&this.workspace, &thread, window, cx);
+ })
+ })
+ .detach_and_log_err(cx);
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -1254,23 +1275,11 @@ impl AgentPanel {
return;
}
- let model = thread_state.configured_model().map(|cm| cm.model.clone());
- if let Some(model) = model {
- thread.update(cx, |active_thread, cx| {
- active_thread.thread().update(cx, |thread, cx| {
- thread.insert_invisible_continue_message(cx);
- thread.advance_prompt_id();
- thread.send_to_model(
- model,
- CompletionIntent::UserPrompt,
- Some(window.window_handle()),
- cx,
- );
- });
- });
- } else {
- log::warn!("No configured model available for continuation");
- }
+ thread.update(cx, |active_thread, cx| {
+ active_thread
+ .thread()
+ .update(cx, |thread, cx| thread.resume(window, cx))
+ });
}
fn toggle_burn_mode(
@@ -1552,24 +1561,24 @@ impl AgentPanel {
let state = {
let active_thread = active_thread.read(cx);
if active_thread.is_empty() {
- &ThreadSummary::Pending
+ &ThreadTitle::Pending
} else {
active_thread.summary(cx)
}
};
match state {
- ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone())
+ ThreadTitle::Pending => Label::new(ThreadTitle::DEFAULT.clone())
.truncate()
.into_any_element(),
- ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER)
+ ThreadTitle::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER)
.truncate()
.into_any_element(),
- ThreadSummary::Ready(_) => div()
+ ThreadTitle::Ready(_) => div()
.w_full()
.child(change_title_editor.clone())
.into_any_element(),
- ThreadSummary::Error => h_flex()
+ ThreadTitle::Error => h_flex()
.w_full()
.child(change_title_editor.clone())
.child(
@@ -2024,7 +2033,7 @@ impl AgentPanel {
.read(cx)
.thread()
.read(cx)
- .configured_model()
+ .model()
.map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
});
@@ -2629,7 +2638,7 @@ impl AgentPanel {
return None;
}
- let model = thread.configured_model()?.model;
+ let model = thread.model()?.model;
let focus_handle = self.focus_handle(cx);
@@ -121,7 +121,7 @@ pub(crate) enum ModelUsageContext {
impl ModelUsageContext {
pub fn configured_model(&self, cx: &App) -> Option<ConfiguredModel> {
match self {
- Self::Thread(thread) => thread.read(cx).configured_model(),
+ Self::Thread(thread) => thread.read(cx).model(),
Self::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
@@ -670,7 +670,7 @@ fn recent_context_picker_entries(
let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx)
.filter(|(_, thread)| match thread {
ThreadContextEntry::Thread { id, .. } => {
- Some(id) != active_thread_id && !current_threads.contains(id)
+ Some(id) != active_thread_id.as_ref() && !current_threads.contains(id)
}
ThreadContextEntry::Context { .. } => true,
})
@@ -169,13 +169,13 @@ impl ContextStrip {
if self
.context_store
.read(cx)
- .includes_thread(active_thread.id())
+ .includes_thread(&active_thread.id())
{
return None;
}
Some(SuggestedContext::Thread {
- name: active_thread.summary().or_default(),
+ name: active_thread.title().or_default(),
thread: weak_active_thread,
})
} else if let Some(active_context_editor) = panel.active_context_editor() {
@@ -156,7 +156,7 @@ impl Render for ProfileSelector {
.map(|profile| profile.name.clone())
.unwrap_or_else(|| "Unknown".into());
- let configured_model = self.thread.read(cx).configured_model().or_else(|| {
+ let configured_model = self.thread.read(cx).model().or_else(|| {
let model_registry = LanguageModelRegistry::read_global(cx);
model_registry.default_model()
});