1use std::rc::Rc;
  2use std::sync::Arc;
  3use std::{borrow::Cow, cell::RefCell};
  4
  5use agent_client_protocol as acp;
  6use anyhow::{Context as _, Result, bail};
  7use futures::AsyncReadExt as _;
  8use gpui::{App, AppContext as _, Task};
  9use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
 10use http_client::{AsyncBody, HttpClientWithUrl};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use ui::SharedString;
 14use util::markdown::MarkdownEscaped;
 15
 16use crate::{AgentTool, ToolCallEventStream};
 17
 18#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 19enum ContentType {
 20    Html,
 21    Plaintext,
 22    Json,
 23}
 24
 25/// Fetches a URL and returns the content as Markdown.
 26#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 27pub struct FetchToolInput {
 28    /// The URL to fetch.
 29    url: String,
 30}
 31
 32pub struct FetchTool {
 33    http_client: Arc<HttpClientWithUrl>,
 34}
 35
 36impl FetchTool {
 37    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 38        Self { http_client }
 39    }
 40
 41    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 42        let url = if !url.starts_with("https://") && !url.starts_with("http://") {
 43            Cow::Owned(format!("https://{url}"))
 44        } else {
 45            Cow::Borrowed(url)
 46        };
 47
 48        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 49
 50        let mut body = Vec::new();
 51        response
 52            .body_mut()
 53            .read_to_end(&mut body)
 54            .await
 55            .context("error reading response body")?;
 56
 57        if response.status().is_client_error() {
 58            let text = String::from_utf8_lossy(body.as_slice());
 59            bail!(
 60                "status error {}, response: {text:?}",
 61                response.status().as_u16()
 62            );
 63        }
 64
 65        let Some(content_type) = response.headers().get("content-type") else {
 66            bail!("missing Content-Type header");
 67        };
 68        let content_type = content_type
 69            .to_str()
 70            .context("invalid Content-Type header")?;
 71
 72        let content_type = if content_type.starts_with("text/plain") {
 73            ContentType::Plaintext
 74        } else if content_type.starts_with("application/json") {
 75            ContentType::Json
 76        } else {
 77            ContentType::Html
 78        };
 79
 80        match content_type {
 81            ContentType::Html => {
 82                let mut handlers: Vec<TagHandler> = vec![
 83                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 84                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 85                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 86                    Rc::new(RefCell::new(markdown::ListHandler)),
 87                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 88                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 89                ];
 90                if url.contains("wikipedia.org") {
 91                    use html_to_markdown::structure::wikipedia;
 92
 93                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 94                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 95                    handlers.push(Rc::new(
 96                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
 97                    ));
 98                } else {
 99                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
100                }
101
102                convert_html_to_markdown(&body[..], &mut handlers)
103            }
104            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
105            ContentType::Json => {
106                let json: serde_json::Value = serde_json::from_slice(&body)?;
107
108                Ok(format!(
109                    "```json\n{}\n```",
110                    serde_json::to_string_pretty(&json)?
111                ))
112            }
113        }
114    }
115}
116
117impl AgentTool for FetchTool {
118    type Input = FetchToolInput;
119    type Output = String;
120
121    fn name() -> &'static str {
122        "fetch"
123    }
124
125    fn kind() -> acp::ToolKind {
126        acp::ToolKind::Fetch
127    }
128
129    fn initial_title(
130        &self,
131        input: Result<Self::Input, serde_json::Value>,
132        _cx: &mut App,
133    ) -> SharedString {
134        match input {
135            Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
136            Err(_) => "Fetch URL".into(),
137        }
138    }
139
140    fn run(
141        self: Arc<Self>,
142        input: Self::Input,
143        event_stream: ToolCallEventStream,
144        cx: &mut App,
145    ) -> Task<Result<Self::Output>> {
146        let authorize = event_stream.authorize(input.url.clone(), cx);
147
148        let text = cx.background_spawn({
149            let http_client = self.http_client.clone();
150            async move {
151                authorize.await?;
152                Self::build_message(http_client, &input.url).await
153            }
154        });
155
156        cx.foreground_executor().spawn(async move {
157            let text = text.await?;
158            if text.trim().is_empty() {
159                bail!("no textual content found");
160            }
161            Ok(text)
162        })
163    }
164}