1mod mcp_server;
2pub mod tools;
3
4use collections::HashMap;
5use context_server::listener::McpServerTool;
6use language_model::LanguageModelRegistry;
7use language_models::provider::anthropic::AnthropicLanguageModelProvider;
8use project::Project;
9use settings::SettingsStore;
10use smol::process::Child;
11use std::any::Any;
12use std::cell::RefCell;
13use std::fmt::Display;
14use std::path::Path;
15use std::rc::Rc;
16use std::sync::Arc;
17use uuid::Uuid;
18
19use agent_client_protocol as acp;
20use anyhow::{Context as _, Result, anyhow};
21use futures::channel::oneshot;
22use futures::{AsyncBufReadExt, AsyncWriteExt};
23use futures::{
24 AsyncRead, AsyncWrite, FutureExt, StreamExt,
25 channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
26 io::BufReader,
27 select_biased,
28};
29use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
30use serde::{Deserialize, Serialize};
31use util::{ResultExt, debug_panic};
32
33use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
34use crate::claude::tools::ClaudeTool;
35use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
36use acp_thread::{AcpThread, AgentConnection, AuthRequired};
37
38#[derive(Clone)]
39pub struct ClaudeCode;
40
41impl AgentServer for ClaudeCode {
42 fn name(&self) -> &'static str {
43 "Claude Code"
44 }
45
46 fn empty_state_headline(&self) -> &'static str {
47 self.name()
48 }
49
50 fn empty_state_message(&self) -> &'static str {
51 "How can I help you today?"
52 }
53
54 fn logo(&self) -> ui::IconName {
55 ui::IconName::AiClaude
56 }
57
58 fn connect(
59 &self,
60 _root_dir: &Path,
61 _project: &Entity<Project>,
62 _cx: &mut App,
63 ) -> Task<Result<Rc<dyn AgentConnection>>> {
64 let connection = ClaudeAgentConnection {
65 sessions: Default::default(),
66 };
67
68 Task::ready(Ok(Rc::new(connection) as _))
69 }
70}
71
72struct ClaudeAgentConnection {
73 sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
74}
75
76impl AgentConnection for ClaudeAgentConnection {
77 fn new_thread(
78 self: Rc<Self>,
79 project: Entity<Project>,
80 cwd: &Path,
81 cx: &mut App,
82 ) -> Task<Result<Entity<AcpThread>>> {
83 let cwd = cwd.to_owned();
84 cx.spawn(async move |cx| {
85 let settings = cx.read_global(|settings: &SettingsStore, _| {
86 settings.get::<AllAgentServersSettings>(None).claude.clone()
87 })?;
88
89 let Some(command) = AgentServerCommand::resolve(
90 "claude",
91 &[],
92 Some(&util::paths::home_dir().join(".claude/local/claude")),
93 settings,
94 &project,
95 cx,
96 )
97 .await
98 else {
99 anyhow::bail!("Failed to find claude binary");
100 };
101
102 let anthropic: Arc<AnthropicLanguageModelProvider> = cx.update(|cx| {
103 let registry = LanguageModelRegistry::global(cx);
104 let provider: Arc<dyn Any + Send + Sync> = registry
105 .read(cx)
106 .provider(&language_model::ANTHROPIC_PROVIDER_ID)
107 .context("Failed to get Anthropic provider")?;
108
109 Arc::downcast::<AnthropicLanguageModelProvider>(provider)
110 .map_err(|_| anyhow!("Failed to downcast provider"))
111 })??;
112
113 let api_key = cx
114 .update(|cx| AnthropicLanguageModelProvider::api_key(cx))?
115 .await
116 .map_err(|err| {
117 if err.is::<language_model::AuthenticateError>() {
118 let (update_tx, update_rx) = oneshot::channel();
119 let mut update_tx = Some(update_tx);
120
121 let sub = cx
122 .update(|cx| {
123 anthropic.observe(
124 move |_cx| {
125 if let Some(update_tx) = update_tx.take() {
126 update_tx.send(()).ok();
127 }
128 },
129 cx,
130 )
131 })
132 .ok();
133
134 let update_task = cx.foreground_executor().spawn(async move {
135 update_rx.await.ok();
136 drop(sub)
137 });
138
139 anyhow!(
140 AuthRequired::new()
141 .with_description(
142 "To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into())
143 .with_update(update_task)
144 )
145 } else {
146 anyhow!(err)
147 }
148 })?;
149
150 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
151 let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), cx).await?;
152
153 let mut mcp_servers = HashMap::default();
154 mcp_servers.insert(
155 mcp_server::SERVER_NAME.to_string(),
156 permission_mcp_server.server_config()?,
157 );
158 let mcp_config = McpConfig { mcp_servers };
159
160 let mcp_config_file = tempfile::NamedTempFile::new()?;
161 let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
162
163 let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
164 mcp_config_file
165 .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
166 .await?;
167 mcp_config_file.flush().await?;
168
169 let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
170 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
171
172 let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
173
174 log::trace!("Starting session with id: {}", session_id);
175
176 let mut child = spawn_claude(
177 &command,
178 ClaudeSessionMode::Start,
179 session_id.clone(),
180 api_key,
181 &mcp_config_path,
182 &cwd,
183 )?;
184
185 let stdout = child.stdout.take().context("Failed to take stdout")?;
186 let stdin = child.stdin.take().context("Failed to take stdin")?;
187 let stderr = child.stderr.take().context("Failed to take stderr")?;
188
189 let pid = child.id();
190 log::trace!("Spawned (pid: {})", pid);
191
192 cx.background_spawn(async move {
193 let mut stderr = BufReader::new(stderr);
194 let mut line = String::new();
195 while let Ok(n) = stderr.read_line(&mut line).await
196 && n > 0
197 {
198 log::warn!("agent stderr: {}", &line);
199 line.clear();
200 }
201 })
202 .detach();
203
204 cx.background_spawn(async move {
205 let mut outgoing_rx = Some(outgoing_rx);
206
207 ClaudeAgentSession::handle_io(
208 outgoing_rx.take().unwrap(),
209 incoming_message_tx.clone(),
210 stdin,
211 stdout,
212 )
213 .await?;
214
215 log::trace!("Stopped (pid: {})", pid);
216
217 drop(mcp_config_path);
218 anyhow::Ok(())
219 })
220 .detach();
221
222 let turn_state = Rc::new(RefCell::new(TurnState::None));
223
224 let handler_task = cx.spawn({
225 let turn_state = turn_state.clone();
226 let mut thread_rx = thread_rx.clone();
227 async move |cx| {
228 while let Some(message) = incoming_message_rx.next().await {
229 ClaudeAgentSession::handle_message(
230 thread_rx.clone(),
231 message,
232 turn_state.clone(),
233 cx,
234 )
235 .await
236 }
237
238 if let Some(status) = child.status().await.log_err() {
239 if let Some(thread) = thread_rx.recv().await.ok() {
240 thread
241 .update(cx, |thread, cx| {
242 thread.emit_server_exited(status, cx);
243 })
244 .ok();
245 }
246 }
247 }
248 });
249
250 let thread = cx.new(|cx| {
251 AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
252 })?;
253
254 thread_tx.send(thread.downgrade())?;
255
256 let session = ClaudeAgentSession {
257 outgoing_tx,
258 turn_state,
259 _handler_task: handler_task,
260 _mcp_server: Some(permission_mcp_server),
261 };
262
263 self.sessions.borrow_mut().insert(session_id, session);
264
265 Ok(thread)
266 })
267 }
268
269 fn auth_methods(&self) -> &[acp::AuthMethod] {
270 &[]
271 }
272
273 fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
274 Task::ready(Err(anyhow!("Authentication not supported")))
275 }
276
277 fn prompt(
278 &self,
279 _id: Option<acp_thread::UserMessageId>,
280 params: acp::PromptRequest,
281 cx: &mut App,
282 ) -> Task<Result<acp::PromptResponse>> {
283 let sessions = self.sessions.borrow();
284 let Some(session) = sessions.get(¶ms.session_id) else {
285 return Task::ready(Err(anyhow!(
286 "Attempted to send message to nonexistent session {}",
287 params.session_id
288 )));
289 };
290
291 let (end_tx, end_rx) = oneshot::channel();
292 session.turn_state.replace(TurnState::InProgress { end_tx });
293
294 let mut content = String::new();
295 for chunk in params.prompt {
296 match chunk {
297 acp::ContentBlock::Text(text_content) => {
298 content.push_str(&text_content.text);
299 }
300 acp::ContentBlock::ResourceLink(resource_link) => {
301 content.push_str(&format!("@{}", resource_link.uri));
302 }
303 acp::ContentBlock::Audio(_)
304 | acp::ContentBlock::Image(_)
305 | acp::ContentBlock::Resource(_) => {
306 // TODO
307 }
308 }
309 }
310
311 if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
312 message: Message {
313 role: Role::User,
314 content: Content::UntaggedText(content),
315 id: None,
316 model: None,
317 stop_reason: None,
318 stop_sequence: None,
319 usage: None,
320 },
321 session_id: Some(params.session_id.to_string()),
322 }) {
323 return Task::ready(Err(anyhow!(err)));
324 }
325
326 cx.foreground_executor().spawn(async move { end_rx.await? })
327 }
328
329 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
330 let sessions = self.sessions.borrow();
331 let Some(session) = sessions.get(&session_id) else {
332 log::warn!("Attempted to cancel nonexistent session {}", session_id);
333 return;
334 };
335
336 let request_id = new_request_id();
337
338 let turn_state = session.turn_state.take();
339 let TurnState::InProgress { end_tx } = turn_state else {
340 // Already canceled or idle, put it back
341 session.turn_state.replace(turn_state);
342 return;
343 };
344
345 session.turn_state.replace(TurnState::CancelRequested {
346 end_tx,
347 request_id: request_id.clone(),
348 });
349
350 session
351 .outgoing_tx
352 .unbounded_send(SdkMessage::ControlRequest {
353 request_id,
354 request: ControlRequest::Interrupt,
355 })
356 .log_err();
357 }
358
359 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
360 self
361 }
362}
363
364#[derive(Clone, Copy)]
365enum ClaudeSessionMode {
366 Start,
367 #[expect(dead_code)]
368 Resume,
369}
370
371fn spawn_claude(
372 command: &AgentServerCommand,
373 mode: ClaudeSessionMode,
374 session_id: acp::SessionId,
375 api_key: language_models::provider::anthropic::ApiKey,
376 mcp_config_path: &Path,
377 root_dir: &Path,
378) -> Result<Child> {
379 let child = util::command::new_smol_command(&command.path)
380 .args([
381 "--input-format",
382 "stream-json",
383 "--output-format",
384 "stream-json",
385 "--print",
386 "--verbose",
387 "--mcp-config",
388 mcp_config_path.to_string_lossy().as_ref(),
389 "--permission-prompt-tool",
390 &format!(
391 "mcp__{}__{}",
392 mcp_server::SERVER_NAME,
393 mcp_server::PermissionTool::NAME,
394 ),
395 "--allowedTools",
396 &format!(
397 "mcp__{}__{},mcp__{}__{}",
398 mcp_server::SERVER_NAME,
399 mcp_server::EditTool::NAME,
400 mcp_server::SERVER_NAME,
401 mcp_server::ReadTool::NAME
402 ),
403 "--disallowedTools",
404 "Read,Edit",
405 ])
406 .args(match mode {
407 ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
408 ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
409 })
410 .args(command.args.iter().map(|arg| arg.as_str()))
411 .envs(command.env.iter().flatten())
412 .env("ANTHROPIC_API_KEY", api_key.key)
413 .current_dir(root_dir)
414 .stdin(std::process::Stdio::piped())
415 .stdout(std::process::Stdio::piped())
416 .stderr(std::process::Stdio::piped())
417 .kill_on_drop(true)
418 .spawn()?;
419
420 Ok(child)
421}
422
423struct ClaudeAgentSession {
424 outgoing_tx: UnboundedSender<SdkMessage>,
425 turn_state: Rc<RefCell<TurnState>>,
426 _mcp_server: Option<ClaudeZedMcpServer>,
427 _handler_task: Task<()>,
428}
429
430#[derive(Debug, Default)]
431enum TurnState {
432 #[default]
433 None,
434 InProgress {
435 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
436 },
437 CancelRequested {
438 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
439 request_id: String,
440 },
441 CancelConfirmed {
442 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
443 },
444}
445
446impl TurnState {
447 fn is_canceled(&self) -> bool {
448 matches!(self, TurnState::CancelConfirmed { .. })
449 }
450
451 fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
452 match self {
453 TurnState::None => None,
454 TurnState::InProgress { end_tx, .. } => Some(end_tx),
455 TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
456 TurnState::CancelConfirmed { end_tx } => Some(end_tx),
457 }
458 }
459
460 fn confirm_cancellation(self, id: &str) -> Self {
461 match self {
462 TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
463 TurnState::CancelConfirmed { end_tx }
464 }
465 _ => self,
466 }
467 }
468}
469
470impl ClaudeAgentSession {
471 async fn handle_message(
472 mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
473 message: SdkMessage,
474 turn_state: Rc<RefCell<TurnState>>,
475 cx: &mut AsyncApp,
476 ) {
477 match message {
478 // we should only be sending these out, they don't need to be in the thread
479 SdkMessage::ControlRequest { .. } => {}
480 SdkMessage::User {
481 message,
482 session_id: _,
483 } => {
484 let Some(thread) = thread_rx
485 .recv()
486 .await
487 .log_err()
488 .and_then(|entity| entity.upgrade())
489 else {
490 log::error!("Received an SDK message but thread is gone");
491 return;
492 };
493
494 for chunk in message.content.chunks() {
495 match chunk {
496 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
497 if !turn_state.borrow().is_canceled() {
498 thread
499 .update(cx, |thread, cx| {
500 thread.push_user_content_block(None, text.into(), cx)
501 })
502 .log_err();
503 }
504 }
505 ContentChunk::ToolResult {
506 content,
507 tool_use_id,
508 } => {
509 let content = content.to_string();
510 thread
511 .update(cx, |thread, cx| {
512 thread.update_tool_call(
513 acp::ToolCallUpdate {
514 id: acp::ToolCallId(tool_use_id.into()),
515 fields: acp::ToolCallUpdateFields {
516 status: if turn_state.borrow().is_canceled() {
517 // Do not set to completed if turn was canceled
518 None
519 } else {
520 Some(acp::ToolCallStatus::Completed)
521 },
522 content: (!content.is_empty())
523 .then(|| vec![content.into()]),
524 ..Default::default()
525 },
526 },
527 cx,
528 )
529 })
530 .log_err();
531 }
532 ContentChunk::Thinking { .. }
533 | ContentChunk::RedactedThinking
534 | ContentChunk::ToolUse { .. } => {
535 debug_panic!(
536 "Should not get {:?} with role: assistant. should we handle this?",
537 chunk
538 );
539 }
540
541 ContentChunk::Image
542 | ContentChunk::Document
543 | ContentChunk::WebSearchToolResult => {
544 thread
545 .update(cx, |thread, cx| {
546 thread.push_assistant_content_block(
547 format!("Unsupported content: {:?}", chunk).into(),
548 false,
549 cx,
550 )
551 })
552 .log_err();
553 }
554 }
555 }
556 }
557 SdkMessage::Assistant {
558 message,
559 session_id: _,
560 } => {
561 let Some(thread) = thread_rx
562 .recv()
563 .await
564 .log_err()
565 .and_then(|entity| entity.upgrade())
566 else {
567 log::error!("Received an SDK message but thread is gone");
568 return;
569 };
570
571 for chunk in message.content.chunks() {
572 match chunk {
573 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
574 thread
575 .update(cx, |thread, cx| {
576 thread.push_assistant_content_block(text.into(), false, cx)
577 })
578 .log_err();
579 }
580 ContentChunk::Thinking { thinking } => {
581 thread
582 .update(cx, |thread, cx| {
583 thread.push_assistant_content_block(thinking.into(), true, cx)
584 })
585 .log_err();
586 }
587 ContentChunk::RedactedThinking => {
588 thread
589 .update(cx, |thread, cx| {
590 thread.push_assistant_content_block(
591 "[REDACTED]".into(),
592 true,
593 cx,
594 )
595 })
596 .log_err();
597 }
598 ContentChunk::ToolUse { id, name, input } => {
599 let claude_tool = ClaudeTool::infer(&name, input);
600
601 thread
602 .update(cx, |thread, cx| {
603 if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
604 thread.update_plan(
605 acp::Plan {
606 entries: params
607 .todos
608 .into_iter()
609 .map(Into::into)
610 .collect(),
611 },
612 cx,
613 )
614 } else {
615 thread.upsert_tool_call(
616 claude_tool.as_acp(acp::ToolCallId(id.into())),
617 cx,
618 )?;
619 }
620 anyhow::Ok(())
621 })
622 .log_err();
623 }
624 ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
625 debug_panic!(
626 "Should not get tool results with role: assistant. should we handle this?"
627 );
628 }
629 ContentChunk::Image | ContentChunk::Document => {
630 thread
631 .update(cx, |thread, cx| {
632 thread.push_assistant_content_block(
633 format!("Unsupported content: {:?}", chunk).into(),
634 false,
635 cx,
636 )
637 })
638 .log_err();
639 }
640 }
641 }
642 }
643 SdkMessage::Result {
644 is_error,
645 subtype,
646 result,
647 ..
648 } => {
649 let turn_state = turn_state.take();
650 let was_canceled = turn_state.is_canceled();
651 let Some(end_turn_tx) = turn_state.end_tx() else {
652 debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
653 return;
654 };
655
656 if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
657 end_turn_tx
658 .send(Err(anyhow!(
659 "Error: {}",
660 result.unwrap_or_else(|| subtype.to_string())
661 )))
662 .ok();
663 } else {
664 let stop_reason = match subtype {
665 ResultErrorType::Success => acp::StopReason::EndTurn,
666 ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
667 ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
668 };
669 end_turn_tx
670 .send(Ok(acp::PromptResponse { stop_reason }))
671 .ok();
672 }
673 }
674 SdkMessage::ControlResponse { response } => {
675 if matches!(response.subtype, ResultErrorType::Success) {
676 let new_state = turn_state.take().confirm_cancellation(&response.request_id);
677 turn_state.replace(new_state);
678 } else {
679 log::error!("Control response error: {:?}", response);
680 }
681 }
682 SdkMessage::System { .. } => {}
683 }
684 }
685
686 async fn handle_io(
687 mut outgoing_rx: UnboundedReceiver<SdkMessage>,
688 incoming_tx: UnboundedSender<SdkMessage>,
689 mut outgoing_bytes: impl Unpin + AsyncWrite,
690 incoming_bytes: impl Unpin + AsyncRead,
691 ) -> Result<UnboundedReceiver<SdkMessage>> {
692 let mut output_reader = BufReader::new(incoming_bytes);
693 let mut outgoing_line = Vec::new();
694 let mut incoming_line = String::new();
695 loop {
696 select_biased! {
697 message = outgoing_rx.next() => {
698 if let Some(message) = message {
699 outgoing_line.clear();
700 serde_json::to_writer(&mut outgoing_line, &message)?;
701 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
702 outgoing_line.push(b'\n');
703 outgoing_bytes.write_all(&outgoing_line).await.ok();
704 } else {
705 break;
706 }
707 }
708 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
709 if bytes_read? == 0 {
710 break
711 }
712 log::trace!("recv: {}", &incoming_line);
713 match serde_json::from_str::<SdkMessage>(&incoming_line) {
714 Ok(message) => {
715 incoming_tx.unbounded_send(message).log_err();
716 }
717 Err(error) => {
718 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
719 }
720 }
721 incoming_line.clear();
722 }
723 }
724 }
725
726 Ok(outgoing_rx)
727 }
728}
729
730#[derive(Debug, Clone, Serialize, Deserialize)]
731struct Message {
732 role: Role,
733 content: Content,
734 #[serde(skip_serializing_if = "Option::is_none")]
735 id: Option<String>,
736 #[serde(skip_serializing_if = "Option::is_none")]
737 model: Option<String>,
738 #[serde(skip_serializing_if = "Option::is_none")]
739 stop_reason: Option<String>,
740 #[serde(skip_serializing_if = "Option::is_none")]
741 stop_sequence: Option<String>,
742 #[serde(skip_serializing_if = "Option::is_none")]
743 usage: Option<Usage>,
744}
745
746#[derive(Debug, Clone, Serialize, Deserialize)]
747#[serde(untagged)]
748enum Content {
749 UntaggedText(String),
750 Chunks(Vec<ContentChunk>),
751}
752
753impl Content {
754 pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
755 match self {
756 Self::Chunks(chunks) => chunks.into_iter(),
757 Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
758 }
759 }
760}
761
762impl Display for Content {
763 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
764 match self {
765 Content::UntaggedText(txt) => write!(f, "{}", txt),
766 Content::Chunks(chunks) => {
767 for chunk in chunks {
768 write!(f, "{}", chunk)?;
769 }
770 Ok(())
771 }
772 }
773 }
774}
775
776#[derive(Debug, Clone, Serialize, Deserialize)]
777#[serde(tag = "type", rename_all = "snake_case")]
778enum ContentChunk {
779 Text {
780 text: String,
781 },
782 ToolUse {
783 id: String,
784 name: String,
785 input: serde_json::Value,
786 },
787 ToolResult {
788 content: Content,
789 tool_use_id: String,
790 },
791 Thinking {
792 thinking: String,
793 },
794 RedactedThinking,
795 // TODO
796 Image,
797 Document,
798 WebSearchToolResult,
799 #[serde(untagged)]
800 UntaggedText(String),
801}
802
803impl Display for ContentChunk {
804 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
805 match self {
806 ContentChunk::Text { text } => write!(f, "{}", text),
807 ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
808 ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
809 ContentChunk::UntaggedText(text) => write!(f, "{}", text),
810 ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
811 ContentChunk::Image
812 | ContentChunk::Document
813 | ContentChunk::ToolUse { .. }
814 | ContentChunk::WebSearchToolResult => {
815 write!(f, "\n{:?}\n", &self)
816 }
817 }
818 }
819}
820
821#[derive(Debug, Clone, Serialize, Deserialize)]
822struct Usage {
823 input_tokens: u32,
824 cache_creation_input_tokens: u32,
825 cache_read_input_tokens: u32,
826 output_tokens: u32,
827 service_tier: String,
828}
829
830#[derive(Debug, Clone, Serialize, Deserialize)]
831#[serde(rename_all = "snake_case")]
832enum Role {
833 System,
834 Assistant,
835 User,
836}
837
838#[derive(Debug, Clone, Serialize, Deserialize)]
839struct MessageParam {
840 role: Role,
841 content: String,
842}
843
844#[derive(Debug, Clone, Serialize, Deserialize)]
845#[serde(tag = "type", rename_all = "snake_case")]
846enum SdkMessage {
847 // An assistant message
848 Assistant {
849 message: Message, // from Anthropic SDK
850 #[serde(skip_serializing_if = "Option::is_none")]
851 session_id: Option<String>,
852 },
853 // A user message
854 User {
855 message: Message, // from Anthropic SDK
856 #[serde(skip_serializing_if = "Option::is_none")]
857 session_id: Option<String>,
858 },
859 // Emitted as the last message in a conversation
860 Result {
861 subtype: ResultErrorType,
862 duration_ms: f64,
863 duration_api_ms: f64,
864 is_error: bool,
865 num_turns: i32,
866 #[serde(skip_serializing_if = "Option::is_none")]
867 result: Option<String>,
868 session_id: String,
869 total_cost_usd: f64,
870 },
871 // Emitted as the first message at the start of a conversation
872 System {
873 cwd: String,
874 session_id: String,
875 tools: Vec<String>,
876 model: String,
877 mcp_servers: Vec<McpServer>,
878 #[serde(rename = "apiKeySource")]
879 api_key_source: String,
880 #[serde(rename = "permissionMode")]
881 permission_mode: PermissionMode,
882 },
883 /// Messages used to control the conversation, outside of chat messages to the model
884 ControlRequest {
885 request_id: String,
886 request: ControlRequest,
887 },
888 /// Response to a control request
889 ControlResponse { response: ControlResponse },
890}
891
892#[derive(Debug, Clone, Serialize, Deserialize)]
893#[serde(tag = "subtype", rename_all = "snake_case")]
894enum ControlRequest {
895 /// Cancel the current conversation
896 Interrupt,
897}
898
899#[derive(Debug, Clone, Serialize, Deserialize)]
900struct ControlResponse {
901 request_id: String,
902 subtype: ResultErrorType,
903}
904
905#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
906#[serde(rename_all = "snake_case")]
907enum ResultErrorType {
908 Success,
909 ErrorMaxTurns,
910 ErrorDuringExecution,
911}
912
913impl Display for ResultErrorType {
914 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
915 match self {
916 ResultErrorType::Success => write!(f, "success"),
917 ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
918 ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
919 }
920 }
921}
922
923fn new_request_id() -> String {
924 use rand::Rng;
925 // In the Claude Code TS SDK they just generate a random 12 character string,
926 // `Math.random().toString(36).substring(2, 15)`
927 rand::thread_rng()
928 .sample_iter(&rand::distributions::Alphanumeric)
929 .take(12)
930 .map(char::from)
931 .collect()
932}
933
934#[derive(Debug, Clone, Serialize, Deserialize)]
935struct McpServer {
936 name: String,
937 status: String,
938}
939
940#[derive(Debug, Clone, Serialize, Deserialize)]
941#[serde(rename_all = "camelCase")]
942enum PermissionMode {
943 Default,
944 AcceptEdits,
945 BypassPermissions,
946 Plan,
947}
948
949#[cfg(test)]
950pub(crate) mod tests {
951 use super::*;
952 use crate::e2e_tests;
953 use gpui::TestAppContext;
954 use serde_json::json;
955
956 crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
957
958 pub fn local_command() -> AgentServerCommand {
959 AgentServerCommand {
960 path: "claude".into(),
961 args: vec![],
962 env: None,
963 }
964 }
965
966 #[gpui::test]
967 #[cfg_attr(not(feature = "e2e"), ignore)]
968 async fn test_todo_plan(cx: &mut TestAppContext) {
969 let fs = e2e_tests::init_test(cx).await;
970 let project = Project::test(fs, [], cx).await;
971 let thread =
972 e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
973
974 thread
975 .update(cx, |thread, cx| {
976 thread.send_raw(
977 "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
978 cx,
979 )
980 })
981 .await
982 .unwrap();
983
984 let mut entries_len = 0;
985
986 thread.read_with(cx, |thread, _| {
987 entries_len = thread.plan().entries.len();
988 assert!(thread.plan().entries.len() > 0, "Empty plan");
989 });
990
991 thread
992 .update(cx, |thread, cx| {
993 thread.send_raw(
994 "Mark the first entry status as in progress without acting on it.",
995 cx,
996 )
997 })
998 .await
999 .unwrap();
1000
1001 thread.read_with(cx, |thread, _| {
1002 assert!(matches!(
1003 thread.plan().entries[0].status,
1004 acp::PlanEntryStatus::InProgress
1005 ));
1006 assert_eq!(thread.plan().entries.len(), entries_len);
1007 });
1008
1009 thread
1010 .update(cx, |thread, cx| {
1011 thread.send_raw(
1012 "Now mark the first entry as completed without acting on it.",
1013 cx,
1014 )
1015 })
1016 .await
1017 .unwrap();
1018
1019 thread.read_with(cx, |thread, _| {
1020 assert!(matches!(
1021 thread.plan().entries[0].status,
1022 acp::PlanEntryStatus::Completed
1023 ));
1024 assert_eq!(thread.plan().entries.len(), entries_len);
1025 });
1026 }
1027
1028 #[test]
1029 fn test_deserialize_content_untagged_text() {
1030 let json = json!("Hello, world!");
1031 let content: Content = serde_json::from_value(json).unwrap();
1032 match content {
1033 Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1034 _ => panic!("Expected UntaggedText variant"),
1035 }
1036 }
1037
1038 #[test]
1039 fn test_deserialize_content_chunks() {
1040 let json = json!([
1041 {
1042 "type": "text",
1043 "text": "Hello"
1044 },
1045 {
1046 "type": "tool_use",
1047 "id": "tool_123",
1048 "name": "calculator",
1049 "input": {"operation": "add", "a": 1, "b": 2}
1050 }
1051 ]);
1052 let content: Content = serde_json::from_value(json).unwrap();
1053 match content {
1054 Content::Chunks(chunks) => {
1055 assert_eq!(chunks.len(), 2);
1056 match &chunks[0] {
1057 ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1058 _ => panic!("Expected Text chunk"),
1059 }
1060 match &chunks[1] {
1061 ContentChunk::ToolUse { id, name, input } => {
1062 assert_eq!(id, "tool_123");
1063 assert_eq!(name, "calculator");
1064 assert_eq!(input["operation"], "add");
1065 assert_eq!(input["a"], 1);
1066 assert_eq!(input["b"], 2);
1067 }
1068 _ => panic!("Expected ToolUse chunk"),
1069 }
1070 }
1071 _ => panic!("Expected Chunks variant"),
1072 }
1073 }
1074
1075 #[test]
1076 fn test_deserialize_tool_result_untagged_text() {
1077 let json = json!({
1078 "type": "tool_result",
1079 "content": "Result content",
1080 "tool_use_id": "tool_456"
1081 });
1082 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1083 match chunk {
1084 ContentChunk::ToolResult {
1085 content,
1086 tool_use_id,
1087 } => {
1088 match content {
1089 Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1090 _ => panic!("Expected UntaggedText content"),
1091 }
1092 assert_eq!(tool_use_id, "tool_456");
1093 }
1094 _ => panic!("Expected ToolResult variant"),
1095 }
1096 }
1097
1098 #[test]
1099 fn test_deserialize_tool_result_chunks() {
1100 let json = json!({
1101 "type": "tool_result",
1102 "content": [
1103 {
1104 "type": "text",
1105 "text": "Processing complete"
1106 },
1107 {
1108 "type": "text",
1109 "text": "Result: 42"
1110 }
1111 ],
1112 "tool_use_id": "tool_789"
1113 });
1114 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1115 match chunk {
1116 ContentChunk::ToolResult {
1117 content,
1118 tool_use_id,
1119 } => {
1120 match content {
1121 Content::Chunks(chunks) => {
1122 assert_eq!(chunks.len(), 2);
1123 match &chunks[0] {
1124 ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1125 _ => panic!("Expected Text chunk"),
1126 }
1127 match &chunks[1] {
1128 ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1129 _ => panic!("Expected Text chunk"),
1130 }
1131 }
1132 _ => panic!("Expected Chunks content"),
1133 }
1134 assert_eq!(tool_use_id, "tool_789");
1135 }
1136 _ => panic!("Expected ToolResult variant"),
1137 }
1138 }
1139}