connection.rs

  1use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
  2
  3use agent_client_protocol as acp;
  4use agentic_coding_protocol::{self as acp_old, AgentRequest};
  5use anyhow::Result;
  6use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
  7use project::Project;
  8use ui::App;
  9
 10use crate::AcpThread;
 11
 12pub trait AgentConnection {
 13    fn new_thread(
 14        &self,
 15        project: Entity<Project>,
 16        cwd: &Path,
 17        connection: Rc<dyn AgentConnection>,
 18        cx: &mut AsyncApp,
 19    ) -> Task<Result<Entity<AcpThread>>>;
 20
 21    fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
 22
 23    fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>>;
 24
 25    fn cancel(&self, cx: &mut App);
 26}
 27
 28#[derive(Debug)]
 29pub struct Unauthenticated;
 30
 31impl Error for Unauthenticated {}
 32impl fmt::Display for Unauthenticated {
 33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 34        write!(f, "Unauthenticated")
 35    }
 36}
 37
 38pub struct OldAcpAgentConnection {
 39    pub connection: acp_old::AgentConnection,
 40    pub child_status: Task<Result<()>>,
 41    pub thread: Rc<RefCell<WeakEntity<AcpThread>>>,
 42}
 43
 44impl AgentConnection for OldAcpAgentConnection {
 45    fn new_thread(
 46        &self,
 47        project: Entity<Project>,
 48        _cwd: &Path,
 49        connection: Rc<dyn AgentConnection>,
 50        cx: &mut AsyncApp,
 51    ) -> Task<Result<Entity<AcpThread>>> {
 52        let task = self.connection.request_any(
 53            acp_old::InitializeParams {
 54                protocol_version: acp_old::ProtocolVersion::latest(),
 55            }
 56            .into_any(),
 57        );
 58        let current_thread = self.thread.clone();
 59        cx.spawn(async move |cx| {
 60            let result = task.await?;
 61            let result = acp_old::InitializeParams::response_from_any(result)?;
 62
 63            if !result.is_authenticated {
 64                anyhow::bail!(Unauthenticated)
 65            }
 66
 67            cx.update(|cx| {
 68                let thread = cx.new(|cx| {
 69                    let session_id = acp::SessionId("acp-old-no-id".into());
 70                    AcpThread::new(connection, "Gemini".into(), None, project, session_id, cx)
 71                });
 72                current_thread.replace(thread.downgrade());
 73                thread
 74            })
 75        })
 76    }
 77
 78    fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
 79        let task = self
 80            .connection
 81            .request_any(acp_old::AuthenticateParams.into_any());
 82        cx.foreground_executor().spawn(async move {
 83            task.await?;
 84            Ok(())
 85        })
 86    }
 87
 88    fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task<Result<()>> {
 89        let chunks = params
 90            .prompt
 91            .into_iter()
 92            .filter_map(|block| match block {
 93                acp::ContentBlock::Text(text) => {
 94                    Some(acp_old::UserMessageChunk::Text { text: text.text })
 95                }
 96                acp::ContentBlock::ResourceLink(link) => Some(acp_old::UserMessageChunk::Path {
 97                    path: link.uri.into(),
 98                }),
 99                _ => None,
100            })
101            .collect();
102
103        let task = self
104            .connection
105            .request_any(acp_old::SendUserMessageParams { chunks }.into_any());
106        cx.foreground_executor().spawn(async move {
107            task.await?;
108            anyhow::Ok(())
109        })
110    }
111
112    fn cancel(&self, cx: &mut App) {
113        let task = self
114            .connection
115            .request_any(acp_old::CancelSendMessageParams.into_any());
116        cx.foreground_executor()
117            .spawn(async move {
118                task.await?;
119                anyhow::Ok(())
120            })
121            .detach_and_log_err(cx)
122    }
123}