1mod mcp_server;
2pub mod tools;
3
4use action_log::ActionLog;
5use collections::HashMap;
6use context_server::listener::McpServerTool;
7use language_models::provider::anthropic::AnthropicLanguageModelProvider;
8use project::Project;
9use settings::SettingsStore;
10use smol::process::Child;
11use std::any::Any;
12use std::cell::RefCell;
13use std::fmt::Display;
14use std::path::Path;
15use std::rc::Rc;
16use uuid::Uuid;
17
18use agent_client_protocol as acp;
19use anyhow::{Context as _, Result, anyhow};
20use futures::channel::oneshot;
21use futures::{AsyncBufReadExt, AsyncWriteExt};
22use futures::{
23 AsyncRead, AsyncWrite, FutureExt, StreamExt,
24 channel::mpsc::{self, UnboundedReceiver, UnboundedSender},
25 io::BufReader,
26 select_biased,
27};
28use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
29use serde::{Deserialize, Serialize};
30use util::{ResultExt, debug_panic};
31
32use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
33use crate::claude::tools::ClaudeTool;
34use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
35use acp_thread::{AcpThread, AgentConnection, AuthRequired};
36
37#[derive(Clone)]
38pub struct ClaudeCode;
39
40impl AgentServer for ClaudeCode {
41 fn name(&self) -> &'static str {
42 "Claude Code"
43 }
44
45 fn empty_state_headline(&self) -> &'static str {
46 self.name()
47 }
48
49 fn empty_state_message(&self) -> &'static str {
50 "How can I help you today?"
51 }
52
53 fn logo(&self) -> ui::IconName {
54 ui::IconName::AiClaude
55 }
56
57 fn connect(
58 &self,
59 _root_dir: &Path,
60 _project: &Entity<Project>,
61 _cx: &mut App,
62 ) -> Task<Result<Rc<dyn AgentConnection>>> {
63 let connection = ClaudeAgentConnection {
64 sessions: Default::default(),
65 };
66
67 Task::ready(Ok(Rc::new(connection) as _))
68 }
69
70 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
71 self
72 }
73}
74
75struct ClaudeAgentConnection {
76 sessions: Rc<RefCell<HashMap<acp::SessionId, ClaudeAgentSession>>>,
77}
78
79impl AgentConnection for ClaudeAgentConnection {
80 fn new_thread(
81 self: Rc<Self>,
82 project: Entity<Project>,
83 cwd: &Path,
84 cx: &mut App,
85 ) -> Task<Result<Entity<AcpThread>>> {
86 let cwd = cwd.to_owned();
87 cx.spawn(async move |cx| {
88 let settings = cx.read_global(|settings: &SettingsStore, _| {
89 settings.get::<AllAgentServersSettings>(None).claude.clone()
90 })?;
91
92 let Some(command) = AgentServerCommand::resolve(
93 "claude",
94 &[],
95 Some(&util::paths::home_dir().join(".claude/local/claude")),
96 settings,
97 &project,
98 cx,
99 )
100 .await
101 else {
102 anyhow::bail!("Failed to find claude binary");
103 };
104
105 let api_key =
106 cx.update(AnthropicLanguageModelProvider::api_key)?
107 .await
108 .map_err(|err| {
109 if err.is::<language_model::AuthenticateError>() {
110 anyhow!(AuthRequired::new().with_language_model_provider(
111 language_model::ANTHROPIC_PROVIDER_ID
112 ))
113 } else {
114 anyhow!(err)
115 }
116 })?;
117
118 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
119 let fs = project.read_with(cx, |project, _cx| project.fs().clone())?;
120 let permission_mcp_server = ClaudeZedMcpServer::new(thread_rx.clone(), fs, cx).await?;
121
122 let mut mcp_servers = HashMap::default();
123 mcp_servers.insert(
124 mcp_server::SERVER_NAME.to_string(),
125 permission_mcp_server.server_config()?,
126 );
127 let mcp_config = McpConfig { mcp_servers };
128
129 let mcp_config_file = tempfile::NamedTempFile::new()?;
130 let (mcp_config_file, mcp_config_path) = mcp_config_file.into_parts();
131
132 let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
133 mcp_config_file
134 .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
135 .await?;
136 mcp_config_file.flush().await?;
137
138 let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded();
139 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
140
141 let session_id = acp::SessionId(Uuid::new_v4().to_string().into());
142
143 log::trace!("Starting session with id: {}", session_id);
144
145 let mut child = spawn_claude(
146 &command,
147 ClaudeSessionMode::Start,
148 session_id.clone(),
149 api_key,
150 &mcp_config_path,
151 &cwd,
152 )?;
153
154 let stdout = child.stdout.take().context("Failed to take stdout")?;
155 let stdin = child.stdin.take().context("Failed to take stdin")?;
156 let stderr = child.stderr.take().context("Failed to take stderr")?;
157
158 let pid = child.id();
159 log::trace!("Spawned (pid: {})", pid);
160
161 cx.background_spawn(async move {
162 let mut stderr = BufReader::new(stderr);
163 let mut line = String::new();
164 while let Ok(n) = stderr.read_line(&mut line).await
165 && n > 0
166 {
167 log::warn!("agent stderr: {}", &line);
168 line.clear();
169 }
170 })
171 .detach();
172
173 cx.background_spawn(async move {
174 let mut outgoing_rx = Some(outgoing_rx);
175
176 ClaudeAgentSession::handle_io(
177 outgoing_rx.take().unwrap(),
178 incoming_message_tx.clone(),
179 stdin,
180 stdout,
181 )
182 .await?;
183
184 log::trace!("Stopped (pid: {})", pid);
185
186 drop(mcp_config_path);
187 anyhow::Ok(())
188 })
189 .detach();
190
191 let turn_state = Rc::new(RefCell::new(TurnState::None));
192
193 let handler_task = cx.spawn({
194 let turn_state = turn_state.clone();
195 let mut thread_rx = thread_rx.clone();
196 async move |cx| {
197 while let Some(message) = incoming_message_rx.next().await {
198 ClaudeAgentSession::handle_message(
199 thread_rx.clone(),
200 message,
201 turn_state.clone(),
202 cx,
203 )
204 .await
205 }
206
207 if let Some(status) = child.status().await.log_err()
208 && let Some(thread) = thread_rx.recv().await.ok()
209 {
210 thread
211 .update(cx, |thread, cx| {
212 thread.emit_server_exited(status, cx);
213 })
214 .ok();
215 }
216 }
217 });
218
219 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
220 let thread = cx.new(|_cx| {
221 AcpThread::new(
222 "Claude Code",
223 self.clone(),
224 project,
225 action_log,
226 session_id.clone(),
227 )
228 })?;
229
230 thread_tx.send(thread.downgrade())?;
231
232 let session = ClaudeAgentSession {
233 outgoing_tx,
234 turn_state,
235 _handler_task: handler_task,
236 _mcp_server: Some(permission_mcp_server),
237 };
238
239 self.sessions.borrow_mut().insert(session_id, session);
240
241 Ok(thread)
242 })
243 }
244
245 fn auth_methods(&self) -> &[acp::AuthMethod] {
246 &[]
247 }
248
249 fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
250 Task::ready(Err(anyhow!("Authentication not supported")))
251 }
252
253 fn prompt(
254 &self,
255 _id: Option<acp_thread::UserMessageId>,
256 params: acp::PromptRequest,
257 cx: &mut App,
258 ) -> Task<Result<acp::PromptResponse>> {
259 let sessions = self.sessions.borrow();
260 let Some(session) = sessions.get(¶ms.session_id) else {
261 return Task::ready(Err(anyhow!(
262 "Attempted to send message to nonexistent session {}",
263 params.session_id
264 )));
265 };
266
267 let (end_tx, end_rx) = oneshot::channel();
268 session.turn_state.replace(TurnState::InProgress { end_tx });
269
270 let mut content = String::new();
271 for chunk in params.prompt {
272 match chunk {
273 acp::ContentBlock::Text(text_content) => {
274 content.push_str(&text_content.text);
275 }
276 acp::ContentBlock::ResourceLink(resource_link) => {
277 content.push_str(&format!("@{}", resource_link.uri));
278 }
279 acp::ContentBlock::Audio(_)
280 | acp::ContentBlock::Image(_)
281 | acp::ContentBlock::Resource(_) => {
282 // TODO
283 }
284 }
285 }
286
287 if let Err(err) = session.outgoing_tx.unbounded_send(SdkMessage::User {
288 message: Message {
289 role: Role::User,
290 content: Content::UntaggedText(content),
291 id: None,
292 model: None,
293 stop_reason: None,
294 stop_sequence: None,
295 usage: None,
296 },
297 session_id: Some(params.session_id.to_string()),
298 }) {
299 return Task::ready(Err(anyhow!(err)));
300 }
301
302 cx.foreground_executor().spawn(async move { end_rx.await? })
303 }
304
305 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
306 let sessions = self.sessions.borrow();
307 let Some(session) = sessions.get(session_id) else {
308 log::warn!("Attempted to cancel nonexistent session {}", session_id);
309 return;
310 };
311
312 let request_id = new_request_id();
313
314 let turn_state = session.turn_state.take();
315 let TurnState::InProgress { end_tx } = turn_state else {
316 // Already canceled or idle, put it back
317 session.turn_state.replace(turn_state);
318 return;
319 };
320
321 session.turn_state.replace(TurnState::CancelRequested {
322 end_tx,
323 request_id: request_id.clone(),
324 });
325
326 session
327 .outgoing_tx
328 .unbounded_send(SdkMessage::ControlRequest {
329 request_id,
330 request: ControlRequest::Interrupt,
331 })
332 .log_err();
333 }
334
335 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
336 self
337 }
338}
339
340#[derive(Clone, Copy)]
341enum ClaudeSessionMode {
342 Start,
343 #[expect(dead_code)]
344 Resume,
345}
346
347fn spawn_claude(
348 command: &AgentServerCommand,
349 mode: ClaudeSessionMode,
350 session_id: acp::SessionId,
351 api_key: language_models::provider::anthropic::ApiKey,
352 mcp_config_path: &Path,
353 root_dir: &Path,
354) -> Result<Child> {
355 let child = util::command::new_smol_command(&command.path)
356 .args([
357 "--input-format",
358 "stream-json",
359 "--output-format",
360 "stream-json",
361 "--print",
362 "--verbose",
363 "--mcp-config",
364 mcp_config_path.to_string_lossy().as_ref(),
365 "--permission-prompt-tool",
366 &format!(
367 "mcp__{}__{}",
368 mcp_server::SERVER_NAME,
369 mcp_server::PermissionTool::NAME,
370 ),
371 "--allowedTools",
372 &format!(
373 "mcp__{}__{},mcp__{}__{}",
374 mcp_server::SERVER_NAME,
375 mcp_server::EditTool::NAME,
376 mcp_server::SERVER_NAME,
377 mcp_server::ReadTool::NAME
378 ),
379 "--disallowedTools",
380 "Read,Edit",
381 ])
382 .args(match mode {
383 ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()],
384 ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()],
385 })
386 .args(command.args.iter().map(|arg| arg.as_str()))
387 .envs(command.env.iter().flatten())
388 .env("ANTHROPIC_API_KEY", api_key.key)
389 .current_dir(root_dir)
390 .stdin(std::process::Stdio::piped())
391 .stdout(std::process::Stdio::piped())
392 .stderr(std::process::Stdio::piped())
393 .kill_on_drop(true)
394 .spawn()?;
395
396 Ok(child)
397}
398
399struct ClaudeAgentSession {
400 outgoing_tx: UnboundedSender<SdkMessage>,
401 turn_state: Rc<RefCell<TurnState>>,
402 _mcp_server: Option<ClaudeZedMcpServer>,
403 _handler_task: Task<()>,
404}
405
406#[derive(Debug, Default)]
407enum TurnState {
408 #[default]
409 None,
410 InProgress {
411 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
412 },
413 CancelRequested {
414 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
415 request_id: String,
416 },
417 CancelConfirmed {
418 end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
419 },
420}
421
422impl TurnState {
423 fn is_canceled(&self) -> bool {
424 matches!(self, TurnState::CancelConfirmed { .. })
425 }
426
427 fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
428 match self {
429 TurnState::None => None,
430 TurnState::InProgress { end_tx, .. } => Some(end_tx),
431 TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
432 TurnState::CancelConfirmed { end_tx } => Some(end_tx),
433 }
434 }
435
436 fn confirm_cancellation(self, id: &str) -> Self {
437 match self {
438 TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
439 TurnState::CancelConfirmed { end_tx }
440 }
441 _ => self,
442 }
443 }
444}
445
446impl ClaudeAgentSession {
447 async fn handle_message(
448 mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
449 message: SdkMessage,
450 turn_state: Rc<RefCell<TurnState>>,
451 cx: &mut AsyncApp,
452 ) {
453 match message {
454 // we should only be sending these out, they don't need to be in the thread
455 SdkMessage::ControlRequest { .. } => {}
456 SdkMessage::User {
457 message,
458 session_id: _,
459 } => {
460 let Some(thread) = thread_rx
461 .recv()
462 .await
463 .log_err()
464 .and_then(|entity| entity.upgrade())
465 else {
466 log::error!("Received an SDK message but thread is gone");
467 return;
468 };
469
470 for chunk in message.content.chunks() {
471 match chunk {
472 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
473 if !turn_state.borrow().is_canceled() {
474 thread
475 .update(cx, |thread, cx| {
476 thread.push_user_content_block(None, text.into(), cx)
477 })
478 .log_err();
479 }
480 }
481 ContentChunk::ToolResult {
482 content,
483 tool_use_id,
484 } => {
485 let content = content.to_string();
486 thread
487 .update(cx, |thread, cx| {
488 thread.update_tool_call(
489 acp::ToolCallUpdate {
490 id: acp::ToolCallId(tool_use_id.into()),
491 fields: acp::ToolCallUpdateFields {
492 status: if turn_state.borrow().is_canceled() {
493 // Do not set to completed if turn was canceled
494 None
495 } else {
496 Some(acp::ToolCallStatus::Completed)
497 },
498 content: (!content.is_empty())
499 .then(|| vec![content.into()]),
500 ..Default::default()
501 },
502 },
503 cx,
504 )
505 })
506 .log_err();
507 }
508 ContentChunk::Thinking { .. }
509 | ContentChunk::RedactedThinking
510 | ContentChunk::ToolUse { .. } => {
511 debug_panic!(
512 "Should not get {:?} with role: assistant. should we handle this?",
513 chunk
514 );
515 }
516
517 ContentChunk::Image
518 | ContentChunk::Document
519 | ContentChunk::WebSearchToolResult => {
520 thread
521 .update(cx, |thread, cx| {
522 thread.push_assistant_content_block(
523 format!("Unsupported content: {:?}", chunk).into(),
524 false,
525 cx,
526 )
527 })
528 .log_err();
529 }
530 }
531 }
532 }
533 SdkMessage::Assistant {
534 message,
535 session_id: _,
536 } => {
537 let Some(thread) = thread_rx
538 .recv()
539 .await
540 .log_err()
541 .and_then(|entity| entity.upgrade())
542 else {
543 log::error!("Received an SDK message but thread is gone");
544 return;
545 };
546
547 for chunk in message.content.chunks() {
548 match chunk {
549 ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
550 thread
551 .update(cx, |thread, cx| {
552 thread.push_assistant_content_block(text.into(), false, cx)
553 })
554 .log_err();
555 }
556 ContentChunk::Thinking { thinking } => {
557 thread
558 .update(cx, |thread, cx| {
559 thread.push_assistant_content_block(thinking.into(), true, cx)
560 })
561 .log_err();
562 }
563 ContentChunk::RedactedThinking => {
564 thread
565 .update(cx, |thread, cx| {
566 thread.push_assistant_content_block(
567 "[REDACTED]".into(),
568 true,
569 cx,
570 )
571 })
572 .log_err();
573 }
574 ContentChunk::ToolUse { id, name, input } => {
575 let claude_tool = ClaudeTool::infer(&name, input);
576
577 thread
578 .update(cx, |thread, cx| {
579 if let ClaudeTool::TodoWrite(Some(params)) = claude_tool {
580 thread.update_plan(
581 acp::Plan {
582 entries: params
583 .todos
584 .into_iter()
585 .map(Into::into)
586 .collect(),
587 },
588 cx,
589 )
590 } else {
591 thread.upsert_tool_call(
592 claude_tool.as_acp(acp::ToolCallId(id.into())),
593 cx,
594 )?;
595 }
596 anyhow::Ok(())
597 })
598 .log_err();
599 }
600 ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
601 debug_panic!(
602 "Should not get tool results with role: assistant. should we handle this?"
603 );
604 }
605 ContentChunk::Image | ContentChunk::Document => {
606 thread
607 .update(cx, |thread, cx| {
608 thread.push_assistant_content_block(
609 format!("Unsupported content: {:?}", chunk).into(),
610 false,
611 cx,
612 )
613 })
614 .log_err();
615 }
616 }
617 }
618 }
619 SdkMessage::Result {
620 is_error,
621 subtype,
622 result,
623 ..
624 } => {
625 let turn_state = turn_state.take();
626 let was_canceled = turn_state.is_canceled();
627 let Some(end_turn_tx) = turn_state.end_tx() else {
628 debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
629 return;
630 };
631
632 if is_error || (!was_canceled && subtype == ResultErrorType::ErrorDuringExecution) {
633 end_turn_tx
634 .send(Err(anyhow!(
635 "Error: {}",
636 result.unwrap_or_else(|| subtype.to_string())
637 )))
638 .ok();
639 } else {
640 let stop_reason = match subtype {
641 ResultErrorType::Success => acp::StopReason::EndTurn,
642 ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
643 ResultErrorType::ErrorDuringExecution => acp::StopReason::Canceled,
644 };
645 end_turn_tx
646 .send(Ok(acp::PromptResponse { stop_reason }))
647 .ok();
648 }
649 }
650 SdkMessage::ControlResponse { response } => {
651 if matches!(response.subtype, ResultErrorType::Success) {
652 let new_state = turn_state.take().confirm_cancellation(&response.request_id);
653 turn_state.replace(new_state);
654 } else {
655 log::error!("Control response error: {:?}", response);
656 }
657 }
658 SdkMessage::System { .. } => {}
659 }
660 }
661
662 async fn handle_io(
663 mut outgoing_rx: UnboundedReceiver<SdkMessage>,
664 incoming_tx: UnboundedSender<SdkMessage>,
665 mut outgoing_bytes: impl Unpin + AsyncWrite,
666 incoming_bytes: impl Unpin + AsyncRead,
667 ) -> Result<UnboundedReceiver<SdkMessage>> {
668 let mut output_reader = BufReader::new(incoming_bytes);
669 let mut outgoing_line = Vec::new();
670 let mut incoming_line = String::new();
671 loop {
672 select_biased! {
673 message = outgoing_rx.next() => {
674 if let Some(message) = message {
675 outgoing_line.clear();
676 serde_json::to_writer(&mut outgoing_line, &message)?;
677 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
678 outgoing_line.push(b'\n');
679 outgoing_bytes.write_all(&outgoing_line).await.ok();
680 } else {
681 break;
682 }
683 }
684 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
685 if bytes_read? == 0 {
686 break
687 }
688 log::trace!("recv: {}", &incoming_line);
689 match serde_json::from_str::<SdkMessage>(&incoming_line) {
690 Ok(message) => {
691 incoming_tx.unbounded_send(message).log_err();
692 }
693 Err(error) => {
694 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
695 }
696 }
697 incoming_line.clear();
698 }
699 }
700 }
701
702 Ok(outgoing_rx)
703 }
704}
705
706#[derive(Debug, Clone, Serialize, Deserialize)]
707struct Message {
708 role: Role,
709 content: Content,
710 #[serde(skip_serializing_if = "Option::is_none")]
711 id: Option<String>,
712 #[serde(skip_serializing_if = "Option::is_none")]
713 model: Option<String>,
714 #[serde(skip_serializing_if = "Option::is_none")]
715 stop_reason: Option<String>,
716 #[serde(skip_serializing_if = "Option::is_none")]
717 stop_sequence: Option<String>,
718 #[serde(skip_serializing_if = "Option::is_none")]
719 usage: Option<Usage>,
720}
721
722#[derive(Debug, Clone, Serialize, Deserialize)]
723#[serde(untagged)]
724enum Content {
725 UntaggedText(String),
726 Chunks(Vec<ContentChunk>),
727}
728
729impl Content {
730 pub fn chunks(self) -> impl Iterator<Item = ContentChunk> {
731 match self {
732 Self::Chunks(chunks) => chunks.into_iter(),
733 Self::UntaggedText(text) => vec![ContentChunk::Text { text: text.clone() }].into_iter(),
734 }
735 }
736}
737
738impl Display for Content {
739 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
740 match self {
741 Content::UntaggedText(txt) => write!(f, "{}", txt),
742 Content::Chunks(chunks) => {
743 for chunk in chunks {
744 write!(f, "{}", chunk)?;
745 }
746 Ok(())
747 }
748 }
749 }
750}
751
752#[derive(Debug, Clone, Serialize, Deserialize)]
753#[serde(tag = "type", rename_all = "snake_case")]
754enum ContentChunk {
755 Text {
756 text: String,
757 },
758 ToolUse {
759 id: String,
760 name: String,
761 input: serde_json::Value,
762 },
763 ToolResult {
764 content: Content,
765 tool_use_id: String,
766 },
767 Thinking {
768 thinking: String,
769 },
770 RedactedThinking,
771 // TODO
772 Image,
773 Document,
774 WebSearchToolResult,
775 #[serde(untagged)]
776 UntaggedText(String),
777}
778
779impl Display for ContentChunk {
780 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
781 match self {
782 ContentChunk::Text { text } => write!(f, "{}", text),
783 ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
784 ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
785 ContentChunk::UntaggedText(text) => write!(f, "{}", text),
786 ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
787 ContentChunk::Image
788 | ContentChunk::Document
789 | ContentChunk::ToolUse { .. }
790 | ContentChunk::WebSearchToolResult => {
791 write!(f, "\n{:?}\n", &self)
792 }
793 }
794 }
795}
796
797#[derive(Debug, Clone, Serialize, Deserialize)]
798struct Usage {
799 input_tokens: u32,
800 cache_creation_input_tokens: u32,
801 cache_read_input_tokens: u32,
802 output_tokens: u32,
803 service_tier: String,
804}
805
806#[derive(Debug, Clone, Serialize, Deserialize)]
807#[serde(rename_all = "snake_case")]
808enum Role {
809 System,
810 Assistant,
811 User,
812}
813
814#[derive(Debug, Clone, Serialize, Deserialize)]
815struct MessageParam {
816 role: Role,
817 content: String,
818}
819
820#[derive(Debug, Clone, Serialize, Deserialize)]
821#[serde(tag = "type", rename_all = "snake_case")]
822enum SdkMessage {
823 // An assistant message
824 Assistant {
825 message: Message, // from Anthropic SDK
826 #[serde(skip_serializing_if = "Option::is_none")]
827 session_id: Option<String>,
828 },
829 // A user message
830 User {
831 message: Message, // from Anthropic SDK
832 #[serde(skip_serializing_if = "Option::is_none")]
833 session_id: Option<String>,
834 },
835 // Emitted as the last message in a conversation
836 Result {
837 subtype: ResultErrorType,
838 duration_ms: f64,
839 duration_api_ms: f64,
840 is_error: bool,
841 num_turns: i32,
842 #[serde(skip_serializing_if = "Option::is_none")]
843 result: Option<String>,
844 session_id: String,
845 total_cost_usd: f64,
846 },
847 // Emitted as the first message at the start of a conversation
848 System {
849 cwd: String,
850 session_id: String,
851 tools: Vec<String>,
852 model: String,
853 mcp_servers: Vec<McpServer>,
854 #[serde(rename = "apiKeySource")]
855 api_key_source: String,
856 #[serde(rename = "permissionMode")]
857 permission_mode: PermissionMode,
858 },
859 /// Messages used to control the conversation, outside of chat messages to the model
860 ControlRequest {
861 request_id: String,
862 request: ControlRequest,
863 },
864 /// Response to a control request
865 ControlResponse { response: ControlResponse },
866}
867
868#[derive(Debug, Clone, Serialize, Deserialize)]
869#[serde(tag = "subtype", rename_all = "snake_case")]
870enum ControlRequest {
871 /// Cancel the current conversation
872 Interrupt,
873}
874
875#[derive(Debug, Clone, Serialize, Deserialize)]
876struct ControlResponse {
877 request_id: String,
878 subtype: ResultErrorType,
879}
880
881#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
882#[serde(rename_all = "snake_case")]
883enum ResultErrorType {
884 Success,
885 ErrorMaxTurns,
886 ErrorDuringExecution,
887}
888
889impl Display for ResultErrorType {
890 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
891 match self {
892 ResultErrorType::Success => write!(f, "success"),
893 ResultErrorType::ErrorMaxTurns => write!(f, "error_max_turns"),
894 ResultErrorType::ErrorDuringExecution => write!(f, "error_during_execution"),
895 }
896 }
897}
898
899fn new_request_id() -> String {
900 use rand::Rng;
901 // In the Claude Code TS SDK they just generate a random 12 character string,
902 // `Math.random().toString(36).substring(2, 15)`
903 rand::thread_rng()
904 .sample_iter(&rand::distributions::Alphanumeric)
905 .take(12)
906 .map(char::from)
907 .collect()
908}
909
910#[derive(Debug, Clone, Serialize, Deserialize)]
911struct McpServer {
912 name: String,
913 status: String,
914}
915
916#[derive(Debug, Clone, Serialize, Deserialize)]
917#[serde(rename_all = "camelCase")]
918enum PermissionMode {
919 Default,
920 AcceptEdits,
921 BypassPermissions,
922 Plan,
923}
924
925#[cfg(test)]
926pub(crate) mod tests {
927 use super::*;
928 use crate::e2e_tests;
929 use gpui::TestAppContext;
930 use serde_json::json;
931
932 crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
933
934 pub fn local_command() -> AgentServerCommand {
935 AgentServerCommand {
936 path: "claude".into(),
937 args: vec![],
938 env: None,
939 }
940 }
941
942 #[gpui::test]
943 #[cfg_attr(not(feature = "e2e"), ignore)]
944 async fn test_todo_plan(cx: &mut TestAppContext) {
945 let fs = e2e_tests::init_test(cx).await;
946 let project = Project::test(fs, [], cx).await;
947 let thread =
948 e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
949
950 thread
951 .update(cx, |thread, cx| {
952 thread.send_raw(
953 "Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
954 cx,
955 )
956 })
957 .await
958 .unwrap();
959
960 let mut entries_len = 0;
961
962 thread.read_with(cx, |thread, _| {
963 entries_len = thread.plan().entries.len();
964 assert!(thread.plan().entries.len() > 0, "Empty plan");
965 });
966
967 thread
968 .update(cx, |thread, cx| {
969 thread.send_raw(
970 "Mark the first entry status as in progress without acting on it.",
971 cx,
972 )
973 })
974 .await
975 .unwrap();
976
977 thread.read_with(cx, |thread, _| {
978 assert!(matches!(
979 thread.plan().entries[0].status,
980 acp::PlanEntryStatus::InProgress
981 ));
982 assert_eq!(thread.plan().entries.len(), entries_len);
983 });
984
985 thread
986 .update(cx, |thread, cx| {
987 thread.send_raw(
988 "Now mark the first entry as completed without acting on it.",
989 cx,
990 )
991 })
992 .await
993 .unwrap();
994
995 thread.read_with(cx, |thread, _| {
996 assert!(matches!(
997 thread.plan().entries[0].status,
998 acp::PlanEntryStatus::Completed
999 ));
1000 assert_eq!(thread.plan().entries.len(), entries_len);
1001 });
1002 }
1003
1004 #[test]
1005 fn test_deserialize_content_untagged_text() {
1006 let json = json!("Hello, world!");
1007 let content: Content = serde_json::from_value(json).unwrap();
1008 match content {
1009 Content::UntaggedText(text) => assert_eq!(text, "Hello, world!"),
1010 _ => panic!("Expected UntaggedText variant"),
1011 }
1012 }
1013
1014 #[test]
1015 fn test_deserialize_content_chunks() {
1016 let json = json!([
1017 {
1018 "type": "text",
1019 "text": "Hello"
1020 },
1021 {
1022 "type": "tool_use",
1023 "id": "tool_123",
1024 "name": "calculator",
1025 "input": {"operation": "add", "a": 1, "b": 2}
1026 }
1027 ]);
1028 let content: Content = serde_json::from_value(json).unwrap();
1029 match content {
1030 Content::Chunks(chunks) => {
1031 assert_eq!(chunks.len(), 2);
1032 match &chunks[0] {
1033 ContentChunk::Text { text } => assert_eq!(text, "Hello"),
1034 _ => panic!("Expected Text chunk"),
1035 }
1036 match &chunks[1] {
1037 ContentChunk::ToolUse { id, name, input } => {
1038 assert_eq!(id, "tool_123");
1039 assert_eq!(name, "calculator");
1040 assert_eq!(input["operation"], "add");
1041 assert_eq!(input["a"], 1);
1042 assert_eq!(input["b"], 2);
1043 }
1044 _ => panic!("Expected ToolUse chunk"),
1045 }
1046 }
1047 _ => panic!("Expected Chunks variant"),
1048 }
1049 }
1050
1051 #[test]
1052 fn test_deserialize_tool_result_untagged_text() {
1053 let json = json!({
1054 "type": "tool_result",
1055 "content": "Result content",
1056 "tool_use_id": "tool_456"
1057 });
1058 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1059 match chunk {
1060 ContentChunk::ToolResult {
1061 content,
1062 tool_use_id,
1063 } => {
1064 match content {
1065 Content::UntaggedText(text) => assert_eq!(text, "Result content"),
1066 _ => panic!("Expected UntaggedText content"),
1067 }
1068 assert_eq!(tool_use_id, "tool_456");
1069 }
1070 _ => panic!("Expected ToolResult variant"),
1071 }
1072 }
1073
1074 #[test]
1075 fn test_deserialize_tool_result_chunks() {
1076 let json = json!({
1077 "type": "tool_result",
1078 "content": [
1079 {
1080 "type": "text",
1081 "text": "Processing complete"
1082 },
1083 {
1084 "type": "text",
1085 "text": "Result: 42"
1086 }
1087 ],
1088 "tool_use_id": "tool_789"
1089 });
1090 let chunk: ContentChunk = serde_json::from_value(json).unwrap();
1091 match chunk {
1092 ContentChunk::ToolResult {
1093 content,
1094 tool_use_id,
1095 } => {
1096 match content {
1097 Content::Chunks(chunks) => {
1098 assert_eq!(chunks.len(), 2);
1099 match &chunks[0] {
1100 ContentChunk::Text { text } => assert_eq!(text, "Processing complete"),
1101 _ => panic!("Expected Text chunk"),
1102 }
1103 match &chunks[1] {
1104 ContentChunk::Text { text } => assert_eq!(text, "Result: 42"),
1105 _ => panic!("Expected Text chunk"),
1106 }
1107 }
1108 _ => panic!("Expected Chunks content"),
1109 }
1110 assert_eq!(tool_use_id, "tool_789");
1111 }
1112 _ => panic!("Expected ToolResult variant"),
1113 }
1114 }
1115}