fetch_tool.rs

  1use std::rc::Rc;
  2use std::sync::Arc;
  3use std::{borrow::Cow, cell::RefCell};
  4
  5use agent_client_protocol as acp;
  6use agent_settings::AgentSettings;
  7use anyhow::{Context as _, Result, bail};
  8use futures::{AsyncReadExt as _, FutureExt as _};
  9use gpui::{App, AppContext as _, Task};
 10use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
 11use http_client::{AsyncBody, HttpClientWithUrl};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::Settings;
 15use ui::SharedString;
 16use util::markdown::{MarkdownEscaped, MarkdownInlineCode};
 17
 18use crate::{
 19    AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
 20};
 21
 22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 23enum ContentType {
 24    Html,
 25    Plaintext,
 26    Json,
 27}
 28
 29/// Fetches a URL and returns the content as Markdown.
 30#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 31pub struct FetchToolInput {
 32    /// The URL to fetch.
 33    url: String,
 34}
 35
 36pub struct FetchTool {
 37    http_client: Arc<HttpClientWithUrl>,
 38}
 39
 40impl FetchTool {
 41    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 42        Self { http_client }
 43    }
 44
 45    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 46        let url = if !url.starts_with("https://") && !url.starts_with("http://") {
 47            Cow::Owned(format!("https://{url}"))
 48        } else {
 49            Cow::Borrowed(url)
 50        };
 51
 52        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 53
 54        let mut body = Vec::new();
 55        response
 56            .body_mut()
 57            .read_to_end(&mut body)
 58            .await
 59            .context("error reading response body")?;
 60
 61        if response.status().is_client_error() {
 62            let text = String::from_utf8_lossy(body.as_slice());
 63            bail!(
 64                "status error {}, response: {text:?}",
 65                response.status().as_u16()
 66            );
 67        }
 68
 69        let Some(content_type) = response.headers().get("content-type") else {
 70            bail!("missing Content-Type header");
 71        };
 72        let content_type = content_type
 73            .to_str()
 74            .context("invalid Content-Type header")?;
 75
 76        let content_type = if content_type.starts_with("text/plain") {
 77            ContentType::Plaintext
 78        } else if content_type.starts_with("application/json") {
 79            ContentType::Json
 80        } else {
 81            ContentType::Html
 82        };
 83
 84        match content_type {
 85            ContentType::Html => {
 86                let mut handlers: Vec<TagHandler> = vec![
 87                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 88                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 89                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 90                    Rc::new(RefCell::new(markdown::ListHandler)),
 91                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 92                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 93                ];
 94                if url.contains("wikipedia.org") {
 95                    use html_to_markdown::structure::wikipedia;
 96
 97                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 98                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 99                    handlers.push(Rc::new(
100                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
101                    ));
102                } else {
103                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
104                }
105
106                convert_html_to_markdown(&body[..], &mut handlers)
107            }
108            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
109            ContentType::Json => {
110                let json: serde_json::Value = serde_json::from_slice(&body)?;
111
112                Ok(format!(
113                    "```json\n{}\n```",
114                    serde_json::to_string_pretty(&json)?
115                ))
116            }
117        }
118    }
119}
120
121impl AgentTool for FetchTool {
122    type Input = FetchToolInput;
123    type Output = String;
124
125    fn name() -> &'static str {
126        "fetch"
127    }
128
129    fn kind() -> acp::ToolKind {
130        acp::ToolKind::Fetch
131    }
132
133    fn initial_title(
134        &self,
135        input: Result<Self::Input, serde_json::Value>,
136        _cx: &mut App,
137    ) -> SharedString {
138        match input {
139            Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
140            Err(_) => "Fetch URL".into(),
141        }
142    }
143
144    fn run(
145        self: Arc<Self>,
146        input: Self::Input,
147        event_stream: ToolCallEventStream,
148        cx: &mut App,
149    ) -> Task<Result<Self::Output>> {
150        let settings = AgentSettings::get_global(cx);
151        let decision = decide_permission_from_settings(Self::name(), &input.url, settings);
152
153        let authorize = match decision {
154            ToolPermissionDecision::Allow => None,
155            ToolPermissionDecision::Deny(reason) => {
156                return Task::ready(Err(anyhow::anyhow!("{}", reason)));
157            }
158            ToolPermissionDecision::Confirm => Some(
159                event_stream.authorize(format!("Fetch {}", MarkdownInlineCode(&input.url)), cx),
160            ),
161        };
162
163        let fetch_task = cx.background_spawn({
164            let http_client = self.http_client.clone();
165            async move {
166                if let Some(authorize) = authorize {
167                    authorize.await?;
168                }
169                Self::build_message(http_client, &input.url).await
170            }
171        });
172
173        cx.foreground_executor().spawn(async move {
174            let text = futures::select! {
175                result = fetch_task.fuse() => result?,
176                _ = event_stream.cancelled_by_user().fuse() => {
177                    anyhow::bail!("Fetch cancelled by user");
178                }
179            };
180            if text.trim().is_empty() {
181                bail!("no textual content found");
182            }
183            Ok(text)
184        })
185    }
186}