fetch_tool.rs

  1use std::rc::Rc;
  2use std::sync::Arc;
  3use std::{borrow::Cow, cell::RefCell};
  4
  5use agent_client_protocol as acp;
  6use agent_settings::AgentSettings;
  7use anyhow::{Context as _, Result, bail};
  8use futures::{AsyncReadExt as _, FutureExt as _};
  9use gpui::{App, AppContext as _, Task};
 10use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
 11use http_client::{AsyncBody, HttpClientWithUrl};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::Settings;
 15use ui::SharedString;
 16use util::markdown::{MarkdownEscaped, MarkdownInlineCode};
 17
 18use crate::{
 19    AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
 20};
 21
 22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 23enum ContentType {
 24    Html,
 25    Plaintext,
 26    Json,
 27}
 28
 29/// Fetches a URL and returns the content as Markdown.
 30#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 31pub struct FetchToolInput {
 32    /// The URL to fetch.
 33    url: String,
 34}
 35
 36pub struct FetchTool {
 37    http_client: Arc<HttpClientWithUrl>,
 38}
 39
 40impl FetchTool {
 41    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 42        Self { http_client }
 43    }
 44
 45    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 46        let url = if !url.starts_with("https://") && !url.starts_with("http://") {
 47            Cow::Owned(format!("https://{url}"))
 48        } else {
 49            Cow::Borrowed(url)
 50        };
 51
 52        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 53
 54        let mut body = Vec::new();
 55        response
 56            .body_mut()
 57            .read_to_end(&mut body)
 58            .await
 59            .context("error reading response body")?;
 60
 61        if response.status().is_client_error() {
 62            let text = String::from_utf8_lossy(body.as_slice());
 63            bail!(
 64                "status error {}, response: {text:?}",
 65                response.status().as_u16()
 66            );
 67        }
 68
 69        let Some(content_type) = response.headers().get("content-type") else {
 70            bail!("missing Content-Type header");
 71        };
 72        let content_type = content_type
 73            .to_str()
 74            .context("invalid Content-Type header")?;
 75
 76        let content_type = if content_type.starts_with("text/plain") {
 77            ContentType::Plaintext
 78        } else if content_type.starts_with("application/json") {
 79            ContentType::Json
 80        } else {
 81            ContentType::Html
 82        };
 83
 84        match content_type {
 85            ContentType::Html => {
 86                let mut handlers: Vec<TagHandler> = vec![
 87                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 88                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 89                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 90                    Rc::new(RefCell::new(markdown::ListHandler)),
 91                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 92                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 93                ];
 94                if url.contains("wikipedia.org") {
 95                    use html_to_markdown::structure::wikipedia;
 96
 97                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 98                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 99                    handlers.push(Rc::new(
100                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
101                    ));
102                } else {
103                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
104                }
105
106                convert_html_to_markdown(&body[..], &mut handlers)
107            }
108            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
109            ContentType::Json => {
110                let json: serde_json::Value = serde_json::from_slice(&body)?;
111
112                Ok(format!(
113                    "```json\n{}\n```",
114                    serde_json::to_string_pretty(&json)?
115                ))
116            }
117        }
118    }
119}
120
121impl AgentTool for FetchTool {
122    type Input = FetchToolInput;
123    type Output = String;
124
125    const NAME: &'static str = "fetch";
126
127    fn kind() -> acp::ToolKind {
128        acp::ToolKind::Fetch
129    }
130
131    fn initial_title(
132        &self,
133        input: Result<Self::Input, serde_json::Value>,
134        _cx: &mut App,
135    ) -> SharedString {
136        match input {
137            Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
138            Err(_) => "Fetch URL".into(),
139        }
140    }
141
142    fn run(
143        self: Arc<Self>,
144        input: Self::Input,
145        event_stream: ToolCallEventStream,
146        cx: &mut App,
147    ) -> Task<Result<Self::Output>> {
148        let settings = AgentSettings::get_global(cx);
149        let decision =
150            decide_permission_from_settings(Self::NAME, std::slice::from_ref(&input.url), settings);
151
152        let authorize = match decision {
153            ToolPermissionDecision::Allow => None,
154            ToolPermissionDecision::Deny(reason) => {
155                return Task::ready(Err(anyhow::anyhow!("{}", reason)));
156            }
157            ToolPermissionDecision::Confirm => {
158                let context =
159                    crate::ToolPermissionContext::new(Self::NAME, vec![input.url.clone()]);
160                Some(event_stream.authorize(
161                    format!("Fetch {}", MarkdownInlineCode(&input.url)),
162                    context,
163                    cx,
164                ))
165            }
166        };
167
168        let fetch_task = cx.background_spawn({
169            let http_client = self.http_client.clone();
170            async move {
171                if let Some(authorize) = authorize {
172                    authorize.await?;
173                }
174                Self::build_message(http_client, &input.url).await
175            }
176        });
177
178        cx.foreground_executor().spawn(async move {
179            let text = futures::select! {
180                result = fetch_task.fuse() => result?,
181                _ = event_stream.cancelled_by_user().fuse() => {
182                    anyhow::bail!("Fetch cancelled by user");
183                }
184            };
185            if text.trim().is_empty() {
186                bail!("no textual content found");
187            }
188            Ok(text)
189        })
190    }
191}