1mod mcp_server;
2pub mod tools;
3
4use collections::HashMap;
5use context_server::listener::McpServerTool;
6use language_models::provider::anthropic::AnthropicLanguageModelProvider;
7use project::Project;
8use settings::SettingsStore;
9use smol::process::Child;
10use std::any::Any;
11use std::cell::RefCell;
12use std::fmt::Display;
13use std::path::Path;
14use std::rc::Rc;
15use uuid::Uuid;
16
17use agent_client_protocol as acp;
18use anyhow::{Context as _, Result, anyhow};
19use futures::channel::oneshot;
20use futures::{AsyncBufReadExt, AsyncWriteExt};
21use futures::{
22 AsyncRead, AsyncWrite, FutureExt, StreamExt,
23 channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
24 io::BufReader,
25 select_biased,
26};
27use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
28use serde::{Deserialize, Serialize};
29use util::{ResultExt, debug_panic};
30
31use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
32use crate::claude::tools::ClaudeTool;
33use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
34use acp_thread::{AcpThread, AgentConnection, AuthRequired};
35
36#[derive(Clone)]
37pub struct ClaudeCode;
38
39impl AgentServer for ClaudeCode {
40 fn name(&self) -> &'static str {
41 "Claude Code"
42 }
43
44 fn empty_state_headline(&self) -> &'static str {
45 self.name()
46 }
47
48 fn empty_state_message(&self) -> &'static str {
49 "How can I help you today?"
50 }
51
52 fn logo(&self) -> ui::IconName {
53 ui::IconName::AiClaude
54 }
55
56 fn connect(
57 &self,
58 _root_dir: &Path,
59 _project: &Entity<Project>,
60 _cx: &mut App,
61 ) -> Task<Result<Rc<dyn AgentConnection>>> {
62 let connection = ClaudeAgentConnection {
63 sessions: Default::default(),
64 };
65
66 Task::ready(Ok(Rc::new(connection) as _))
67 }
68
69 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
70 self
71 }
72}
73
74struct ClaudeAgentConnection {
75 sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
76}
77
78impl AgentConnection for ClaudeAgentConnection {
79 fn new_thread(
80 self: Rc<Self>,
81 project: Entity<Project>,
82 cwd: &Path,
83 cx: &mut App,
84 ) -> Task<Result<Entity<AcpThread>>> {
85 let cwd = cwd.to_owned();
86 cx.spawn(async move |cx| {
87 let settings = cx.read_global(|settings: &SettingsStore, _| {
88 settings.get::<AllAgentServersSettings>(None).claude.clone()
89 })?;
90
91 let Some(command) = AgentServerCommand::resolve(
92 "claude",
93 &[],
94 Some(&util::paths::home_dir().join(".claude/local/claude")),
95 settings,
96 &project,
97 cx,
98 )
99 .await
100 else {
101 anyhow::bail!("Failed to find claude binary");
102 };
103
104 let api_key =
105 cx.update(AnthropicLanguageModelProvider::api_key)?
106 .await
107 .map_err(|err| {
108 if err.is::<language_model::AuthenticateError>() {
109 anyhow!(AuthRequired::new().with_language_model_provider(
110 language_model::ANTHROPIC_PROVIDER_ID
111 ))
112 } else {
113 anyhow!(err)
114 }
115 })?;
116
117 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
118 let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
119 let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
120
121 let mut mcp_servers = HashMap::default();
122 mcp_servers.insert(
123 mcp_server::SERVER_NAME.to_string(),
124 permission_mcp_server.server_config()?,
125 );
126 let mcp_config = McpConfig { mcp_servers };
127
128 let mcp_config_file = tempfile::NamedTempFile::new()?;
129 let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
130
131 let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
132 mcp_config_file
133 .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
134 .await?;
135 mcp_config_file.flush().await?;
136
137 let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
138 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
139
140 let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
141
142 log::trace!("Starting session with id: {}", session_id);
143
144 let mut child = spawn_claude(
145 &command,
146 ClaudeSessionMode::Start,
147 session_id.clone(),
148 api_key,
149 &mcp_config_path,
150 &cwd,
151 )?;
152
153 let stdout = child.stdout.take().context("Failed to take stdout")?;
154 let stdin = child.stdin.take().context("Failed to take stdin")?;
155 let stderr = child.stderr.take().context("Failed to take stderr")?;
156
157 let pid = child.id();
158 log::trace!("Spawned (pid: {})", pid);
159
160 cx.background_spawn(async move {
161 let mut stderr = BufReader::new(stderr);
162 let mut line = String::new();
163 while let Ok(n) = stderr.read_line(&mut line).await
164 && n > 0
165 {
166 log::warn!("agent stderr: {}", &line);
167 line.clear();
168 }
169 })
170 .detach();
171
172 cx.background_spawn(async move {
173 let mut outgoing_rx = Some(outgoing_rx);
174
175 ClaudeAgentSession::handle_io(
176 outgoing_rx.take().unwrap(),
177 incoming_message_tx.clone(),
178 stdin,
179 stdout,
180 )
181 .await?;
182
183 log::trace!("Stopped (pid: {})", pid);
184
185 drop(mcp_config_path);
186 anyhow::Ok(())
187 })
188 .detach();
189
190 let turn_state = Rc::new(RefCell::new(TurnState::None));
191
192 let handler_task = cx.spawn({
193 let turn_state = turn_state.clone();
194 let mut thread_rx = thread_rx.clone();
195 async move |cx| {
196 while let Some(message) = incoming_message_rx.next().await {
197 ClaudeAgentSession::handle_message(
198 thread_rx.clone(),
199 message,
200 turn_state.clone(),
201 cx,
202 )
203 .await
204 }
205
206 if let Some(status) = child.status().await.log_err()
207 && let Some(thread) = thread_rx.recv().await.ok() {
208 thread
209 .update(cx, |thread, cx| {
210 thread.emit_server_exited(status, cx);
211 })
212 .ok();
213 }
214 }
215 });
216
217 let thread = cx.new(|cx| {
218 AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
219 })?;
220
221 thread_tx.send(thread.downgrade())?;
222
223 let session = ClaudeAgentSession {
224 outgoing_tx,
225 turn_state,
226 _handler_task: handler_task,
227 _mcp_server: Some(permission_mcp_server),
228 };
229
230 self.sessions.borrow_mut().insert(session_id, session);
231
232 Ok(thread)
233 })
234 }
235
236 fn auth_methods(&self) -> &[acp::AuthMethod] {
237 &[]
238 }
239
240 fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
241 Task::ready(Err(anyhow!("Authentication not supported")))
242 }
243
244 fn prompt(
245 &self,
246 _id: Option<acp_thread::UserMessageId>,
247 params: acp::PromptRequest,
248 cx: &mut App,
249 ) -> Task<Result<acp::PromptResponse>> {
250 let sessions = self.sessions.borrow();
251 let Some(session) = sessions.get(¶ms.session_id) else {
252 return Task::ready(Err(anyhow!(
253 "Attempted to send message to nonexistent session {}",
254 params.session_id
255 )));
256 };
257
258 let (end_tx, end_rx) = oneshot::channel();
259 session.turn_state.replace(TurnState::InProgress { end_tx });
260
261 let mut content = String::new();
262 for chunk in params.prompt {
263 match chunk {
264 acp::ContentBlock::Text(text_content) => {
265 content.push_str(&text_content.text);
266 }
267 acp::ContentBlock::ResourceLink(resource_link) => {
268 content.push_str(&format!("@{}", resource_link.uri));
269 }
270 acp::ContentBlock::Audio(_)
271 | acp::ContentBlock::Image(_)
272 | acp::ContentBlock::Resource(_) => {
273 // TODO
274 }
275 }
276 }
277
278 if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
279 message: Message {
280 role: Role::User,
281 content: Content::UntaggedText(content),
282 id: None,
283 model: None,
284 stop_reason: None,
285 stop_sequence: None,
286 usage: None,
287 },
288 session_id: Some(params.session_id.to_string()),
289 }) {
290 return Task::ready(Err(anyhow!(err)));
291 }
292
293 cx.foreground_executor().spawn(async move { end_rx.await? })
294 }
295
296 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
297 let sessions = self.sessions.borrow();
298 let Some(session) = sessions.get(session_id) else {
299 log::warn!("Attempted to cancel nonexistent session {}", session_id);
300 return;
301 };
302
303 let request_id = new_request_id();
304
305 let turn_state = session.turn_state.take();
306 let TurnState::InProgress { end_tx } = turn_state else {
307 // Already canceled or idle, put it back
308 session.turn_state.replace(turn_state);
309 return;
310 };
311
312 session.turn_state.replace(TurnState::CancelRequested {
313 end_tx,
314 request_id: request_id.clone(),
315 });
316
317 session
318 .outgoing_tx
319 .unbounded_send(SdkMessage::ControlRequest {
320 request_id,
321 request: ControlRequest::Interrupt,
322 })
323 .log_err();
324 }
325
326 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
327 self
328 }
329}
330
331#[derive(Clone, Copy)]
332enum ClaudeSessionMode {
333 Start,
334 #[expect(dead_code)]
335 Resume,
336}
337
338fn spawn_claude(
339 command: &AgentServerCommand,
340 mode: ClaudeSessionMode,
341 session_id: acp::SessionId,
342 api_key: language_models::provider::anthropic::ApiKey,
343 mcp_config_path: &Path,
344 root_dir: &Path,
345) -> Result<Child> {
346 let child = util::command::new_smol_command(&command.path)
347 .args([
348 "--input-format",
349 "stream-json",
350 "--output-format",
351 "stream-json",
352 "--print",
353 "--verbose",
354 "--mcp-config",
355 mcp_config_path.to_string_lossy().as_ref(),
356 "--permission-prompt-tool",
357 &format!(
358 "mcp__{}__{}",
359 mcp_server::SERVER_NAME,
360 mcp_server::PermissionTool::NAME,
361 ),
362 "--allowedTools",
363 &format!(
364 "mcp__{}__{},mcp__{}__{}",
365 mcp_server::SERVER_NAME,
366 mcp_server::EditTool::NAME,
367 mcp_server::SERVER_NAME,
368 mcp_server::ReadTool::NAME
369 ),
370 "--disallowedTools",
371 "Read,Edit",
372 ])
373 .args(match mode {
374 ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
375 ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
376 })
377 .args(command.args.iter().map(|arg| arg.as_str()))
378 .envs(command.env.iter().flatten())
379 .env("ANTHROPIC_API_KEY", api_key.key)
380 .current_dir(root_dir)
381 .stdin(std::process::Stdio::piped())
382 .stdout(std::process::Stdio::piped())
383 .stderr(std::process::Stdio::piped())
384 .kill_on_drop(true)
385 .spawn()?;
386
387 Ok(child)
388}
389
390struct ClaudeAgentSession {
391 outgoing_tx: UnboundedSender<SdkMessage>,
392 turn_state: Rc<RefCell<TurnState>>,
393 _mcp_server: Option<ClaudeZedMcpServer>,
394 _handler_task: Task<()>,
395}
396
397#[derive(Debug, Default)]
398enum TurnState {
399 #[default]
400 None,
401 InProgress {
402 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
403 },
404 CancelRequested {
405 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
406 request_id: String,
407 },
408 CancelConfirmed {
409 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
410 },
411}
412
413impl TurnState {
414 fn is_canceled(&self) -> bool {
415 matches!(self, TurnState::CancelConfirmed { .. })
416 }
417
418 fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
419 match self {
420 TurnState::None => None,
421 TurnState::InProgress { end_tx, .. } => Some(end_tx),
422 TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
423 TurnState::CancelConfirmed { end_tx } => Some(end_tx),
424 }
425 }
426
427 fn confirm_cancellation(self, id: &str) -> Self {
428 match self {
429 TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
430 TurnState::CancelConfirmed { end_tx }
431 }
432 _ => self,
433 }
434 }
435}
436
437impl ClaudeAgentSession {
438 async fn handle_message(
439 mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
440 message: SdkMessage,
441 turn_state: Rc<RefCell<TurnState>>,
442 cx: &mut AsyncApp,
443 ) {
444 match message {
445 // we should only be sending these out, they don't need to be in the thread
446 SdkMessage::ControlRequest { .. } => {}
447 SdkMessage::User {
448 message,
449 session_id: _,
450 } => {
451 let Some(thread) = thread_rx
452 .recv()
453 .await
454 .log_err()
455 .and_then(|entity| entity.upgrade())
456 else {
457 log::error!("Received an SDK message but thread is gone");
458 return;
459 };
460
461 for chunk in message.content.chunks() {
462 match chunk {
463 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
464 if !turn_state.borrow().is_canceled() {
465 thread
466 .update(cx, |thread, cx| {
467 thread.push_user_content_block(None, text.into(), cx)
468 })
469 .log_err();
470 }
471 }
472 ContentChunk::ToolResult {
473 content,
474 tool_use_id,
475 } => {
476 let content = content.to_string();
477 thread
478 .update(cx, |thread, cx| {
479 thread.update_tool_call(
480 acp::ToolCallUpdate {
481 id: acp::ToolCallId(tool_use_id.into()),
482 fields: acp::ToolCallUpdateFields {
483 status: if turn_state.borrow().is_canceled() {
484 // Do not set to completed if turn was canceled
485 None
486 } else {
487 Some(acp::ToolCallStatus::Completed)
488 },
489 content: (!content.is_empty())
490 .then(|| vec![content.into()]),
491 ..Default::default()
492 },
493 },
494 cx,
495 )
496 })
497 .log_err();
498 }
499 ContentChunk::Thinking { .. }
500 | ContentChunk::RedactedThinking
501 | ContentChunk::ToolUse { .. } => {
502 debug_panic!(
503 "Should not get {:?} with role: assistant. should we handle this?",
504 chunk
505 );
506 }
507
508 ContentChunk::Image
509 | ContentChunk::Document
510 | ContentChunk::WebSearchToolResult => {
511 thread
512 .update(cx, |thread, cx| {
513 thread.push_assistant_content_block(
514 format!("Unsupported content: {:?}", chunk).into(),
515 false,
516 cx,
517 )
518 })
519 .log_err();
520 }
521 }
522 }
523 }
524 SdkMessage::Assistant {
525 message,
526 session_id: _,
527 } => {
528 let Some(thread) = thread_rx
529 .recv()
530 .await
531 .log_err()
532 .and_then(|entity| entity.upgrade())
533 else {
534 log::error!("Received an SDK message but thread is gone");
535 return;
536 };
537
538 for chunk in message.content.chunks() {
539 match chunk {
540 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
541 thread
542 .update(cx, |thread, cx| {
543 thread.push_assistant_content_block(text.into(), false, cx)
544 })
545 .log_err();
546 }
547 ContentChunk::Thinking { thinking } => {
548 thread
549 .update(cx, |thread, cx| {
550 thread.push_assistant_content_block(thinking.into(), true, cx)
551 })
552 .log_err();
553 }
554 ContentChunk::RedactedThinking => {
555 thread
556 .update(cx, |thread, cx| {
557 thread.push_assistant_content_block(
558 "[REDACTED]".into(),
559 true,
560 cx,
561 )
562 })
563 .log_err();
564 }
565 ContentChunk::ToolUse { id, name, input } => {
566 let claude_tool = ClaudeTool::infer(&name, input);
567
568 thread
569 .update(cx, |thread, cx| {
570 if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
571 thread.update_plan(
572 acp::Plan {
573 entries: params
574 .todos
575 .into_iter()
576 .map(Into::into)
577 .collect(),
578 },
579 cx,
580 )
581 } else {
582 thread.upsert_tool_call(
583 claude_tool.as_acp(acp::ToolCallId(id.into())),
584 cx,
585 )?;
586 }
587 anyhow::Ok(())
588 })
589 .log_err();
590 }
591 ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
592 debug_panic!(
593 "Should not get tool results with role: assistant. should we handle this?"
594 );
595 }
596 ContentChunk::Image | ContentChunk::Document => {
597 thread
598 .update(cx, |thread, cx| {
599 thread.push_assistant_content_block(
600 format!("Unsupported content: {:?}", chunk).into(),
601 false,
602 cx,
603 )
604 })
605 .log_err();
606 }
607 }
608 }
609 }
610 SdkMessage::Result {
611 is_error,
612 subtype,
613 result,
614 ..
615 } => {
616 let turn_state = turn_state.take();
617 let was_canceled = turn_state.is_canceled();
618 let Some(end_turn_tx) = turn_state.end_tx() else {
619 debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
620 return;
621 };
622
623 if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
624 end_turn_tx
625 .send(Err(anyhow!(
626 "Error: {}",
627 result.unwrap_or_else(|| subtype.to_string())
628 )))
629 .ok();
630 } else {
631 let stop_reason = match subtype {
632 ResultErrorType::Success => acp::StopReason::EndTurn,
633 ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
634 ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
635 };
636 end_turn_tx
637 .send(Ok(acp::PromptResponse { stop_reason }))
638 .ok();
639 }
640 }
641 SdkMessage::ControlResponse { response } => {
642 if matches!(response.subtype, ResultErrorType::Success) {
643 let new_state = turn_state.take().confirm_cancellation(&response.request_id);
644 turn_state.replace(new_state);
645 } else {
646 log::error!("Control response error: {:?}", response);
647 }
648 }
649 SdkMessage::System { .. } => {}
650 }
651 }
652
653 async fn handle_io(
654 mut outgoing_rx: UnboundedReceiver<SdkMessage>,
655 incoming_tx: UnboundedSender<SdkMessage>,
656 mut outgoing_bytes: impl Unpin + AsyncWrite,
657 incoming_bytes: impl Unpin + AsyncRead,
658 ) -> Result<UnboundedReceiver<SdkMessage>> {
659 let mut output_reader = BufReader::new(incoming_bytes);
660 let mut outgoing_line = Vec::new();
661 let mut incoming_line = String::new();
662 loop {
663 select_biased! {
664 message = outgoing_rx.next() => {
665 if let Some(message) = message {
666 outgoing_line.clear();
667 serde_json::to_writer(&mut outgoing_line, &message)?;
668 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
669 outgoing_line.push(b'\n');
670 outgoing_bytes.write_all(&outgoing_line).await.ok();
671 } else {
672 break;
673 }
674 }
675 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
676 if bytes_read? == 0 {
677 break
678 }
679 log::trace!("recv: {}", &incoming_line);
680 match serde_json::from_str::<SdkMessage>(&incoming_line) {
681 Ok(message) => {
682 incoming_tx.unbounded_send(message).log_err();
683 }
684 Err(error) => {
685 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
686 }
687 }
688 incoming_line.clear();
689 }
690 }
691 }
692
693 Ok(outgoing_rx)
694 }
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
698struct Message {
699 role: Role,
700 content: Content,
701 #[serde(skip_serializing_if = "Option::is_none")]
702 id: Option<String>,
703 #[serde(skip_serializing_if = "Option::is_none")]
704 model: Option<String>,
705 #[serde(skip_serializing_if = "Option::is_none")]
706 stop_reason: Option<String>,
707 #[serde(skip_serializing_if = "Option::is_none")]
708 stop_sequence: Option<String>,
709 #[serde(skip_serializing_if = "Option::is_none")]
710 usage: Option<Usage>,
711}
712
713#[derive(Debug, Clone, Serialize, Deserialize)]
714#[serde(untagged)]
715enum Content {
716 UntaggedText(String),
717 Chunks(Vec<ContentChunk>),
718}
719
720impl Content {
721 pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
722 match self {
723 Self::Chunks(chunks) => chunks.into_iter(),
724 Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
725 }
726 }
727}
728
729impl Display for Content {
730 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
731 match self {
732 Content::UntaggedText(txt) => write!(f, "{}", txt),
733 Content::Chunks(chunks) => {
734 for chunk in chunks {
735 write!(f, "{}", chunk)?;
736 }
737 Ok(())
738 }
739 }
740 }
741}
742
743#[derive(Debug, Clone, Serialize, Deserialize)]
744#[serde(tag = "type", rename_all = "snake_case")]
745enum ContentChunk {
746 Text {
747 text: String,
748 },
749 ToolUse {
750 id: String,
751 name: String,
752 input: serde_json::Value,
753 },
754 ToolResult {
755 content: Content,
756 tool_use_id: String,
757 },
758 Thinking {
759 thinking: String,
760 },
761 RedactedThinking,
762 // TODO
763 Image,
764 Document,
765 WebSearchToolResult,
766 #[serde(untagged)]
767 UntaggedText(String),
768}
769
770impl Display for ContentChunk {
771 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
772 match self {
773 ContentChunk::Text { text } => write!(f, "{}", text),
774 ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
775 ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
776 ContentChunk::UntaggedText(text) => write!(f, "{}", text),
777 ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
778 ContentChunk::Image
779 | ContentChunk::Document
780 | ContentChunk::ToolUse { .. }
781 | ContentChunk::WebSearchToolResult => {
782 write!(f, "\n{:?}\n", &self)
783 }
784 }
785 }
786}
787
788#[derive(Debug, Clone, Serialize, Deserialize)]
789struct Usage {
790 input_tokens: u32,
791 cache_creation_input_tokens: u32,
792 cache_read_input_tokens: u32,
793 output_tokens: u32,
794 service_tier: String,
795}
796
797#[derive(Debug, Clone, Serialize, Deserialize)]
798#[serde(rename_all = "snake_case")]
799enum Role {
800 System,
801 Assistant,
802 User,
803}
804
805#[derive(Debug, Clone, Serialize, Deserialize)]
806struct MessageParam {
807 role: Role,
808 content: String,
809}
810
811#[derive(Debug, Clone, Serialize, Deserialize)]
812#[serde(tag = "type", rename_all = "snake_case")]
813enum SdkMessage {
814 // An assistant message
815 Assistant {
816 message: Message, // from Anthropic SDK
817 #[serde(skip_serializing_if = "Option::is_none")]
818 session_id: Option<String>,
819 },
820 // A user message
821 User {
822 message: Message, // from Anthropic SDK
823 #[serde(skip_serializing_if = "Option::is_none")]
824 session_id: Option<String>,
825 },
826 // Emitted as the last message in a conversation
827 Result {
828 subtype: ResultErrorType,
829 duration_ms: f64,
830 duration_api_ms: f64,
831 is_error: bool,
832 num_turns: i32,
833 #[serde(skip_serializing_if = "Option::is_none")]
834 result: Option<String>,
835 session_id: String,
836 total_cost_usd: f64,
837 },
838 // Emitted as the first message at the start of a conversation
839 System {
840 cwd: String,
841 session_id: String,
842 tools: Vec<String>,
843 model: String,
844 mcp_servers: Vec<McpServer>,
845 #[serde(rename = "apiKeySource")]
846 api_key_source: String,
847 #[serde(rename = "permissionMode")]
848 permission_mode: PermissionMode,
849 },
850 /// Messages used to control the conversation, outside of chat messages to the model
851 ControlRequest {
852 request_id: String,
853 request: ControlRequest,
854 },
855 /// Response to a control request
856 ControlResponse { response: ControlResponse },
857}
858
859#[derive(Debug, Clone, Serialize, Deserialize)]
860#[serde(tag = "subtype", rename_all = "snake_case")]
861enum ControlRequest {
862 /// Cancel the current conversation
863 Interrupt,
864}
865
866#[derive(Debug, Clone, Serialize, Deserialize)]
867struct ControlResponse {
868 request_id: String,
869 subtype: ResultErrorType,
870}
871
872#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
873#[serde(rename_all = "snake_case")]
874enum ResultErrorType {
875 Success,
876 ErrorMaxTurns,
877 ErrorDuringExecution,
878}
879
880impl Display for ResultErrorType {
881 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
882 match self {
883 ResultErrorType::Success => write!(f, "success"),
884 ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
885 ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
886 }
887 }
888}
889
890fn new_request_id() -> String {
891 use rand::Rng;
892 // In the Claude Code TS SDK they just generate a random 12 character string,
893 // `Math.random().toString(36).substring(2, 15)`
894 rand::thread_rng()
895 .sample_iter(&rand::distributions::Alphanumeric)
896 .take(12)
897 .map(char::from)
898 .collect()
899}
900
901#[derive(Debug, Clone, Serialize, Deserialize)]
902struct McpServer {
903 name: String,
904 status: String,
905}
906
907#[derive(Debug, Clone, Serialize, Deserialize)]
908#[serde(rename_all = "camelCase")]
909enum PermissionMode {
910 Default,
911 AcceptEdits,
912 BypassPermissions,
913 Plan,
914}
915
916#[cfg(test)]
917pub(crate) mod tests {
918 use super::*;
919 use crate::e2e_tests;
920 use gpui::TestAppContext;
921 use serde_json::json;
922
923 crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
924
925 pub fn local_command() -> AgentServerCommand {
926 AgentServerCommand {
927 path: "claude".into(),
928 args: vec![],
929 env: None,
930 }
931 }
932
933 #[gpui::test]
934 #[cfg_attr(not(feature = "e2e"), ignore)]
935 async fn test_todo_plan(cx: &mut TestAppContext) {
936 let fs = e2e_tests::init_test(cx).await;
937 let project = Project::test(fs, [], cx).await;
938 let thread =
939 e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
940
941 thread
942 .update(cx, |thread, cx| {
943 thread.send_raw(
944 "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
945 cx,
946 )
947 })
948 .await
949 .unwrap();
950
951 let mut entries_len = 0;
952
953 thread.read_with(cx, |thread, _| {
954 entries_len = thread.plan().entries.len();
955 assert!(thread.plan().entries.len() > 0, "Empty plan");
956 });
957
958 thread
959 .update(cx, |thread, cx| {
960 thread.send_raw(
961 "Mark the first entry status as in progress without acting on it.",
962 cx,
963 )
964 })
965 .await
966 .unwrap();
967
968 thread.read_with(cx, |thread, _| {
969 assert!(matches!(
970 thread.plan().entries[0].status,
971 acp::PlanEntryStatus::InProgress
972 ));
973 assert_eq!(thread.plan().entries.len(), entries_len);
974 });
975
976 thread
977 .update(cx, |thread, cx| {
978 thread.send_raw(
979 "Now mark the first entry as completed without acting on it.",
980 cx,
981 )
982 })
983 .await
984 .unwrap();
985
986 thread.read_with(cx, |thread, _| {
987 assert!(matches!(
988 thread.plan().entries[0].status,
989 acp::PlanEntryStatus::Completed
990 ));
991 assert_eq!(thread.plan().entries.len(), entries_len);
992 });
993 }
994
995 #[test]
996 fn test_deserialize_content_untagged_text() {
997 let json = json!("Hello, world!");
998 let content: Content = serde_json::from_value(json).unwrap();
999 match content {
1000 Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1001 _ => panic!("Expected UntaggedText variant"),
1002 }
1003 }
1004
1005 #[test]
1006 fn test_deserialize_content_chunks() {
1007 let json = json!([
1008 {
1009 "type": "text",
1010 "text": "Hello"
1011 },
1012 {
1013 "type": "tool_use",
1014 "id": "tool_123",
1015 "name": "calculator",
1016 "input": {"operation": "add", "a": 1, "b": 2}
1017 }
1018 ]);
1019 let content: Content = serde_json::from_value(json).unwrap();
1020 match content {
1021 Content::Chunks(chunks) => {
1022 assert_eq!(chunks.len(), 2);
1023 match &chunks[0] {
1024 ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1025 _ => panic!("Expected Text chunk"),
1026 }
1027 match &chunks[1] {
1028 ContentChunk::ToolUse { id, name, input } => {
1029 assert_eq!(id, "tool_123");
1030 assert_eq!(name, "calculator");
1031 assert_eq!(input["operation"], "add");
1032 assert_eq!(input["a"], 1);
1033 assert_eq!(input["b"], 2);
1034 }
1035 _ => panic!("Expected ToolUse chunk"),
1036 }
1037 }
1038 _ => panic!("Expected Chunks variant"),
1039 }
1040 }
1041
1042 #[test]
1043 fn test_deserialize_tool_result_untagged_text() {
1044 let json = json!({
1045 "type": "tool_result",
1046 "content": "Result content",
1047 "tool_use_id": "tool_456"
1048 });
1049 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1050 match chunk {
1051 ContentChunk::ToolResult {
1052 content,
1053 tool_use_id,
1054 } => {
1055 match content {
1056 Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1057 _ => panic!("Expected UntaggedText content"),
1058 }
1059 assert_eq!(tool_use_id, "tool_456");
1060 }
1061 _ => panic!("Expected ToolResult variant"),
1062 }
1063 }
1064
1065 #[test]
1066 fn test_deserialize_tool_result_chunks() {
1067 let json = json!({
1068 "type": "tool_result",
1069 "content": [
1070 {
1071 "type": "text",
1072 "text": "Processing complete"
1073 },
1074 {
1075 "type": "text",
1076 "text": "Result: 42"
1077 }
1078 ],
1079 "tool_use_id": "tool_789"
1080 });
1081 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1082 match chunk {
1083 ContentChunk::ToolResult {
1084 content,
1085 tool_use_id,
1086 } => {
1087 match content {
1088 Content::Chunks(chunks) => {
1089 assert_eq!(chunks.len(), 2);
1090 match &chunks[0] {
1091 ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1092 _ => panic!("Expected Text chunk"),
1093 }
1094 match &chunks[1] {
1095 ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1096 _ => panic!("Expected Text chunk"),
1097 }
1098 }
1099 _ => panic!("Expected Chunks content"),
1100 }
1101 assert_eq!(tool_use_id, "tool_789");
1102 }
1103 _ => panic!("Expected ToolResult variant"),
1104 }
1105 }
1106}