fetch_tool.rs

  1use std::cell::RefCell;
  2use std::rc::Rc;
  3use std::sync::Arc;
  4
  5use anyhow::{anyhow, bail, Context as _, Result};
  6use assistant_tool::{ActionLog, Tool};
  7use futures::AsyncReadExt as _;
  8use gpui::{App, AppContext as _, Entity, Task};
  9use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
 10use http_client::{AsyncBody, HttpClientWithUrl};
 11use language_model::LanguageModelRequestMessage;
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15
 16#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 17enum ContentType {
 18    Html,
 19    Plaintext,
 20    Json,
 21}
 22
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct FetchToolInput {
 25    /// The URL to fetch.
 26    url: String,
 27}
 28
 29pub struct FetchTool {
 30    http_client: Arc<HttpClientWithUrl>,
 31}
 32
 33impl FetchTool {
 34    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 35        Self { http_client }
 36    }
 37
 38    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 39        let mut url = url.to_owned();
 40        if !url.starts_with("https://") && !url.starts_with("http://") {
 41            url = format!("https://{url}");
 42        }
 43
 44        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 45
 46        let mut body = Vec::new();
 47        response
 48            .body_mut()
 49            .read_to_end(&mut body)
 50            .await
 51            .context("error reading response body")?;
 52
 53        if response.status().is_client_error() {
 54            let text = String::from_utf8_lossy(body.as_slice());
 55            bail!(
 56                "status error {}, response: {text:?}",
 57                response.status().as_u16()
 58            );
 59        }
 60
 61        let Some(content_type) = response.headers().get("content-type") else {
 62            bail!("missing Content-Type header");
 63        };
 64        let content_type = content_type
 65            .to_str()
 66            .context("invalid Content-Type header")?;
 67        let content_type = match content_type {
 68            "text/html" => ContentType::Html,
 69            "text/plain" => ContentType::Plaintext,
 70            "application/json" => ContentType::Json,
 71            _ => ContentType::Html,
 72        };
 73
 74        match content_type {
 75            ContentType::Html => {
 76                let mut handlers: Vec<TagHandler> = vec![
 77                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 78                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 79                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 80                    Rc::new(RefCell::new(markdown::ListHandler)),
 81                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 82                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 83                ];
 84                if url.contains("wikipedia.org") {
 85                    use html_to_markdown::structure::wikipedia;
 86
 87                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 88                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 89                    handlers.push(Rc::new(
 90                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
 91                    ));
 92                } else {
 93                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
 94                }
 95
 96                convert_html_to_markdown(&body[..], &mut handlers)
 97            }
 98            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
 99            ContentType::Json => {
100                let json: serde_json::Value = serde_json::from_slice(&body)?;
101
102                Ok(format!(
103                    "```json\n{}\n```",
104                    serde_json::to_string_pretty(&json)?
105                ))
106            }
107        }
108    }
109}
110
111impl Tool for FetchTool {
112    fn name(&self) -> String {
113        "fetch".to_string()
114    }
115
116    fn needs_confirmation(&self) -> bool {
117        true
118    }
119
120    fn description(&self) -> String {
121        include_str!("./fetch_tool/description.md").to_string()
122    }
123
124    fn input_schema(&self) -> serde_json::Value {
125        let schema = schemars::schema_for!(FetchToolInput);
126        serde_json::to_value(&schema).unwrap()
127    }
128
129    fn ui_text(&self, input: &serde_json::Value) -> String {
130        match serde_json::from_value::<FetchToolInput>(input.clone()) {
131            Ok(input) => format!("Fetch `{}`", input.url),
132            Err(_) => "Fetch URL".to_string(),
133        }
134    }
135
136    fn run(
137        self: Arc<Self>,
138        input: serde_json::Value,
139        _messages: &[LanguageModelRequestMessage],
140        _project: Entity<Project>,
141        _action_log: Entity<ActionLog>,
142        cx: &mut App,
143    ) -> Task<Result<String>> {
144        let input = match serde_json::from_value::<FetchToolInput>(input) {
145            Ok(input) => input,
146            Err(err) => return Task::ready(Err(anyhow!(err))),
147        };
148
149        let text = cx.background_spawn({
150            let http_client = self.http_client.clone();
151            let url = input.url.clone();
152            async move { Self::build_message(http_client, &url).await }
153        });
154
155        cx.foreground_executor().spawn(async move {
156            let text = text.await?;
157            if text.trim().is_empty() {
158                bail!("no textual content found");
159            }
160
161            Ok(text)
162        })
163    }
164}