fetch_tool.rs

  1use std::cell::RefCell;
  2use std::rc::Rc;
  3use std::sync::Arc;
  4
  5use anyhow::{anyhow, bail, Context as _, Result};
  6use assistant_tool::{ActionLog, Tool};
  7use futures::AsyncReadExt as _;
  8use gpui::{App, AppContext as _, Entity, Task};
  9use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
 10use http_client::{AsyncBody, HttpClientWithUrl};
 11use language_model::LanguageModelRequestMessage;
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15
 16#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 17enum ContentType {
 18    Html,
 19    Plaintext,
 20    Json,
 21}
 22
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct FetchToolInput {
 25    /// The URL to fetch.
 26    url: String,
 27}
 28
 29pub struct FetchTool {
 30    http_client: Arc<HttpClientWithUrl>,
 31}
 32
 33impl FetchTool {
 34    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 35        Self { http_client }
 36    }
 37
 38    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 39        let mut url = url.to_owned();
 40        if !url.starts_with("https://") && !url.starts_with("http://") {
 41            url = format!("https://{url}");
 42        }
 43
 44        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 45
 46        let mut body = Vec::new();
 47        response
 48            .body_mut()
 49            .read_to_end(&mut body)
 50            .await
 51            .context("error reading response body")?;
 52
 53        if response.status().is_client_error() {
 54            let text = String::from_utf8_lossy(body.as_slice());
 55            bail!(
 56                "status error {}, response: {text:?}",
 57                response.status().as_u16()
 58            );
 59        }
 60
 61        let Some(content_type) = response.headers().get("content-type") else {
 62            bail!("missing Content-Type header");
 63        };
 64        let content_type = content_type
 65            .to_str()
 66            .context("invalid Content-Type header")?;
 67        let content_type = match content_type {
 68            "text/html" => ContentType::Html,
 69            "text/plain" => ContentType::Plaintext,
 70            "application/json" => ContentType::Json,
 71            _ => ContentType::Html,
 72        };
 73
 74        match content_type {
 75            ContentType::Html => {
 76                let mut handlers: Vec<TagHandler> = vec![
 77                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 78                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 79                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 80                    Rc::new(RefCell::new(markdown::ListHandler)),
 81                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 82                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 83                ];
 84                if url.contains("wikipedia.org") {
 85                    use html_to_markdown::structure::wikipedia;
 86
 87                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 88                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 89                    handlers.push(Rc::new(
 90                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
 91                    ));
 92                } else {
 93                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
 94                }
 95
 96                convert_html_to_markdown(&body[..], &mut handlers)
 97            }
 98            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
 99            ContentType::Json => {
100                let json: serde_json::Value = serde_json::from_slice(&body)?;
101
102                Ok(format!(
103                    "```json\n{}\n```",
104                    serde_json::to_string_pretty(&json)?
105                ))
106            }
107        }
108    }
109}
110
111impl Tool for FetchTool {
112    fn name(&self) -> String {
113        "fetch".to_string()
114    }
115
116    fn description(&self) -> String {
117        include_str!("./fetch_tool/description.md").to_string()
118    }
119
120    fn input_schema(&self) -> serde_json::Value {
121        let schema = schemars::schema_for!(FetchToolInput);
122        serde_json::to_value(&schema).unwrap()
123    }
124
125    fn ui_text(&self, input: &serde_json::Value) -> String {
126        match serde_json::from_value::<FetchToolInput>(input.clone()) {
127            Ok(input) => format!("Fetch `{}`", input.url),
128            Err(_) => "Fetch URL".to_string(),
129        }
130    }
131
132    fn run(
133        self: Arc<Self>,
134        input: serde_json::Value,
135        _messages: &[LanguageModelRequestMessage],
136        _project: Entity<Project>,
137        _action_log: Entity<ActionLog>,
138        cx: &mut App,
139    ) -> Task<Result<String>> {
140        let input = match serde_json::from_value::<FetchToolInput>(input) {
141            Ok(input) => input,
142            Err(err) => return Task::ready(Err(anyhow!(err))),
143        };
144
145        let text = cx.background_spawn({
146            let http_client = self.http_client.clone();
147            let url = input.url.clone();
148            async move { Self::build_message(http_client, &url).await }
149        });
150
151        cx.foreground_executor().spawn(async move {
152            let text = text.await?;
153            if text.trim().is_empty() {
154                bail!("no textual content found");
155            }
156
157            Ok(text)
158        })
159    }
160}