update_plan_tool.rs

  1use crate::{AgentTool, ToolCallEventStream, ToolInput};
  2use agent_client_protocol as acp;
  3use gpui::{App, SharedString, Task};
  4use schemars::JsonSchema;
  5use serde::{Deserialize, Serialize};
  6use std::sync::Arc;
  7
  8#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
  9#[serde(rename_all = "snake_case")]
 10#[schemars(inline)]
 11pub enum PlanEntryStatus {
 12    /// The task has not started yet.
 13    Pending,
 14    /// The task is currently being worked on.
 15    InProgress,
 16    /// The task has been successfully completed.
 17    Completed,
 18}
 19
 20impl From<PlanEntryStatus> for acp::PlanEntryStatus {
 21    fn from(value: PlanEntryStatus) -> Self {
 22        match value {
 23            PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
 24            PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
 25            PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
 26        }
 27    }
 28}
 29
 30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
 31#[serde(rename_all = "snake_case")]
 32#[schemars(inline)]
 33pub enum PlanEntryPriority {
 34    High,
 35    #[default]
 36    Medium,
 37    Low,
 38}
 39
 40impl From<PlanEntryPriority> for acp::PlanEntryPriority {
 41    fn from(value: PlanEntryPriority) -> Self {
 42        match value {
 43            PlanEntryPriority::High => acp::PlanEntryPriority::High,
 44            PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
 45            PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
 46        }
 47    }
 48}
 49
 50#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
 51pub struct PlanItem {
 52    /// Human-readable description of what this task aims to accomplish.
 53    pub step: String,
 54    /// The current status of this task.
 55    pub status: PlanEntryStatus,
 56    /// The relative importance of this task. Defaults to medium when omitted.
 57    #[serde(default)]
 58    pub priority: PlanEntryPriority,
 59}
 60
 61impl From<PlanItem> for acp::PlanEntry {
 62    fn from(value: PlanItem) -> Self {
 63        acp::PlanEntry::new(value.step, value.priority.into(), value.status.into())
 64    }
 65}
 66
 67/// Updates the task plan.
 68/// Provide a list of plan entries, each with step, status, and optional priority.
 69#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
 70pub struct UpdatePlanToolInput {
 71    /// The list of plan entries and their current statuses.
 72    pub plan: Vec<PlanItem>,
 73}
 74
 75pub struct UpdatePlanTool;
 76
 77impl UpdatePlanTool {
 78    fn to_plan(input: UpdatePlanToolInput) -> acp::Plan {
 79        acp::Plan::new(input.plan.into_iter().map(Into::into).collect())
 80    }
 81}
 82
 83impl AgentTool for UpdatePlanTool {
 84    type Input = UpdatePlanToolInput;
 85    type Output = String;
 86
 87    const NAME: &'static str = "update_plan";
 88
 89    fn kind() -> acp::ToolKind {
 90        acp::ToolKind::Think
 91    }
 92
 93    fn initial_title(
 94        &self,
 95        input: Result<Self::Input, serde_json::Value>,
 96        _cx: &mut App,
 97    ) -> SharedString {
 98        match input {
 99            Ok(input) if input.plan.is_empty() => "Clear plan".into(),
100            Ok(_) | Err(_) => "Update plan".into(),
101        }
102    }
103
104    fn run(
105        self: Arc<Self>,
106        input: ToolInput<Self::Input>,
107        event_stream: ToolCallEventStream,
108        cx: &mut App,
109    ) -> Task<Result<Self::Output, Self::Output>> {
110        cx.spawn(async move |_cx| {
111            let input = input
112                .recv()
113                .await
114                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
115
116            event_stream.update_plan(Self::to_plan(input));
117
118            Ok("Plan updated".to_string())
119        })
120    }
121
122    fn replay(
123        &self,
124        input: Self::Input,
125        _output: Self::Output,
126        event_stream: ToolCallEventStream,
127        _cx: &mut App,
128    ) -> anyhow::Result<()> {
129        event_stream.update_plan(Self::to_plan(input));
130        Ok(())
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::ToolCallEventStream;
138    use gpui::TestAppContext;
139    use pretty_assertions::assert_eq;
140
141    fn sample_input() -> UpdatePlanToolInput {
142        UpdatePlanToolInput {
143            plan: vec![
144                PlanItem {
145                    step: "Inspect the existing tool wiring".to_string(),
146                    status: PlanEntryStatus::Completed,
147                    priority: PlanEntryPriority::High,
148                },
149                PlanItem {
150                    step: "Implement the update_plan tool".to_string(),
151                    status: PlanEntryStatus::InProgress,
152                    priority: PlanEntryPriority::Medium,
153                },
154                PlanItem {
155                    step: "Add tests".to_string(),
156                    status: PlanEntryStatus::Pending,
157                    priority: PlanEntryPriority::Low,
158                },
159            ],
160        }
161    }
162
163    #[gpui::test]
164    async fn test_run_emits_plan_event(cx: &mut TestAppContext) {
165        let tool = Arc::new(UpdatePlanTool);
166        let (event_stream, mut event_rx) = ToolCallEventStream::test();
167
168        let input = sample_input();
169        let result = cx
170            .update(|cx| tool.run(ToolInput::resolved(input.clone()), event_stream, cx))
171            .await
172            .expect("tool should succeed");
173
174        assert_eq!(result, "Plan updated".to_string());
175
176        let plan = event_rx.expect_plan().await;
177        assert_eq!(
178            plan,
179            acp::Plan::new(vec![
180                acp::PlanEntry::new(
181                    "Inspect the existing tool wiring",
182                    acp::PlanEntryPriority::High,
183                    acp::PlanEntryStatus::Completed,
184                ),
185                acp::PlanEntry::new(
186                    "Implement the update_plan tool",
187                    acp::PlanEntryPriority::Medium,
188                    acp::PlanEntryStatus::InProgress,
189                ),
190                acp::PlanEntry::new(
191                    "Add tests",
192                    acp::PlanEntryPriority::Low,
193                    acp::PlanEntryStatus::Pending,
194                ),
195            ])
196        );
197    }
198
199    #[gpui::test]
200    async fn test_replay_emits_plan_event(cx: &mut TestAppContext) {
201        let tool = UpdatePlanTool;
202        let (event_stream, mut event_rx) = ToolCallEventStream::test();
203
204        let input = sample_input();
205
206        cx.update(|cx| {
207            tool.replay(input.clone(), "Plan updated".to_string(), event_stream, cx)
208                .expect("replay should succeed");
209        });
210
211        let plan = event_rx.expect_plan().await;
212        assert_eq!(
213            plan,
214            acp::Plan::new(vec![
215                acp::PlanEntry::new(
216                    "Inspect the existing tool wiring",
217                    acp::PlanEntryPriority::High,
218                    acp::PlanEntryStatus::Completed,
219                ),
220                acp::PlanEntry::new(
221                    "Implement the update_plan tool",
222                    acp::PlanEntryPriority::Medium,
223                    acp::PlanEntryStatus::InProgress,
224                ),
225                acp::PlanEntry::new(
226                    "Add tests",
227                    acp::PlanEntryPriority::Low,
228                    acp::PlanEntryStatus::Pending,
229                ),
230            ])
231        );
232    }
233
234    #[gpui::test]
235    async fn test_run_defaults_priority_to_medium(cx: &mut TestAppContext) {
236        let tool = Arc::new(UpdatePlanTool);
237        let (event_stream, mut event_rx) = ToolCallEventStream::test();
238
239        let input = UpdatePlanToolInput {
240            plan: vec![
241                PlanItem {
242                    step: "First".to_string(),
243                    status: PlanEntryStatus::InProgress,
244                    priority: PlanEntryPriority::default(),
245                },
246                PlanItem {
247                    step: "Second".to_string(),
248                    status: PlanEntryStatus::InProgress,
249                    priority: PlanEntryPriority::default(),
250                },
251            ],
252        };
253
254        let result = cx
255            .update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx))
256            .await
257            .expect("tool should succeed");
258
259        assert_eq!(result, "Plan updated".to_string());
260
261        let plan = event_rx.expect_plan().await;
262        assert_eq!(
263            plan,
264            acp::Plan::new(vec![
265                acp::PlanEntry::new(
266                    "First",
267                    acp::PlanEntryPriority::Medium,
268                    acp::PlanEntryStatus::InProgress,
269                ),
270                acp::PlanEntry::new(
271                    "Second",
272                    acp::PlanEntryPriority::Medium,
273                    acp::PlanEntryStatus::InProgress,
274                ),
275            ])
276        );
277    }
278
279    #[gpui::test]
280    async fn test_initial_title(cx: &mut TestAppContext) {
281        let tool = UpdatePlanTool;
282
283        let title = cx.update(|cx| tool.initial_title(Ok(sample_input()), cx));
284        assert_eq!(title, SharedString::from("Update plan"));
285
286        let title =
287            cx.update(|cx| tool.initial_title(Ok(UpdatePlanToolInput { plan: Vec::new() }), cx));
288        assert_eq!(title, SharedString::from("Clear plan"));
289    }
290}