1mod mcp_server;
2pub mod tools;
3
4use collections::HashMap;
5use context_server::listener::McpServerTool;
6use language_models::provider::anthropic::AnthropicLanguageModelProvider;
7use project::Project;
8use settings::SettingsStore;
9use smol::process::Child;
10use std::any::Any;
11use std::cell::RefCell;
12use std::fmt::Display;
13use std::path::Path;
14use std::rc::Rc;
15use uuid::Uuid;
16
17use agent_client_protocol as acp;
18use anyhow::{Context as _, Result, anyhow};
19use futures::channel::oneshot;
20use futures::{AsyncBufReadExt, AsyncWriteExt};
21use futures::{
22 AsyncRead, AsyncWrite, FutureExt, StreamExt,
23 channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
24 io::BufReader,
25 select_biased,
26};
27use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
28use serde::{Deserialize, Serialize};
29use util::{ResultExt, debug_panic};
30
31use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
32use crate::claude::tools::ClaudeTool;
33use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
34use acp_thread::{AcpThread, AgentConnection, AuthRequired};
35
36#[derive(Clone)]
37pub struct ClaudeCode;
38
39impl AgentServer for ClaudeCode {
40 fn name(&self) -> &'static str {
41 "Claude Code"
42 }
43
44 fn empty_state_headline(&self) -> &'static str {
45 self.name()
46 }
47
48 fn empty_state_message(&self) -> &'static str {
49 "How can I help you today?"
50 }
51
52 fn logo(&self) -> ui::IconName {
53 ui::IconName::AiClaude
54 }
55
56 fn connect(
57 &self,
58 _root_dir: &Path,
59 _project: &Entity<Project>,
60 _cx: &mut App,
61 ) -> Task<Result<Rc<dyn AgentConnection>>> {
62 let connection = ClaudeAgentConnection {
63 sessions: Default::default(),
64 };
65
66 Task::ready(Ok(Rc::new(connection) as _))
67 }
68}
69
70struct ClaudeAgentConnection {
71 sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
72}
73
74impl AgentConnection for ClaudeAgentConnection {
75 fn new_thread(
76 self: Rc<Self>,
77 project: Entity<Project>,
78 cwd: &Path,
79 cx: &mut App,
80 ) -> Task<Result<Entity<AcpThread>>> {
81 let cwd = cwd.to_owned();
82 cx.spawn(async move |cx| {
83 let settings = cx.read_global(|settings: &SettingsStore, _| {
84 settings.get::<AllAgentServersSettings>(None).claude.clone()
85 })?;
86
87 let Some(command) = AgentServerCommand::resolve(
88 "claude",
89 &[],
90 Some(&util::paths::home_dir().join(".claude/local/claude")),
91 settings,
92 &project,
93 cx,
94 )
95 .await
96 else {
97 anyhow::bail!("Failed to find claude binary");
98 };
99
100 let api_key =
101 cx.update(AnthropicLanguageModelProvider::api_key)?
102 .await
103 .map_err(|err| {
104 if err.is::<language_model::AuthenticateError>() {
105 anyhow!(AuthRequired::new().with_language_model_provider(
106 language_model::ANTHROPIC_PROVIDER_ID
107 ))
108 } else {
109 anyhow!(err)
110 }
111 })?;
112
113 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
114 let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
115 let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
116
117 let mut mcp_servers = HashMap::default();
118 mcp_servers.insert(
119 mcp_server::SERVER_NAME.to_string(),
120 permission_mcp_server.server_config()?,
121 );
122 let mcp_config = McpConfig { mcp_servers };
123
124 let mcp_config_file = tempfile::NamedTempFile::new()?;
125 let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
126
127 let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
128 mcp_config_file
129 .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
130 .await?;
131 mcp_config_file.flush().await?;
132
133 let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
134 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
135
136 let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
137
138 log::trace!("Starting session with id: {}", session_id);
139
140 let mut child = spawn_claude(
141 &command,
142 ClaudeSessionMode::Start,
143 session_id.clone(),
144 api_key,
145 &mcp_config_path,
146 &cwd,
147 )?;
148
149 let stdout = child.stdout.take().context("Failed to take stdout")?;
150 let stdin = child.stdin.take().context("Failed to take stdin")?;
151 let stderr = child.stderr.take().context("Failed to take stderr")?;
152
153 let pid = child.id();
154 log::trace!("Spawned (pid: {})", pid);
155
156 cx.background_spawn(async move {
157 let mut stderr = BufReader::new(stderr);
158 let mut line = String::new();
159 while let Ok(n) = stderr.read_line(&mut line).await
160 && n > 0
161 {
162 log::warn!("agent stderr: {}", &line);
163 line.clear();
164 }
165 })
166 .detach();
167
168 cx.background_spawn(async move {
169 let mut outgoing_rx = Some(outgoing_rx);
170
171 ClaudeAgentSession::handle_io(
172 outgoing_rx.take().unwrap(),
173 incoming_message_tx.clone(),
174 stdin,
175 stdout,
176 )
177 .await?;
178
179 log::trace!("Stopped (pid: {})", pid);
180
181 drop(mcp_config_path);
182 anyhow::Ok(())
183 })
184 .detach();
185
186 let turn_state = Rc::new(RefCell::new(TurnState::None));
187
188 let handler_task = cx.spawn({
189 let turn_state = turn_state.clone();
190 let mut thread_rx = thread_rx.clone();
191 async move |cx| {
192 while let Some(message) = incoming_message_rx.next().await {
193 ClaudeAgentSession::handle_message(
194 thread_rx.clone(),
195 message,
196 turn_state.clone(),
197 cx,
198 )
199 .await
200 }
201
202 if let Some(status) = child.status().await.log_err() {
203 if let Some(thread) = thread_rx.recv().await.ok() {
204 thread
205 .update(cx, |thread, cx| {
206 thread.emit_server_exited(status, cx);
207 })
208 .ok();
209 }
210 }
211 }
212 });
213
214 let thread = cx.new(|cx| {
215 AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
216 })?;
217
218 thread_tx.send(thread.downgrade())?;
219
220 let session = ClaudeAgentSession {
221 outgoing_tx,
222 turn_state,
223 _handler_task: handler_task,
224 _mcp_server: Some(permission_mcp_server),
225 };
226
227 self.sessions.borrow_mut().insert(session_id, session);
228
229 Ok(thread)
230 })
231 }
232
233 fn auth_methods(&self) -> &[acp::AuthMethod] {
234 &[]
235 }
236
237 fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
238 Task::ready(Err(anyhow!("Authentication not supported")))
239 }
240
241 fn prompt(
242 &self,
243 _id: Option<acp_thread::UserMessageId>,
244 params: acp::PromptRequest,
245 cx: &mut App,
246 ) -> Task<Result<acp::PromptResponse>> {
247 let sessions = self.sessions.borrow();
248 let Some(session) = sessions.get(¶ms.session_id) else {
249 return Task::ready(Err(anyhow!(
250 "Attempted to send message to nonexistent session {}",
251 params.session_id
252 )));
253 };
254
255 let (end_tx, end_rx) = oneshot::channel();
256 session.turn_state.replace(TurnState::InProgress { end_tx });
257
258 let mut content = String::new();
259 for chunk in params.prompt {
260 match chunk {
261 acp::ContentBlock::Text(text_content) => {
262 content.push_str(&text_content.text);
263 }
264 acp::ContentBlock::ResourceLink(resource_link) => {
265 content.push_str(&format!("@{}", resource_link.uri));
266 }
267 acp::ContentBlock::Audio(_)
268 | acp::ContentBlock::Image(_)
269 | acp::ContentBlock::Resource(_) => {
270 // TODO
271 }
272 }
273 }
274
275 if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
276 message: Message {
277 role: Role::User,
278 content: Content::UntaggedText(content),
279 id: None,
280 model: None,
281 stop_reason: None,
282 stop_sequence: None,
283 usage: None,
284 },
285 session_id: Some(params.session_id.to_string()),
286 }) {
287 return Task::ready(Err(anyhow!(err)));
288 }
289
290 cx.foreground_executor().spawn(async move { end_rx.await? })
291 }
292
293 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
294 let sessions = self.sessions.borrow();
295 let Some(session) = sessions.get(session_id) else {
296 log::warn!("Attempted to cancel nonexistent session {}", session_id);
297 return;
298 };
299
300 let request_id = new_request_id();
301
302 let turn_state = session.turn_state.take();
303 let TurnState::InProgress { end_tx } = turn_state else {
304 // Already canceled or idle, put it back
305 session.turn_state.replace(turn_state);
306 return;
307 };
308
309 session.turn_state.replace(TurnState::CancelRequested {
310 end_tx,
311 request_id: request_id.clone(),
312 });
313
314 session
315 .outgoing_tx
316 .unbounded_send(SdkMessage::ControlRequest {
317 request_id,
318 request: ControlRequest::Interrupt,
319 })
320 .log_err();
321 }
322
323 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
324 self
325 }
326}
327
328#[derive(Clone, Copy)]
329enum ClaudeSessionMode {
330 Start,
331 #[expect(dead_code)]
332 Resume,
333}
334
335fn spawn_claude(
336 command: &AgentServerCommand,
337 mode: ClaudeSessionMode,
338 session_id: acp::SessionId,
339 api_key: language_models::provider::anthropic::ApiKey,
340 mcp_config_path: &Path,
341 root_dir: &Path,
342) -> Result<Child> {
343 let child = util::command::new_smol_command(&command.path)
344 .args([
345 "--input-format",
346 "stream-json",
347 "--output-format",
348 "stream-json",
349 "--print",
350 "--verbose",
351 "--mcp-config",
352 mcp_config_path.to_string_lossy().as_ref(),
353 "--permission-prompt-tool",
354 &format!(
355 "mcp__{}__{}",
356 mcp_server::SERVER_NAME,
357 mcp_server::PermissionTool::NAME,
358 ),
359 "--allowedTools",
360 &format!(
361 "mcp__{}__{},mcp__{}__{}",
362 mcp_server::SERVER_NAME,
363 mcp_server::EditTool::NAME,
364 mcp_server::SERVER_NAME,
365 mcp_server::ReadTool::NAME
366 ),
367 "--disallowedTools",
368 "Read,Edit",
369 ])
370 .args(match mode {
371 ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
372 ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
373 })
374 .args(command.args.iter().map(|arg| arg.as_str()))
375 .envs(command.env.iter().flatten())
376 .env("ANTHROPIC_API_KEY", api_key.key)
377 .current_dir(root_dir)
378 .stdin(std::process::Stdio::piped())
379 .stdout(std::process::Stdio::piped())
380 .stderr(std::process::Stdio::piped())
381 .kill_on_drop(true)
382 .spawn()?;
383
384 Ok(child)
385}
386
387struct ClaudeAgentSession {
388 outgoing_tx: UnboundedSender<SdkMessage>,
389 turn_state: Rc<RefCell<TurnState>>,
390 _mcp_server: Option<ClaudeZedMcpServer>,
391 _handler_task: Task<()>,
392}
393
394#[derive(Debug, Default)]
395enum TurnState {
396 #[default]
397 None,
398 InProgress {
399 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
400 },
401 CancelRequested {
402 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
403 request_id: String,
404 },
405 CancelConfirmed {
406 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
407 },
408}
409
410impl TurnState {
411 fn is_canceled(&self) -> bool {
412 matches!(self, TurnState::CancelConfirmed { .. })
413 }
414
415 fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
416 match self {
417 TurnState::None => None,
418 TurnState::InProgress { end_tx, .. } => Some(end_tx),
419 TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
420 TurnState::CancelConfirmed { end_tx } => Some(end_tx),
421 }
422 }
423
424 fn confirm_cancellation(self, id: &str) -> Self {
425 match self {
426 TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
427 TurnState::CancelConfirmed { end_tx }
428 }
429 _ => self,
430 }
431 }
432}
433
434impl ClaudeAgentSession {
435 async fn handle_message(
436 mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
437 message: SdkMessage,
438 turn_state: Rc<RefCell<TurnState>>,
439 cx: &mut AsyncApp,
440 ) {
441 match message {
442 // we should only be sending these out, they don't need to be in the thread
443 SdkMessage::ControlRequest { .. } => {}
444 SdkMessage::User {
445 message,
446 session_id: _,
447 } => {
448 let Some(thread) = thread_rx
449 .recv()
450 .await
451 .log_err()
452 .and_then(|entity| entity.upgrade())
453 else {
454 log::error!("Received an SDK message but thread is gone");
455 return;
456 };
457
458 for chunk in message.content.chunks() {
459 match chunk {
460 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
461 if !turn_state.borrow().is_canceled() {
462 thread
463 .update(cx, |thread, cx| {
464 thread.push_user_content_block(None, text.into(), cx)
465 })
466 .log_err();
467 }
468 }
469 ContentChunk::ToolResult {
470 content,
471 tool_use_id,
472 } => {
473 let content = content.to_string();
474 thread
475 .update(cx, |thread, cx| {
476 thread.update_tool_call(
477 acp::ToolCallUpdate {
478 id: acp::ToolCallId(tool_use_id.into()),
479 fields: acp::ToolCallUpdateFields {
480 status: if turn_state.borrow().is_canceled() {
481 // Do not set to completed if turn was canceled
482 None
483 } else {
484 Some(acp::ToolCallStatus::Completed)
485 },
486 content: (!content.is_empty())
487 .then(|| vec![content.into()]),
488 ..Default::default()
489 },
490 },
491 cx,
492 )
493 })
494 .log_err();
495 }
496 ContentChunk::Thinking { .. }
497 | ContentChunk::RedactedThinking
498 | ContentChunk::ToolUse { .. } => {
499 debug_panic!(
500 "Should not get {:?} with role: assistant. should we handle this?",
501 chunk
502 );
503 }
504
505 ContentChunk::Image
506 | ContentChunk::Document
507 | ContentChunk::WebSearchToolResult => {
508 thread
509 .update(cx, |thread, cx| {
510 thread.push_assistant_content_block(
511 format!("Unsupported content: {:?}", chunk).into(),
512 false,
513 cx,
514 )
515 })
516 .log_err();
517 }
518 }
519 }
520 }
521 SdkMessage::Assistant {
522 message,
523 session_id: _,
524 } => {
525 let Some(thread) = thread_rx
526 .recv()
527 .await
528 .log_err()
529 .and_then(|entity| entity.upgrade())
530 else {
531 log::error!("Received an SDK message but thread is gone");
532 return;
533 };
534
535 for chunk in message.content.chunks() {
536 match chunk {
537 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
538 thread
539 .update(cx, |thread, cx| {
540 thread.push_assistant_content_block(text.into(), false, cx)
541 })
542 .log_err();
543 }
544 ContentChunk::Thinking { thinking } => {
545 thread
546 .update(cx, |thread, cx| {
547 thread.push_assistant_content_block(thinking.into(), true, cx)
548 })
549 .log_err();
550 }
551 ContentChunk::RedactedThinking => {
552 thread
553 .update(cx, |thread, cx| {
554 thread.push_assistant_content_block(
555 "[REDACTED]".into(),
556 true,
557 cx,
558 )
559 })
560 .log_err();
561 }
562 ContentChunk::ToolUse { id, name, input } => {
563 let claude_tool = ClaudeTool::infer(&name, input);
564
565 thread
566 .update(cx, |thread, cx| {
567 if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
568 thread.update_plan(
569 acp::Plan {
570 entries: params
571 .todos
572 .into_iter()
573 .map(Into::into)
574 .collect(),
575 },
576 cx,
577 )
578 } else {
579 thread.upsert_tool_call(
580 claude_tool.as_acp(acp::ToolCallId(id.into())),
581 cx,
582 )?;
583 }
584 anyhow::Ok(())
585 })
586 .log_err();
587 }
588 ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
589 debug_panic!(
590 "Should not get tool results with role: assistant. should we handle this?"
591 );
592 }
593 ContentChunk::Image | ContentChunk::Document => {
594 thread
595 .update(cx, |thread, cx| {
596 thread.push_assistant_content_block(
597 format!("Unsupported content: {:?}", chunk).into(),
598 false,
599 cx,
600 )
601 })
602 .log_err();
603 }
604 }
605 }
606 }
607 SdkMessage::Result {
608 is_error,
609 subtype,
610 result,
611 ..
612 } => {
613 let turn_state = turn_state.take();
614 let was_canceled = turn_state.is_canceled();
615 let Some(end_turn_tx) = turn_state.end_tx() else {
616 debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
617 return;
618 };
619
620 if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
621 end_turn_tx
622 .send(Err(anyhow!(
623 "Error: {}",
624 result.unwrap_or_else(|| subtype.to_string())
625 )))
626 .ok();
627 } else {
628 let stop_reason = match subtype {
629 ResultErrorType::Success => acp::StopReason::EndTurn,
630 ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
631 ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
632 };
633 end_turn_tx
634 .send(Ok(acp::PromptResponse { stop_reason }))
635 .ok();
636 }
637 }
638 SdkMessage::ControlResponse { response } => {
639 if matches!(response.subtype, ResultErrorType::Success) {
640 let new_state = turn_state.take().confirm_cancellation(&response.request_id);
641 turn_state.replace(new_state);
642 } else {
643 log::error!("Control response error: {:?}", response);
644 }
645 }
646 SdkMessage::System { .. } => {}
647 }
648 }
649
650 async fn handle_io(
651 mut outgoing_rx: UnboundedReceiver<SdkMessage>,
652 incoming_tx: UnboundedSender<SdkMessage>,
653 mut outgoing_bytes: impl Unpin + AsyncWrite,
654 incoming_bytes: impl Unpin + AsyncRead,
655 ) -> Result<UnboundedReceiver<SdkMessage>> {
656 let mut output_reader = BufReader::new(incoming_bytes);
657 let mut outgoing_line = Vec::new();
658 let mut incoming_line = String::new();
659 loop {
660 select_biased! {
661 message = outgoing_rx.next() => {
662 if let Some(message) = message {
663 outgoing_line.clear();
664 serde_json::to_writer(&mut outgoing_line, &message)?;
665 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
666 outgoing_line.push(b'\n');
667 outgoing_bytes.write_all(&outgoing_line).await.ok();
668 } else {
669 break;
670 }
671 }
672 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
673 if bytes_read? == 0 {
674 break
675 }
676 log::trace!("recv: {}", &incoming_line);
677 match serde_json::from_str::<SdkMessage>(&incoming_line) {
678 Ok(message) => {
679 incoming_tx.unbounded_send(message).log_err();
680 }
681 Err(error) => {
682 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
683 }
684 }
685 incoming_line.clear();
686 }
687 }
688 }
689
690 Ok(outgoing_rx)
691 }
692}
693
694#[derive(Debug, Clone, Serialize, Deserialize)]
695struct Message {
696 role: Role,
697 content: Content,
698 #[serde(skip_serializing_if = "Option::is_none")]
699 id: Option<String>,
700 #[serde(skip_serializing_if = "Option::is_none")]
701 model: Option<String>,
702 #[serde(skip_serializing_if = "Option::is_none")]
703 stop_reason: Option<String>,
704 #[serde(skip_serializing_if = "Option::is_none")]
705 stop_sequence: Option<String>,
706 #[serde(skip_serializing_if = "Option::is_none")]
707 usage: Option<Usage>,
708}
709
710#[derive(Debug, Clone, Serialize, Deserialize)]
711#[serde(untagged)]
712enum Content {
713 UntaggedText(String),
714 Chunks(Vec<ContentChunk>),
715}
716
717impl Content {
718 pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
719 match self {
720 Self::Chunks(chunks) => chunks.into_iter(),
721 Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
722 }
723 }
724}
725
726impl Display for Content {
727 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
728 match self {
729 Content::UntaggedText(txt) => write!(f, "{}", txt),
730 Content::Chunks(chunks) => {
731 for chunk in chunks {
732 write!(f, "{}", chunk)?;
733 }
734 Ok(())
735 }
736 }
737 }
738}
739
740#[derive(Debug, Clone, Serialize, Deserialize)]
741#[serde(tag = "type", rename_all = "snake_case")]
742enum ContentChunk {
743 Text {
744 text: String,
745 },
746 ToolUse {
747 id: String,
748 name: String,
749 input: serde_json::Value,
750 },
751 ToolResult {
752 content: Content,
753 tool_use_id: String,
754 },
755 Thinking {
756 thinking: String,
757 },
758 RedactedThinking,
759 // TODO
760 Image,
761 Document,
762 WebSearchToolResult,
763 #[serde(untagged)]
764 UntaggedText(String),
765}
766
767impl Display for ContentChunk {
768 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
769 match self {
770 ContentChunk::Text { text } => write!(f, "{}", text),
771 ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
772 ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
773 ContentChunk::UntaggedText(text) => write!(f, "{}", text),
774 ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
775 ContentChunk::Image
776 | ContentChunk::Document
777 | ContentChunk::ToolUse { .. }
778 | ContentChunk::WebSearchToolResult => {
779 write!(f, "\n{:?}\n", &self)
780 }
781 }
782 }
783}
784
785#[derive(Debug, Clone, Serialize, Deserialize)]
786struct Usage {
787 input_tokens: u32,
788 cache_creation_input_tokens: u32,
789 cache_read_input_tokens: u32,
790 output_tokens: u32,
791 service_tier: String,
792}
793
794#[derive(Debug, Clone, Serialize, Deserialize)]
795#[serde(rename_all = "snake_case")]
796enum Role {
797 System,
798 Assistant,
799 User,
800}
801
802#[derive(Debug, Clone, Serialize, Deserialize)]
803struct MessageParam {
804 role: Role,
805 content: String,
806}
807
808#[derive(Debug, Clone, Serialize, Deserialize)]
809#[serde(tag = "type", rename_all = "snake_case")]
810enum SdkMessage {
811 // An assistant message
812 Assistant {
813 message: Message, // from Anthropic SDK
814 #[serde(skip_serializing_if = "Option::is_none")]
815 session_id: Option<String>,
816 },
817 // A user message
818 User {
819 message: Message, // from Anthropic SDK
820 #[serde(skip_serializing_if = "Option::is_none")]
821 session_id: Option<String>,
822 },
823 // Emitted as the last message in a conversation
824 Result {
825 subtype: ResultErrorType,
826 duration_ms: f64,
827 duration_api_ms: f64,
828 is_error: bool,
829 num_turns: i32,
830 #[serde(skip_serializing_if = "Option::is_none")]
831 result: Option<String>,
832 session_id: String,
833 total_cost_usd: f64,
834 },
835 // Emitted as the first message at the start of a conversation
836 System {
837 cwd: String,
838 session_id: String,
839 tools: Vec<String>,
840 model: String,
841 mcp_servers: Vec<McpServer>,
842 #[serde(rename = "apiKeySource")]
843 api_key_source: String,
844 #[serde(rename = "permissionMode")]
845 permission_mode: PermissionMode,
846 },
847 /// Messages used to control the conversation, outside of chat messages to the model
848 ControlRequest {
849 request_id: String,
850 request: ControlRequest,
851 },
852 /// Response to a control request
853 ControlResponse { response: ControlResponse },
854}
855
856#[derive(Debug, Clone, Serialize, Deserialize)]
857#[serde(tag = "subtype", rename_all = "snake_case")]
858enum ControlRequest {
859 /// Cancel the current conversation
860 Interrupt,
861}
862
863#[derive(Debug, Clone, Serialize, Deserialize)]
864struct ControlResponse {
865 request_id: String,
866 subtype: ResultErrorType,
867}
868
869#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
870#[serde(rename_all = "snake_case")]
871enum ResultErrorType {
872 Success,
873 ErrorMaxTurns,
874 ErrorDuringExecution,
875}
876
877impl Display for ResultErrorType {
878 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
879 match self {
880 ResultErrorType::Success => write!(f, "success"),
881 ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
882 ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
883 }
884 }
885}
886
887fn new_request_id() -> String {
888 use rand::Rng;
889 // In the Claude Code TS SDK they just generate a random 12 character string,
890 // `Math.random().toString(36).substring(2, 15)`
891 rand::thread_rng()
892 .sample_iter(&rand::distributions::Alphanumeric)
893 .take(12)
894 .map(char::from)
895 .collect()
896}
897
898#[derive(Debug, Clone, Serialize, Deserialize)]
899struct McpServer {
900 name: String,
901 status: String,
902}
903
904#[derive(Debug, Clone, Serialize, Deserialize)]
905#[serde(rename_all = "camelCase")]
906enum PermissionMode {
907 Default,
908 AcceptEdits,
909 BypassPermissions,
910 Plan,
911}
912
913#[cfg(test)]
914pub(crate) mod tests {
915 use super::*;
916 use crate::e2e_tests;
917 use gpui::TestAppContext;
918 use serde_json::json;
919
920 crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
921
922 pub fn local_command() -> AgentServerCommand {
923 AgentServerCommand {
924 path: "claude".into(),
925 args: vec![],
926 env: None,
927 }
928 }
929
930 #[gpui::test]
931 #[cfg_attr(not(feature = "e2e"), ignore)]
932 async fn test_todo_plan(cx: &mut TestAppContext) {
933 let fs = e2e_tests::init_test(cx).await;
934 let project = Project::test(fs, [], cx).await;
935 let thread =
936 e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
937
938 thread
939 .update(cx, |thread, cx| {
940 thread.send_raw(
941 "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
942 cx,
943 )
944 })
945 .await
946 .unwrap();
947
948 let mut entries_len = 0;
949
950 thread.read_with(cx, |thread, _| {
951 entries_len = thread.plan().entries.len();
952 assert!(thread.plan().entries.len() > 0, "Empty plan");
953 });
954
955 thread
956 .update(cx, |thread, cx| {
957 thread.send_raw(
958 "Mark the first entry status as in progress without acting on it.",
959 cx,
960 )
961 })
962 .await
963 .unwrap();
964
965 thread.read_with(cx, |thread, _| {
966 assert!(matches!(
967 thread.plan().entries[0].status,
968 acp::PlanEntryStatus::InProgress
969 ));
970 assert_eq!(thread.plan().entries.len(), entries_len);
971 });
972
973 thread
974 .update(cx, |thread, cx| {
975 thread.send_raw(
976 "Now mark the first entry as completed without acting on it.",
977 cx,
978 )
979 })
980 .await
981 .unwrap();
982
983 thread.read_with(cx, |thread, _| {
984 assert!(matches!(
985 thread.plan().entries[0].status,
986 acp::PlanEntryStatus::Completed
987 ));
988 assert_eq!(thread.plan().entries.len(), entries_len);
989 });
990 }
991
992 #[test]
993 fn test_deserialize_content_untagged_text() {
994 let json = json!("Hello, world!");
995 let content: Content = serde_json::from_value(json).unwrap();
996 match content {
997 Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
998 _ => panic!("Expected UntaggedText variant"),
999 }
1000 }
1001
1002 #[test]
1003 fn test_deserialize_content_chunks() {
1004 let json = json!([
1005 {
1006 "type": "text",
1007 "text": "Hello"
1008 },
1009 {
1010 "type": "tool_use",
1011 "id": "tool_123",
1012 "name": "calculator",
1013 "input": {"operation": "add", "a": 1, "b": 2}
1014 }
1015 ]);
1016 let content: Content = serde_json::from_value(json).unwrap();
1017 match content {
1018 Content::Chunks(chunks) => {
1019 assert_eq!(chunks.len(), 2);
1020 match &chunks[0] {
1021 ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1022 _ => panic!("Expected Text chunk"),
1023 }
1024 match &chunks[1] {
1025 ContentChunk::ToolUse { id, name, input } => {
1026 assert_eq!(id, "tool_123");
1027 assert_eq!(name, "calculator");
1028 assert_eq!(input["operation"], "add");
1029 assert_eq!(input["a"], 1);
1030 assert_eq!(input["b"], 2);
1031 }
1032 _ => panic!("Expected ToolUse chunk"),
1033 }
1034 }
1035 _ => panic!("Expected Chunks variant"),
1036 }
1037 }
1038
1039 #[test]
1040 fn test_deserialize_tool_result_untagged_text() {
1041 let json = json!({
1042 "type": "tool_result",
1043 "content": "Result content",
1044 "tool_use_id": "tool_456"
1045 });
1046 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1047 match chunk {
1048 ContentChunk::ToolResult {
1049 content,
1050 tool_use_id,
1051 } => {
1052 match content {
1053 Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1054 _ => panic!("Expected UntaggedText content"),
1055 }
1056 assert_eq!(tool_use_id, "tool_456");
1057 }
1058 _ => panic!("Expected ToolResult variant"),
1059 }
1060 }
1061
1062 #[test]
1063 fn test_deserialize_tool_result_chunks() {
1064 let json = json!({
1065 "type": "tool_result",
1066 "content": [
1067 {
1068 "type": "text",
1069 "text": "Processing complete"
1070 },
1071 {
1072 "type": "text",
1073 "text": "Result: 42"
1074 }
1075 ],
1076 "tool_use_id": "tool_789"
1077 });
1078 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1079 match chunk {
1080 ContentChunk::ToolResult {
1081 content,
1082 tool_use_id,
1083 } => {
1084 match content {
1085 Content::Chunks(chunks) => {
1086 assert_eq!(chunks.len(), 2);
1087 match &chunks[0] {
1088 ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1089 _ => panic!("Expected Text chunk"),
1090 }
1091 match &chunks[1] {
1092 ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1093 _ => panic!("Expected Text chunk"),
1094 }
1095 }
1096 _ => panic!("Expected Chunks content"),
1097 }
1098 assert_eq!(tool_use_id, "tool_789");
1099 }
1100 _ => panic!("Expected ToolResult variant"),
1101 }
1102 }
1103}