fetch_tool.rs

  1use std::rc::Rc;
  2use std::sync::Arc;
  3use std::{borrow::Cow, cell::RefCell};
  4
  5use crate::schema::json_schema_for;
  6use anyhow::{Context as _, Result, anyhow, bail};
  7use assistant_tool::{ActionLog, Tool, ToolResult};
  8use futures::AsyncReadExt as _;
  9use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
 10use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
 11use http_client::{AsyncBody, HttpClientWithUrl};
 12use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 13use project::Project;
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use ui::IconName;
 17use util::markdown::MarkdownEscaped;
 18
 19#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 20enum ContentType {
 21    Html,
 22    Plaintext,
 23    Json,
 24}
 25
 26#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 27pub struct FetchToolInput {
 28    /// The URL to fetch.
 29    url: String,
 30}
 31
 32pub struct FetchTool {
 33    http_client: Arc<HttpClientWithUrl>,
 34}
 35
 36impl FetchTool {
 37    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 38        Self { http_client }
 39    }
 40
 41    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 42        let url = if !url.starts_with("https://") && !url.starts_with("http://") {
 43            Cow::Owned(format!("https://{url}"))
 44        } else {
 45            Cow::Borrowed(url)
 46        };
 47
 48        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 49
 50        let mut body = Vec::new();
 51        response
 52            .body_mut()
 53            .read_to_end(&mut body)
 54            .await
 55            .context("error reading response body")?;
 56
 57        if response.status().is_client_error() {
 58            let text = String::from_utf8_lossy(body.as_slice());
 59            bail!(
 60                "status error {}, response: {text:?}",
 61                response.status().as_u16()
 62            );
 63        }
 64
 65        let Some(content_type) = response.headers().get("content-type") else {
 66            bail!("missing Content-Type header");
 67        };
 68        let content_type = content_type
 69            .to_str()
 70            .context("invalid Content-Type header")?;
 71        let content_type = match content_type {
 72            "text/html" | "application/xhtml+xml" => ContentType::Html,
 73            "application/json" => ContentType::Json,
 74            _ => ContentType::Plaintext,
 75        };
 76
 77        match content_type {
 78            ContentType::Html => {
 79                let mut handlers: Vec<TagHandler> = vec![
 80                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 81                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 82                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 83                    Rc::new(RefCell::new(markdown::ListHandler)),
 84                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 85                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 86                ];
 87                if url.contains("wikipedia.org") {
 88                    use html_to_markdown::structure::wikipedia;
 89
 90                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 91                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 92                    handlers.push(Rc::new(
 93                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
 94                    ));
 95                } else {
 96                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
 97                }
 98
 99                convert_html_to_markdown(&body[..], &mut handlers)
100            }
101            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
102            ContentType::Json => {
103                let json: serde_json::Value = serde_json::from_slice(&body)?;
104
105                Ok(format!(
106                    "```json\n{}\n```",
107                    serde_json::to_string_pretty(&json)?
108                ))
109            }
110        }
111    }
112}
113
114impl Tool for FetchTool {
115    fn name(&self) -> String {
116        "fetch".to_string()
117    }
118
119    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
120        false
121    }
122
123    fn may_perform_edits(&self) -> bool {
124        false
125    }
126
127    fn description(&self) -> String {
128        include_str!("./fetch_tool/description.md").to_string()
129    }
130
131    fn icon(&self) -> IconName {
132        IconName::ToolWeb
133    }
134
135    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
136        json_schema_for::<FetchToolInput>(format)
137    }
138
139    fn ui_text(&self, input: &serde_json::Value) -> String {
140        match serde_json::from_value::<FetchToolInput>(input.clone()) {
141            Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
142            Err(_) => "Fetch URL".to_string(),
143        }
144    }
145
146    fn run(
147        self: Arc<Self>,
148        input: serde_json::Value,
149        _request: Arc<LanguageModelRequest>,
150        _project: Entity<Project>,
151        _action_log: Entity<ActionLog>,
152        _model: Arc<dyn LanguageModel>,
153        _window: Option<AnyWindowHandle>,
154        cx: &mut App,
155    ) -> ToolResult {
156        let input = match serde_json::from_value::<FetchToolInput>(input) {
157            Ok(input) => input,
158            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
159        };
160
161        let text = cx.background_spawn({
162            let http_client = self.http_client.clone();
163            async move { Self::build_message(http_client, &input.url).await }
164        });
165
166        cx.foreground_executor()
167            .spawn(async move {
168                let text = text.await?;
169                if text.trim().is_empty() {
170                    bail!("no textual content found");
171                }
172
173                Ok(text.into())
174            })
175            .into()
176    }
177}