fetch_tool.rs

  1use std::cell::RefCell;
  2use std::rc::Rc;
  3use std::sync::Arc;
  4
  5use anyhow::{anyhow, bail, Context as _, Result};
  6use assistant_tool::{ActionLog, Tool};
  7use futures::AsyncReadExt as _;
  8use gpui::{App, AppContext as _, Entity, Task};
  9use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
 10use http_client::{AsyncBody, HttpClientWithUrl};
 11use language_model::LanguageModelRequestMessage;
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use ui::IconName;
 16use util::markdown::MarkdownString;
 17
 18#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 19enum ContentType {
 20    Html,
 21    Plaintext,
 22    Json,
 23}
 24
 25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 26pub struct FetchToolInput {
 27    /// The URL to fetch.
 28    url: String,
 29}
 30
 31pub struct FetchTool {
 32    http_client: Arc<HttpClientWithUrl>,
 33}
 34
 35impl FetchTool {
 36    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 37        Self { http_client }
 38    }
 39
 40    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 41        let mut url = url.to_owned();
 42        if !url.starts_with("https://") && !url.starts_with("http://") {
 43            url = format!("https://{url}");
 44        }
 45
 46        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 47
 48        let mut body = Vec::new();
 49        response
 50            .body_mut()
 51            .read_to_end(&mut body)
 52            .await
 53            .context("error reading response body")?;
 54
 55        if response.status().is_client_error() {
 56            let text = String::from_utf8_lossy(body.as_slice());
 57            bail!(
 58                "status error {}, response: {text:?}",
 59                response.status().as_u16()
 60            );
 61        }
 62
 63        let Some(content_type) = response.headers().get("content-type") else {
 64            bail!("missing Content-Type header");
 65        };
 66        let content_type = content_type
 67            .to_str()
 68            .context("invalid Content-Type header")?;
 69        let content_type = match content_type {
 70            "text/html" => ContentType::Html,
 71            "text/plain" => ContentType::Plaintext,
 72            "application/json" => ContentType::Json,
 73            _ => ContentType::Html,
 74        };
 75
 76        match content_type {
 77            ContentType::Html => {
 78                let mut handlers: Vec<TagHandler> = vec![
 79                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 80                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 81                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 82                    Rc::new(RefCell::new(markdown::ListHandler)),
 83                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 84                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 85                ];
 86                if url.contains("wikipedia.org") {
 87                    use html_to_markdown::structure::wikipedia;
 88
 89                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 90                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 91                    handlers.push(Rc::new(
 92                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
 93                    ));
 94                } else {
 95                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
 96                }
 97
 98                convert_html_to_markdown(&body[..], &mut handlers)
 99            }
100            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
101            ContentType::Json => {
102                let json: serde_json::Value = serde_json::from_slice(&body)?;
103
104                Ok(format!(
105                    "```json\n{}\n```",
106                    serde_json::to_string_pretty(&json)?
107                ))
108            }
109        }
110    }
111}
112
113impl Tool for FetchTool {
114    fn name(&self) -> String {
115        "fetch".to_string()
116    }
117
118    fn needs_confirmation(&self) -> bool {
119        true
120    }
121
122    fn description(&self) -> String {
123        include_str!("./fetch_tool/description.md").to_string()
124    }
125
126    fn icon(&self) -> IconName {
127        IconName::Globe
128    }
129
130    fn input_schema(&self) -> serde_json::Value {
131        let schema = schemars::schema_for!(FetchToolInput);
132        serde_json::to_value(&schema).unwrap()
133    }
134
135    fn ui_text(&self, input: &serde_json::Value) -> String {
136        match serde_json::from_value::<FetchToolInput>(input.clone()) {
137            Ok(input) => format!("Fetch {}", MarkdownString::escape(&input.url)),
138            Err(_) => "Fetch URL".to_string(),
139        }
140    }
141
142    fn run(
143        self: Arc<Self>,
144        input: serde_json::Value,
145        _messages: &[LanguageModelRequestMessage],
146        _project: Entity<Project>,
147        _action_log: Entity<ActionLog>,
148        cx: &mut App,
149    ) -> Task<Result<String>> {
150        let input = match serde_json::from_value::<FetchToolInput>(input) {
151            Ok(input) => input,
152            Err(err) => return Task::ready(Err(anyhow!(err))),
153        };
154
155        let text = cx.background_spawn({
156            let http_client = self.http_client.clone();
157            let url = input.url.clone();
158            async move { Self::build_message(http_client, &url).await }
159        });
160
161        cx.foreground_executor().spawn(async move {
162            let text = text.await?;
163            if text.trim().is_empty() {
164                bail!("no textual content found");
165            }
166
167            Ok(text)
168        })
169    }
170}