fetch_tool.rs

  1use std::rc::Rc;
  2use std::sync::Arc;
  3use std::{borrow::Cow, cell::RefCell};
  4
  5use crate::schema::json_schema_for;
  6use action_log::ActionLog;
  7use anyhow::{Context as _, Result, anyhow, bail};
  8use assistant_tool::{Tool, ToolResult};
  9use futures::AsyncReadExt as _;
 10use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
 11use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
 12use http_client::{AsyncBody, HttpClientWithUrl};
 13use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 14use project::Project;
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use ui::IconName;
 18use util::markdown::MarkdownEscaped;
 19
 20#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
 21enum ContentType {
 22    Html,
 23    Plaintext,
 24    Json,
 25}
 26
 27#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 28pub struct FetchToolInput {
 29    /// The URL to fetch.
 30    url: String,
 31}
 32
 33pub struct FetchTool {
 34    http_client: Arc<HttpClientWithUrl>,
 35}
 36
 37impl FetchTool {
 38    pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
 39        Self { http_client }
 40    }
 41
 42    async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
 43        let url = if !url.starts_with("https://") && !url.starts_with("http://") {
 44            Cow::Owned(format!("https://{url}"))
 45        } else {
 46            Cow::Borrowed(url)
 47        };
 48
 49        let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
 50
 51        let mut body = Vec::new();
 52        response
 53            .body_mut()
 54            .read_to_end(&mut body)
 55            .await
 56            .context("error reading response body")?;
 57
 58        if response.status().is_client_error() {
 59            let text = String::from_utf8_lossy(body.as_slice());
 60            bail!(
 61                "status error {}, response: {text:?}",
 62                response.status().as_u16()
 63            );
 64        }
 65
 66        let Some(content_type) = response.headers().get("content-type") else {
 67            bail!("missing Content-Type header");
 68        };
 69        let content_type = content_type
 70            .to_str()
 71            .context("invalid Content-Type header")?;
 72        let content_type = match content_type {
 73            "text/html" | "application/xhtml+xml" => ContentType::Html,
 74            "application/json" => ContentType::Json,
 75            _ => ContentType::Plaintext,
 76        };
 77
 78        match content_type {
 79            ContentType::Html => {
 80                let mut handlers: Vec<TagHandler> = vec![
 81                    Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
 82                    Rc::new(RefCell::new(markdown::ParagraphHandler)),
 83                    Rc::new(RefCell::new(markdown::HeadingHandler)),
 84                    Rc::new(RefCell::new(markdown::ListHandler)),
 85                    Rc::new(RefCell::new(markdown::TableHandler::new())),
 86                    Rc::new(RefCell::new(markdown::StyledTextHandler)),
 87                ];
 88                if url.contains("wikipedia.org") {
 89                    use html_to_markdown::structure::wikipedia;
 90
 91                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
 92                    handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
 93                    handlers.push(Rc::new(
 94                        RefCell::new(wikipedia::WikipediaCodeHandler::new()),
 95                    ));
 96                } else {
 97                    handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
 98                }
 99
100                convert_html_to_markdown(&body[..], &mut handlers)
101            }
102            ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
103            ContentType::Json => {
104                let json: serde_json::Value = serde_json::from_slice(&body)?;
105
106                Ok(format!(
107                    "```json\n{}\n```",
108                    serde_json::to_string_pretty(&json)?
109                ))
110            }
111        }
112    }
113}
114
115impl Tool for FetchTool {
116    fn name(&self) -> String {
117        "fetch".to_string()
118    }
119
120    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
121        true
122    }
123
124    fn may_perform_edits(&self) -> bool {
125        false
126    }
127
128    fn description(&self) -> String {
129        include_str!("./fetch_tool/description.md").to_string()
130    }
131
132    fn icon(&self) -> IconName {
133        IconName::ToolWeb
134    }
135
136    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
137        json_schema_for::<FetchToolInput>(format)
138    }
139
140    fn ui_text(&self, input: &serde_json::Value) -> String {
141        match serde_json::from_value::<FetchToolInput>(input.clone()) {
142            Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
143            Err(_) => "Fetch URL".to_string(),
144        }
145    }
146
147    fn run(
148        self: Arc<Self>,
149        input: serde_json::Value,
150        _request: Arc<LanguageModelRequest>,
151        _project: Entity<Project>,
152        _action_log: Entity<ActionLog>,
153        _model: Arc<dyn LanguageModel>,
154        _window: Option<AnyWindowHandle>,
155        cx: &mut App,
156    ) -> ToolResult {
157        let input = match serde_json::from_value::<FetchToolInput>(input) {
158            Ok(input) => input,
159            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
160        };
161
162        let text = cx.background_spawn({
163            let http_client = self.http_client.clone();
164            async move { Self::build_message(http_client, &input.url).await }
165        });
166
167        cx.foreground_executor()
168            .spawn(async move {
169                let text = text.await?;
170                if text.trim().is_empty() {
171                    bail!("no textual content found");
172                }
173
174                Ok(text.into())
175            })
176            .into()
177    }
178}