fetch_tool.rs

  1use std::rc::Rc;
  2use std::sync::Arc;
  3use std::{borrow::Cow, cell::RefCell};
  4
  5use agent_client_protocol as acp;
  6use agent_settings::AgentSettings;
  7use anyhow::{Context as _, Result, bail};
  8use futures::{AsyncReadExt as _, FutureExt as _};
  9use gpui::{App, AppContext as _, Task};
 10use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
 11use http_client::{AsyncBody, HttpClientWithUrl};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::Settings;
 15use ui::SharedString;
 16use util::markdown::{MarkdownEscaped, MarkdownInlineCode};
 17
 18use crate::{
 19    AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision,
 20    decide_permission_from_settings,
 21};
 22
 23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 24enum ContentType {
 25    Html,
 26    Plaintext,
 27    Json,
 28}
 29
 30/// Fetches a URL and returns the content as Markdown.
 31#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 32pub struct FetchToolInput {
 33    /// The URL to fetch.
 34    url: String,
 35}
 36
 37pub struct FetchTool {
 38    http_client: Arc<HttpClientWithUrl>,
 39}
 40
 41impl FetchTool {
 42    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 43        Self { http_client }
 44    }
 45
 46    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 47        let url = if !url.starts_with("https://") && !url.starts_with("http://") {
 48            Cow::Owned(format!("https://{url}"))
 49        } else {
 50            Cow::Borrowed(url)
 51        };
 52
 53        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 54
 55        let mut body = Vec::new();
 56        response
 57            .body_mut()
 58            .read_to_end(&mut body)
 59            .await
 60            .context("error reading response body")?;
 61
 62        if response.status().is_client_error() {
 63            let text = String::from_utf8_lossy(body.as_slice());
 64            bail!(
 65                "status error {}, response: {text:?}",
 66                response.status().as_u16()
 67            );
 68        }
 69
 70        let Some(content_type) = response.headers().get("content-type") else {
 71            bail!("missing Content-Type header");
 72        };
 73        let content_type = content_type
 74            .to_str()
 75            .context("invalid Content-Type header")?;
 76
 77        let content_type = if content_type.starts_with("text/plain") {
 78            ContentType::Plaintext
 79        } else if content_type.starts_with("application/json") {
 80            ContentType::Json
 81        } else {
 82            ContentType::Html
 83        };
 84
 85        match content_type {
 86            ContentType::Html => {
 87                let mut handlers: Vec<TagHandler> = vec![
 88                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 89                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 90                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 91                    Rc::new(RefCell::new(markdown::ListHandler)),
 92                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 93                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 94                ];
 95                if url.contains("wikipedia.org") {
 96                    use html_to_markdown::structure::wikipedia;
 97
 98                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 99                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
100                    handlers.push(Rc::new(
101                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
102                    ));
103                } else {
104                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
105                }
106
107                convert_html_to_markdown(&body[..], &mut handlers)
108            }
109            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
110            ContentType::Json => {
111                let json: serde_json::Value = serde_json::from_slice(&body)?;
112
113                Ok(format!(
114                    "```json\n{}\n```",
115                    serde_json::to_string_pretty(&json)?
116                ))
117            }
118        }
119    }
120}
121
122impl AgentTool for FetchTool {
123    type Input = FetchToolInput;
124    type Output = String;
125
126    const NAME: &'static str = "fetch";
127
128    fn kind() -> acp::ToolKind {
129        acp::ToolKind::Fetch
130    }
131
132    fn initial_title(
133        &self,
134        input: Result<Self::Input, serde_json::Value>,
135        _cx: &mut App,
136    ) -> SharedString {
137        match input {
138            Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
139            Err(_) => "Fetch URL".into(),
140        }
141    }
142
143    fn run(
144        self: Arc<Self>,
145        input: ToolInput<Self::Input>,
146        event_stream: ToolCallEventStream,
147        cx: &mut App,
148    ) -> Task<Result<Self::Output, Self::Output>> {
149        let http_client = self.http_client.clone();
150        cx.spawn(async move |cx| {
151            let input: FetchToolInput = input
152                .recv()
153                .await
154                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
155
156            let decision = cx.update(|cx| {
157                decide_permission_from_settings(
158                    Self::NAME,
159                    std::slice::from_ref(&input.url),
160                    AgentSettings::get_global(cx),
161                )
162            });
163
164            let authorize = match decision {
165                ToolPermissionDecision::Allow => None,
166                ToolPermissionDecision::Deny(reason) => {
167                    return Err(reason);
168                }
169                ToolPermissionDecision::Confirm => Some(cx.update(|cx| {
170                    let context =
171                        crate::ToolPermissionContext::new(Self::NAME, vec![input.url.clone()]);
172                    event_stream.authorize(
173                        format!("Fetch {}", MarkdownInlineCode(&input.url)),
174                        context,
175                        cx,
176                    )
177                })),
178            };
179
180            let fetch_task = cx.background_spawn({
181                let http_client = http_client.clone();
182                let url = input.url.clone();
183                async move {
184                    if let Some(authorize) = authorize {
185                        authorize.await?;
186                    }
187                    Self::build_message(http_client, &url).await
188                }
189            });
190
191            let text = futures::select! {
192                result = fetch_task.fuse() => result.map_err(|e| e.to_string())?,
193                _ = event_stream.cancelled_by_user().fuse() => {
194                    return Err("Fetch cancelled by user".to_string());
195                }
196            };
197            if text.trim().is_empty() {
198                return Err("no textual content found".to_string());
199            }
200            Ok(text)
201        })
202    }
203}