fetch_tool.rs

  1use std::cell::RefCell;
  2use std::rc::Rc;
  3use std::sync::Arc;
  4
  5use anyhow::{anyhow, bail, Context as _, Result};
  6use assistant_tool::{ActionLog, Tool};
  7use futures::AsyncReadExt as _;
  8use gpui::{App, AppContext as _, Entity, Task};
  9use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
 10use http_client::{AsyncBody, HttpClientWithUrl};
 11use language_model::LanguageModelRequestMessage;
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use ui::IconName;
 16
 17#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 18enum ContentType {
 19    Html,
 20    Plaintext,
 21    Json,
 22}
 23
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct FetchToolInput {
 26    /// The URL to fetch.
 27    url: String,
 28}
 29
 30pub struct FetchTool {
 31    http_client: Arc<HttpClientWithUrl>,
 32}
 33
 34impl FetchTool {
 35    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 36        Self { http_client }
 37    }
 38
 39    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 40        let mut url = url.to_owned();
 41        if !url.starts_with("https://") && !url.starts_with("http://") {
 42            url = format!("https://{url}");
 43        }
 44
 45        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 46
 47        let mut body = Vec::new();
 48        response
 49            .body_mut()
 50            .read_to_end(&mut body)
 51            .await
 52            .context("error reading response body")?;
 53
 54        if response.status().is_client_error() {
 55            let text = String::from_utf8_lossy(body.as_slice());
 56            bail!(
 57                "status error {}, response: {text:?}",
 58                response.status().as_u16()
 59            );
 60        }
 61
 62        let Some(content_type) = response.headers().get("content-type") else {
 63            bail!("missing Content-Type header");
 64        };
 65        let content_type = content_type
 66            .to_str()
 67            .context("invalid Content-Type header")?;
 68        let content_type = match content_type {
 69            "text/html" => ContentType::Html,
 70            "text/plain" => ContentType::Plaintext,
 71            "application/json" => ContentType::Json,
 72            _ => ContentType::Html,
 73        };
 74
 75        match content_type {
 76            ContentType::Html => {
 77                let mut handlers: Vec<TagHandler> = vec![
 78                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 79                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 80                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 81                    Rc::new(RefCell::new(markdown::ListHandler)),
 82                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 83                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 84                ];
 85                if url.contains("wikipedia.org") {
 86                    use html_to_markdown::structure::wikipedia;
 87
 88                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 89                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 90                    handlers.push(Rc::new(
 91                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
 92                    ));
 93                } else {
 94                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
 95                }
 96
 97                convert_html_to_markdown(&body[..], &mut handlers)
 98            }
 99            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
100            ContentType::Json => {
101                let json: serde_json::Value = serde_json::from_slice(&body)?;
102
103                Ok(format!(
104                    "```json\n{}\n```",
105                    serde_json::to_string_pretty(&json)?
106                ))
107            }
108        }
109    }
110}
111
112impl Tool for FetchTool {
113    fn name(&self) -> String {
114        "fetch".to_string()
115    }
116
117    fn needs_confirmation(&self) -> bool {
118        true
119    }
120
121    fn description(&self) -> String {
122        include_str!("./fetch_tool/description.md").to_string()
123    }
124
125    fn icon(&self) -> IconName {
126        IconName::Globe
127    }
128
129    fn input_schema(&self) -> serde_json::Value {
130        let schema = schemars::schema_for!(FetchToolInput);
131        serde_json::to_value(&schema).unwrap()
132    }
133
134    fn ui_text(&self, input: &serde_json::Value) -> String {
135        match serde_json::from_value::<FetchToolInput>(input.clone()) {
136            Ok(input) => format!("Fetch `{}`", input.url),
137            Err(_) => "Fetch URL".to_string(),
138        }
139    }
140
141    fn run(
142        self: Arc<Self>,
143        input: serde_json::Value,
144        _messages: &[LanguageModelRequestMessage],
145        _project: Entity<Project>,
146        _action_log: Entity<ActionLog>,
147        cx: &mut App,
148    ) -> Task<Result<String>> {
149        let input = match serde_json::from_value::<FetchToolInput>(input) {
150            Ok(input) => input,
151            Err(err) => return Task::ready(Err(anyhow!(err))),
152        };
153
154        let text = cx.background_spawn({
155            let http_client = self.http_client.clone();
156            let url = input.url.clone();
157            async move { Self::build_message(http_client, &url).await }
158        });
159
160        cx.foreground_executor().spawn(async move {
161            let text = text.await?;
162            if text.trim().is_empty() {
163                bail!("no textual content found");
164            }
165
166            Ok(text)
167        })
168    }
169}