1mod mcp_server;
2pub mod tools;
3
4use collections::HashMap;
5use context_server::listener::McpServerTool;
6use language_models::provider::anthropic::AnthropicLanguageModelProvider;
7use project::Project;
8use settings::SettingsStore;
9use smol::process::Child;
10use std::any::Any;
11use std::cell::RefCell;
12use std::fmt::Display;
13use std::path::Path;
14use std::rc::Rc;
15use uuid::Uuid;
16
17use agent_client_protocol as acp;
18use anyhow::{Context as _, Result, anyhow};
19use futures::channel::oneshot;
20use futures::{AsyncBufReadExt, AsyncWriteExt};
21use futures::{
22 AsyncRead, AsyncWrite, FutureExt, StreamExt,
23 channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
24 io::BufReader,
25 select_biased,
26};
27use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
28use serde::{Deserialize, Serialize};
29use util::{ResultExt, debug_panic};
30
31use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
32use crate::claude::tools::ClaudeTool;
33use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
34use acp_thread::{AcpThread, AgentConnection, AuthRequired};
35
36#[derive(Clone)]
37pub struct ClaudeCode;
38
39impl AgentServer for ClaudeCode {
40 fn name(&self) -> &'static str {
41 "Claude Code"
42 }
43
44 fn empty_state_headline(&self) -> &'static str {
45 self.name()
46 }
47
48 fn empty_state_message(&self) -> &'static str {
49 "How can I help you today?"
50 }
51
52 fn logo(&self) -> ui::IconName {
53 ui::IconName::AiClaude
54 }
55
56 fn connect(
57 &self,
58 _root_dir: &Path,
59 _project: &Entity<Project>,
60 _cx: &mut App,
61 ) -> Task<Result<Rc<dyn AgentConnection>>> {
62 let connection = ClaudeAgentConnection {
63 sessions: Default::default(),
64 };
65
66 Task::ready(Ok(Rc::new(connection) as _))
67 }
68
69 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
70 self
71 }
72}
73
74struct ClaudeAgentConnection {
75 sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
76}
77
78impl AgentConnection for ClaudeAgentConnection {
79 fn new_thread(
80 self: Rc<Self>,
81 project: Entity<Project>,
82 cwd: &Path,
83 cx: &mut App,
84 ) -> Task<Result<Entity<AcpThread>>> {
85 let cwd = cwd.to_owned();
86 cx.spawn(async move |cx| {
87 let settings = cx.read_global(|settings: &SettingsStore, _| {
88 settings.get::<AllAgentServersSettings>(None).claude.clone()
89 })?;
90
91 let Some(command) = AgentServerCommand::resolve(
92 "claude",
93 &[],
94 Some(&util::paths::home_dir().join(".claude/local/claude")),
95 settings,
96 &project,
97 cx,
98 )
99 .await
100 else {
101 anyhow::bail!("Failed to find claude binary");
102 };
103
104 let api_key =
105 cx.update(AnthropicLanguageModelProvider::api_key)?
106 .await
107 .map_err(|err| {
108 if err.is::<language_model::AuthenticateError>() {
109 anyhow!(AuthRequired::new().with_language_model_provider(
110 language_model::ANTHROPIC_PROVIDER_ID
111 ))
112 } else {
113 anyhow!(err)
114 }
115 })?;
116
117 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
118 let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
119 let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
120
121 let mut mcp_servers = HashMap::default();
122 mcp_servers.insert(
123 mcp_server::SERVER_NAME.to_string(),
124 permission_mcp_server.server_config()?,
125 );
126 let mcp_config = McpConfig { mcp_servers };
127
128 let mcp_config_file = tempfile::NamedTempFile::new()?;
129 let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
130
131 let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
132 mcp_config_file
133 .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
134 .await?;
135 mcp_config_file.flush().await?;
136
137 let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
138 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
139
140 let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
141
142 log::trace!("Starting session with id: {}", session_id);
143
144 let mut child = spawn_claude(
145 &command,
146 ClaudeSessionMode::Start,
147 session_id.clone(),
148 api_key,
149 &mcp_config_path,
150 &cwd,
151 )?;
152
153 let stdout = child.stdout.take().context("Failed to take stdout")?;
154 let stdin = child.stdin.take().context("Failed to take stdin")?;
155 let stderr = child.stderr.take().context("Failed to take stderr")?;
156
157 let pid = child.id();
158 log::trace!("Spawned (pid: {})", pid);
159
160 cx.background_spawn(async move {
161 let mut stderr = BufReader::new(stderr);
162 let mut line = String::new();
163 while let Ok(n) = stderr.read_line(&mut line).await
164 && n > 0
165 {
166 log::warn!("agent stderr: {}", &line);
167 line.clear();
168 }
169 })
170 .detach();
171
172 cx.background_spawn(async move {
173 let mut outgoing_rx = Some(outgoing_rx);
174
175 ClaudeAgentSession::handle_io(
176 outgoing_rx.take().unwrap(),
177 incoming_message_tx.clone(),
178 stdin,
179 stdout,
180 )
181 .await?;
182
183 log::trace!("Stopped (pid: {})", pid);
184
185 drop(mcp_config_path);
186 anyhow::Ok(())
187 })
188 .detach();
189
190 let turn_state = Rc::new(RefCell::new(TurnState::None));
191
192 let handler_task = cx.spawn({
193 let turn_state = turn_state.clone();
194 let mut thread_rx = thread_rx.clone();
195 async move |cx| {
196 while let Some(message) = incoming_message_rx.next().await {
197 ClaudeAgentSession::handle_message(
198 thread_rx.clone(),
199 message,
200 turn_state.clone(),
201 cx,
202 )
203 .await
204 }
205
206 if let Some(status) = child.status().await.log_err() {
207 if let Some(thread) = thread_rx.recv().await.ok() {
208 thread
209 .update(cx, |thread, cx| {
210 thread.emit_server_exited(status, cx);
211 })
212 .ok();
213 }
214 }
215 }
216 });
217
218 let thread = cx.new(|cx| {
219 AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
220 })?;
221
222 thread_tx.send(thread.downgrade())?;
223
224 let session = ClaudeAgentSession {
225 outgoing_tx,
226 turn_state,
227 _handler_task: handler_task,
228 _mcp_server: Some(permission_mcp_server),
229 };
230
231 self.sessions.borrow_mut().insert(session_id, session);
232
233 Ok(thread)
234 })
235 }
236
237 fn auth_methods(&self) -> &[acp::AuthMethod] {
238 &[]
239 }
240
241 fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
242 Task::ready(Err(anyhow!("Authentication not supported")))
243 }
244
245 fn prompt(
246 &self,
247 _id: Option<acp_thread::UserMessageId>,
248 params: acp::PromptRequest,
249 cx: &mut App,
250 ) -> Task<Result<acp::PromptResponse>> {
251 let sessions = self.sessions.borrow();
252 let Some(session) = sessions.get(¶ms.session_id) else {
253 return Task::ready(Err(anyhow!(
254 "Attempted to send message to nonexistent session {}",
255 params.session_id
256 )));
257 };
258
259 let (end_tx, end_rx) = oneshot::channel();
260 session.turn_state.replace(TurnState::InProgress { end_tx });
261
262 let mut content = String::new();
263 for chunk in params.prompt {
264 match chunk {
265 acp::ContentBlock::Text(text_content) => {
266 content.push_str(&text_content.text);
267 }
268 acp::ContentBlock::ResourceLink(resource_link) => {
269 content.push_str(&format!("@{}", resource_link.uri));
270 }
271 acp::ContentBlock::Audio(_)
272 | acp::ContentBlock::Image(_)
273 | acp::ContentBlock::Resource(_) => {
274 // TODO
275 }
276 }
277 }
278
279 if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
280 message: Message {
281 role: Role::User,
282 content: Content::UntaggedText(content),
283 id: None,
284 model: None,
285 stop_reason: None,
286 stop_sequence: None,
287 usage: None,
288 },
289 session_id: Some(params.session_id.to_string()),
290 }) {
291 return Task::ready(Err(anyhow!(err)));
292 }
293
294 cx.foreground_executor().spawn(async move { end_rx.await? })
295 }
296
297 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
298 let sessions = self.sessions.borrow();
299 let Some(session) = sessions.get(session_id) else {
300 log::warn!("Attempted to cancel nonexistent session {}", session_id);
301 return;
302 };
303
304 let request_id = new_request_id();
305
306 let turn_state = session.turn_state.take();
307 let TurnState::InProgress { end_tx } = turn_state else {
308 // Already canceled or idle, put it back
309 session.turn_state.replace(turn_state);
310 return;
311 };
312
313 session.turn_state.replace(TurnState::CancelRequested {
314 end_tx,
315 request_id: request_id.clone(),
316 });
317
318 session
319 .outgoing_tx
320 .unbounded_send(SdkMessage::ControlRequest {
321 request_id,
322 request: ControlRequest::Interrupt,
323 })
324 .log_err();
325 }
326
327 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
328 self
329 }
330}
331
332#[derive(Clone, Copy)]
333enum ClaudeSessionMode {
334 Start,
335 #[expect(dead_code)]
336 Resume,
337}
338
339fn spawn_claude(
340 command: &AgentServerCommand,
341 mode: ClaudeSessionMode,
342 session_id: acp::SessionId,
343 api_key: language_models::provider::anthropic::ApiKey,
344 mcp_config_path: &Path,
345 root_dir: &Path,
346) -> Result<Child> {
347 let child = util::command::new_smol_command(&command.path)
348 .args([
349 "--input-format",
350 "stream-json",
351 "--output-format",
352 "stream-json",
353 "--print",
354 "--verbose",
355 "--mcp-config",
356 mcp_config_path.to_string_lossy().as_ref(),
357 "--permission-prompt-tool",
358 &format!(
359 "mcp__{}__{}",
360 mcp_server::SERVER_NAME,
361 mcp_server::PermissionTool::NAME,
362 ),
363 "--allowedTools",
364 &format!(
365 "mcp__{}__{},mcp__{}__{}",
366 mcp_server::SERVER_NAME,
367 mcp_server::EditTool::NAME,
368 mcp_server::SERVER_NAME,
369 mcp_server::ReadTool::NAME
370 ),
371 "--disallowedTools",
372 "Read,Edit",
373 ])
374 .args(match mode {
375 ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
376 ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
377 })
378 .args(command.args.iter().map(|arg| arg.as_str()))
379 .envs(command.env.iter().flatten())
380 .env("ANTHROPIC_API_KEY", api_key.key)
381 .current_dir(root_dir)
382 .stdin(std::process::Stdio::piped())
383 .stdout(std::process::Stdio::piped())
384 .stderr(std::process::Stdio::piped())
385 .kill_on_drop(true)
386 .spawn()?;
387
388 Ok(child)
389}
390
391struct ClaudeAgentSession {
392 outgoing_tx: UnboundedSender<SdkMessage>,
393 turn_state: Rc<RefCell<TurnState>>,
394 _mcp_server: Option<ClaudeZedMcpServer>,
395 _handler_task: Task<()>,
396}
397
398#[derive(Debug, Default)]
399enum TurnState {
400 #[default]
401 None,
402 InProgress {
403 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
404 },
405 CancelRequested {
406 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
407 request_id: String,
408 },
409 CancelConfirmed {
410 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
411 },
412}
413
414impl TurnState {
415 fn is_canceled(&self) -> bool {
416 matches!(self, TurnState::CancelConfirmed { .. })
417 }
418
419 fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
420 match self {
421 TurnState::None => None,
422 TurnState::InProgress { end_tx, .. } => Some(end_tx),
423 TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
424 TurnState::CancelConfirmed { end_tx } => Some(end_tx),
425 }
426 }
427
428 fn confirm_cancellation(self, id: &str) -> Self {
429 match self {
430 TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
431 TurnState::CancelConfirmed { end_tx }
432 }
433 _ => self,
434 }
435 }
436}
437
438impl ClaudeAgentSession {
439 async fn handle_message(
440 mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
441 message: SdkMessage,
442 turn_state: Rc<RefCell<TurnState>>,
443 cx: &mut AsyncApp,
444 ) {
445 match message {
446 // we should only be sending these out, they don't need to be in the thread
447 SdkMessage::ControlRequest { .. } => {}
448 SdkMessage::User {
449 message,
450 session_id: _,
451 } => {
452 let Some(thread) = thread_rx
453 .recv()
454 .await
455 .log_err()
456 .and_then(|entity| entity.upgrade())
457 else {
458 log::error!("Received an SDK message but thread is gone");
459 return;
460 };
461
462 for chunk in message.content.chunks() {
463 match chunk {
464 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
465 if !turn_state.borrow().is_canceled() {
466 thread
467 .update(cx, |thread, cx| {
468 thread.push_user_content_block(None, text.into(), cx)
469 })
470 .log_err();
471 }
472 }
473 ContentChunk::ToolResult {
474 content,
475 tool_use_id,
476 } => {
477 let content = content.to_string();
478 thread
479 .update(cx, |thread, cx| {
480 thread.update_tool_call(
481 acp::ToolCallUpdate {
482 id: acp::ToolCallId(tool_use_id.into()),
483 fields: acp::ToolCallUpdateFields {
484 status: if turn_state.borrow().is_canceled() {
485 // Do not set to completed if turn was canceled
486 None
487 } else {
488 Some(acp::ToolCallStatus::Completed)
489 },
490 content: (!content.is_empty())
491 .then(|| vec![content.into()]),
492 ..Default::default()
493 },
494 },
495 cx,
496 )
497 })
498 .log_err();
499 }
500 ContentChunk::Thinking { .. }
501 | ContentChunk::RedactedThinking
502 | ContentChunk::ToolUse { .. } => {
503 debug_panic!(
504 "Should not get {:?} with role: assistant. should we handle this?",
505 chunk
506 );
507 }
508
509 ContentChunk::Image
510 | ContentChunk::Document
511 | ContentChunk::WebSearchToolResult => {
512 thread
513 .update(cx, |thread, cx| {
514 thread.push_assistant_content_block(
515 format!("Unsupported content: {:?}", chunk).into(),
516 false,
517 cx,
518 )
519 })
520 .log_err();
521 }
522 }
523 }
524 }
525 SdkMessage::Assistant {
526 message,
527 session_id: _,
528 } => {
529 let Some(thread) = thread_rx
530 .recv()
531 .await
532 .log_err()
533 .and_then(|entity| entity.upgrade())
534 else {
535 log::error!("Received an SDK message but thread is gone");
536 return;
537 };
538
539 for chunk in message.content.chunks() {
540 match chunk {
541 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
542 thread
543 .update(cx, |thread, cx| {
544 thread.push_assistant_content_block(text.into(), false, cx)
545 })
546 .log_err();
547 }
548 ContentChunk::Thinking { thinking } => {
549 thread
550 .update(cx, |thread, cx| {
551 thread.push_assistant_content_block(thinking.into(), true, cx)
552 })
553 .log_err();
554 }
555 ContentChunk::RedactedThinking => {
556 thread
557 .update(cx, |thread, cx| {
558 thread.push_assistant_content_block(
559 "[REDACTED]".into(),
560 true,
561 cx,
562 )
563 })
564 .log_err();
565 }
566 ContentChunk::ToolUse { id, name, input } => {
567 let claude_tool = ClaudeTool::infer(&name, input);
568
569 thread
570 .update(cx, |thread, cx| {
571 if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
572 thread.update_plan(
573 acp::Plan {
574 entries: params
575 .todos
576 .into_iter()
577 .map(Into::into)
578 .collect(),
579 },
580 cx,
581 )
582 } else {
583 thread.upsert_tool_call(
584 claude_tool.as_acp(acp::ToolCallId(id.into())),
585 cx,
586 )?;
587 }
588 anyhow::Ok(())
589 })
590 .log_err();
591 }
592 ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
593 debug_panic!(
594 "Should not get tool results with role: assistant. should we handle this?"
595 );
596 }
597 ContentChunk::Image | ContentChunk::Document => {
598 thread
599 .update(cx, |thread, cx| {
600 thread.push_assistant_content_block(
601 format!("Unsupported content: {:?}", chunk).into(),
602 false,
603 cx,
604 )
605 })
606 .log_err();
607 }
608 }
609 }
610 }
611 SdkMessage::Result {
612 is_error,
613 subtype,
614 result,
615 ..
616 } => {
617 let turn_state = turn_state.take();
618 let was_canceled = turn_state.is_canceled();
619 let Some(end_turn_tx) = turn_state.end_tx() else {
620 debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
621 return;
622 };
623
624 if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
625 end_turn_tx
626 .send(Err(anyhow!(
627 "Error: {}",
628 result.unwrap_or_else(|| subtype.to_string())
629 )))
630 .ok();
631 } else {
632 let stop_reason = match subtype {
633 ResultErrorType::Success => acp::StopReason::EndTurn,
634 ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
635 ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
636 };
637 end_turn_tx
638 .send(Ok(acp::PromptResponse { stop_reason }))
639 .ok();
640 }
641 }
642 SdkMessage::ControlResponse { response } => {
643 if matches!(response.subtype, ResultErrorType::Success) {
644 let new_state = turn_state.take().confirm_cancellation(&response.request_id);
645 turn_state.replace(new_state);
646 } else {
647 log::error!("Control response error: {:?}", response);
648 }
649 }
650 SdkMessage::System { .. } => {}
651 }
652 }
653
654 async fn handle_io(
655 mut outgoing_rx: UnboundedReceiver<SdkMessage>,
656 incoming_tx: UnboundedSender<SdkMessage>,
657 mut outgoing_bytes: impl Unpin + AsyncWrite,
658 incoming_bytes: impl Unpin + AsyncRead,
659 ) -> Result<UnboundedReceiver<SdkMessage>> {
660 let mut output_reader = BufReader::new(incoming_bytes);
661 let mut outgoing_line = Vec::new();
662 let mut incoming_line = String::new();
663 loop {
664 select_biased! {
665 message = outgoing_rx.next() => {
666 if let Some(message) = message {
667 outgoing_line.clear();
668 serde_json::to_writer(&mut outgoing_line, &message)?;
669 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
670 outgoing_line.push(b'\n');
671 outgoing_bytes.write_all(&outgoing_line).await.ok();
672 } else {
673 break;
674 }
675 }
676 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
677 if bytes_read? == 0 {
678 break
679 }
680 log::trace!("recv: {}", &incoming_line);
681 match serde_json::from_str::<SdkMessage>(&incoming_line) {
682 Ok(message) => {
683 incoming_tx.unbounded_send(message).log_err();
684 }
685 Err(error) => {
686 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
687 }
688 }
689 incoming_line.clear();
690 }
691 }
692 }
693
694 Ok(outgoing_rx)
695 }
696}
697
698#[derive(Debug, Clone, Serialize, Deserialize)]
699struct Message {
700 role: Role,
701 content: Content,
702 #[serde(skip_serializing_if = "Option::is_none")]
703 id: Option<String>,
704 #[serde(skip_serializing_if = "Option::is_none")]
705 model: Option<String>,
706 #[serde(skip_serializing_if = "Option::is_none")]
707 stop_reason: Option<String>,
708 #[serde(skip_serializing_if = "Option::is_none")]
709 stop_sequence: Option<String>,
710 #[serde(skip_serializing_if = "Option::is_none")]
711 usage: Option<Usage>,
712}
713
714#[derive(Debug, Clone, Serialize, Deserialize)]
715#[serde(untagged)]
716enum Content {
717 UntaggedText(String),
718 Chunks(Vec<ContentChunk>),
719}
720
721impl Content {
722 pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
723 match self {
724 Self::Chunks(chunks) => chunks.into_iter(),
725 Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
726 }
727 }
728}
729
730impl Display for Content {
731 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
732 match self {
733 Content::UntaggedText(txt) => write!(f, "{}", txt),
734 Content::Chunks(chunks) => {
735 for chunk in chunks {
736 write!(f, "{}", chunk)?;
737 }
738 Ok(())
739 }
740 }
741 }
742}
743
744#[derive(Debug, Clone, Serialize, Deserialize)]
745#[serde(tag = "type", rename_all = "snake_case")]
746enum ContentChunk {
747 Text {
748 text: String,
749 },
750 ToolUse {
751 id: String,
752 name: String,
753 input: serde_json::Value,
754 },
755 ToolResult {
756 content: Content,
757 tool_use_id: String,
758 },
759 Thinking {
760 thinking: String,
761 },
762 RedactedThinking,
763 // TODO
764 Image,
765 Document,
766 WebSearchToolResult,
767 #[serde(untagged)]
768 UntaggedText(String),
769}
770
771impl Display for ContentChunk {
772 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
773 match self {
774 ContentChunk::Text { text } => write!(f, "{}", text),
775 ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
776 ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
777 ContentChunk::UntaggedText(text) => write!(f, "{}", text),
778 ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
779 ContentChunk::Image
780 | ContentChunk::Document
781 | ContentChunk::ToolUse { .. }
782 | ContentChunk::WebSearchToolResult => {
783 write!(f, "\n{:?}\n", &self)
784 }
785 }
786 }
787}
788
789#[derive(Debug, Clone, Serialize, Deserialize)]
790struct Usage {
791 input_tokens: u32,
792 cache_creation_input_tokens: u32,
793 cache_read_input_tokens: u32,
794 output_tokens: u32,
795 service_tier: String,
796}
797
798#[derive(Debug, Clone, Serialize, Deserialize)]
799#[serde(rename_all = "snake_case")]
800enum Role {
801 System,
802 Assistant,
803 User,
804}
805
806#[derive(Debug, Clone, Serialize, Deserialize)]
807struct MessageParam {
808 role: Role,
809 content: String,
810}
811
812#[derive(Debug, Clone, Serialize, Deserialize)]
813#[serde(tag = "type", rename_all = "snake_case")]
814enum SdkMessage {
815 // An assistant message
816 Assistant {
817 message: Message, // from Anthropic SDK
818 #[serde(skip_serializing_if = "Option::is_none")]
819 session_id: Option<String>,
820 },
821 // A user message
822 User {
823 message: Message, // from Anthropic SDK
824 #[serde(skip_serializing_if = "Option::is_none")]
825 session_id: Option<String>,
826 },
827 // Emitted as the last message in a conversation
828 Result {
829 subtype: ResultErrorType,
830 duration_ms: f64,
831 duration_api_ms: f64,
832 is_error: bool,
833 num_turns: i32,
834 #[serde(skip_serializing_if = "Option::is_none")]
835 result: Option<String>,
836 session_id: String,
837 total_cost_usd: f64,
838 },
839 // Emitted as the first message at the start of a conversation
840 System {
841 cwd: String,
842 session_id: String,
843 tools: Vec<String>,
844 model: String,
845 mcp_servers: Vec<McpServer>,
846 #[serde(rename = "apiKeySource")]
847 api_key_source: String,
848 #[serde(rename = "permissionMode")]
849 permission_mode: PermissionMode,
850 },
851 /// Messages used to control the conversation, outside of chat messages to the model
852 ControlRequest {
853 request_id: String,
854 request: ControlRequest,
855 },
856 /// Response to a control request
857 ControlResponse { response: ControlResponse },
858}
859
860#[derive(Debug, Clone, Serialize, Deserialize)]
861#[serde(tag = "subtype", rename_all = "snake_case")]
862enum ControlRequest {
863 /// Cancel the current conversation
864 Interrupt,
865}
866
867#[derive(Debug, Clone, Serialize, Deserialize)]
868struct ControlResponse {
869 request_id: String,
870 subtype: ResultErrorType,
871}
872
873#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
874#[serde(rename_all = "snake_case")]
875enum ResultErrorType {
876 Success,
877 ErrorMaxTurns,
878 ErrorDuringExecution,
879}
880
881impl Display for ResultErrorType {
882 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
883 match self {
884 ResultErrorType::Success => write!(f, "success"),
885 ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
886 ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
887 }
888 }
889}
890
891fn new_request_id() -> String {
892 use rand::Rng;
893 // In the Claude Code TS SDK they just generate a random 12 character string,
894 // `Math.random().toString(36).substring(2, 15)`
895 rand::thread_rng()
896 .sample_iter(&rand::distributions::Alphanumeric)
897 .take(12)
898 .map(char::from)
899 .collect()
900}
901
902#[derive(Debug, Clone, Serialize, Deserialize)]
903struct McpServer {
904 name: String,
905 status: String,
906}
907
908#[derive(Debug, Clone, Serialize, Deserialize)]
909#[serde(rename_all = "camelCase")]
910enum PermissionMode {
911 Default,
912 AcceptEdits,
913 BypassPermissions,
914 Plan,
915}
916
917#[cfg(test)]
918pub(crate) mod tests {
919 use super::*;
920 use crate::e2e_tests;
921 use gpui::TestAppContext;
922 use serde_json::json;
923
924 crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
925
926 pub fn local_command() -> AgentServerCommand {
927 AgentServerCommand {
928 path: "claude".into(),
929 args: vec![],
930 env: None,
931 }
932 }
933
934 #[gpui::test]
935 #[cfg_attr(not(feature = "e2e"), ignore)]
936 async fn test_todo_plan(cx: &mut TestAppContext) {
937 let fs = e2e_tests::init_test(cx).await;
938 let project = Project::test(fs, [], cx).await;
939 let thread =
940 e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
941
942 thread
943 .update(cx, |thread, cx| {
944 thread.send_raw(
945 "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
946 cx,
947 )
948 })
949 .await
950 .unwrap();
951
952 let mut entries_len = 0;
953
954 thread.read_with(cx, |thread, _| {
955 entries_len = thread.plan().entries.len();
956 assert!(thread.plan().entries.len() > 0, "Empty plan");
957 });
958
959 thread
960 .update(cx, |thread, cx| {
961 thread.send_raw(
962 "Mark the first entry status as in progress without acting on it.",
963 cx,
964 )
965 })
966 .await
967 .unwrap();
968
969 thread.read_with(cx, |thread, _| {
970 assert!(matches!(
971 thread.plan().entries[0].status,
972 acp::PlanEntryStatus::InProgress
973 ));
974 assert_eq!(thread.plan().entries.len(), entries_len);
975 });
976
977 thread
978 .update(cx, |thread, cx| {
979 thread.send_raw(
980 "Now mark the first entry as completed without acting on it.",
981 cx,
982 )
983 })
984 .await
985 .unwrap();
986
987 thread.read_with(cx, |thread, _| {
988 assert!(matches!(
989 thread.plan().entries[0].status,
990 acp::PlanEntryStatus::Completed
991 ));
992 assert_eq!(thread.plan().entries.len(), entries_len);
993 });
994 }
995
996 #[test]
997 fn test_deserialize_content_untagged_text() {
998 let json = json!("Hello, world!");
999 let content: Content = serde_json::from_value(json).unwrap();
1000 match content {
1001 Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1002 _ => panic!("Expected UntaggedText variant"),
1003 }
1004 }
1005
1006 #[test]
1007 fn test_deserialize_content_chunks() {
1008 let json = json!([
1009 {
1010 "type": "text",
1011 "text": "Hello"
1012 },
1013 {
1014 "type": "tool_use",
1015 "id": "tool_123",
1016 "name": "calculator",
1017 "input": {"operation": "add", "a": 1, "b": 2}
1018 }
1019 ]);
1020 let content: Content = serde_json::from_value(json).unwrap();
1021 match content {
1022 Content::Chunks(chunks) => {
1023 assert_eq!(chunks.len(), 2);
1024 match &chunks[0] {
1025 ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1026 _ => panic!("Expected Text chunk"),
1027 }
1028 match &chunks[1] {
1029 ContentChunk::ToolUse { id, name, input } => {
1030 assert_eq!(id, "tool_123");
1031 assert_eq!(name, "calculator");
1032 assert_eq!(input["operation"], "add");
1033 assert_eq!(input["a"], 1);
1034 assert_eq!(input["b"], 2);
1035 }
1036 _ => panic!("Expected ToolUse chunk"),
1037 }
1038 }
1039 _ => panic!("Expected Chunks variant"),
1040 }
1041 }
1042
1043 #[test]
1044 fn test_deserialize_tool_result_untagged_text() {
1045 let json = json!({
1046 "type": "tool_result",
1047 "content": "Result content",
1048 "tool_use_id": "tool_456"
1049 });
1050 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1051 match chunk {
1052 ContentChunk::ToolResult {
1053 content,
1054 tool_use_id,
1055 } => {
1056 match content {
1057 Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1058 _ => panic!("Expected UntaggedText content"),
1059 }
1060 assert_eq!(tool_use_id, "tool_456");
1061 }
1062 _ => panic!("Expected ToolResult variant"),
1063 }
1064 }
1065
1066 #[test]
1067 fn test_deserialize_tool_result_chunks() {
1068 let json = json!({
1069 "type": "tool_result",
1070 "content": [
1071 {
1072 "type": "text",
1073 "text": "Processing complete"
1074 },
1075 {
1076 "type": "text",
1077 "text": "Result: 42"
1078 }
1079 ],
1080 "tool_use_id": "tool_789"
1081 });
1082 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1083 match chunk {
1084 ContentChunk::ToolResult {
1085 content,
1086 tool_use_id,
1087 } => {
1088 match content {
1089 Content::Chunks(chunks) => {
1090 assert_eq!(chunks.len(), 2);
1091 match &chunks[0] {
1092 ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1093 _ => panic!("Expected Text chunk"),
1094 }
1095 match &chunks[1] {
1096 ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1097 _ => panic!("Expected Text chunk"),
1098 }
1099 }
1100 _ => panic!("Expected Chunks content"),
1101 }
1102 assert_eq!(tool_use_id, "tool_789");
1103 }
1104 _ => panic!("Expected ToolResult variant"),
1105 }
1106 }
1107}