1use std::cell::RefCell;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _, Result};
6use assistant_tool::{ActionLog, Tool};
7use futures::AsyncReadExt as _;
8use gpui::{App, AppContext as _, Entity, Task};
9use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
10use http_client::{AsyncBody, HttpClientWithUrl};
11use language_model::LanguageModelRequestMessage;
12use project::Project;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
17enum ContentType {
18 Html,
19 Plaintext,
20 Json,
21}
22
23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
24pub struct FetchToolInput {
25 /// The URL to fetch.
26 url: String,
27}
28
29pub struct FetchTool {
30 http_client: Arc<HttpClientWithUrl>,
31}
32
33impl FetchTool {
34 pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
35 Self { http_client }
36 }
37
38 async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
39 let mut url = url.to_owned();
40 if !url.starts_with("https://") && !url.starts_with("http://") {
41 url = format!("https://{url}");
42 }
43
44 let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
45
46 let mut body = Vec::new();
47 response
48 .body_mut()
49 .read_to_end(&mut body)
50 .await
51 .context("error reading response body")?;
52
53 if response.status().is_client_error() {
54 let text = String::from_utf8_lossy(body.as_slice());
55 bail!(
56 "status error {}, response: {text:?}",
57 response.status().as_u16()
58 );
59 }
60
61 let Some(content_type) = response.headers().get("content-type") else {
62 bail!("missing Content-Type header");
63 };
64 let content_type = content_type
65 .to_str()
66 .context("invalid Content-Type header")?;
67 let content_type = match content_type {
68 "text/html" => ContentType::Html,
69 "text/plain" => ContentType::Plaintext,
70 "application/json" => ContentType::Json,
71 _ => ContentType::Html,
72 };
73
74 match content_type {
75 ContentType::Html => {
76 let mut handlers: Vec<TagHandler> = vec![
77 Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
78 Rc::new(RefCell::new(markdown::ParagraphHandler)),
79 Rc::new(RefCell::new(markdown::HeadingHandler)),
80 Rc::new(RefCell::new(markdown::ListHandler)),
81 Rc::new(RefCell::new(markdown::TableHandler::new())),
82 Rc::new(RefCell::new(markdown::StyledTextHandler)),
83 ];
84 if url.contains("wikipedia.org") {
85 use html_to_markdown::structure::wikipedia;
86
87 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
88 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
89 handlers.push(Rc::new(
90 RefCell::new(wikipedia::WikipediaCodeHandler::new()),
91 ));
92 } else {
93 handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
94 }
95
96 convert_html_to_markdown(&body[..], &mut handlers)
97 }
98 ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
99 ContentType::Json => {
100 let json: serde_json::Value = serde_json::from_slice(&body)?;
101
102 Ok(format!(
103 "```json\n{}\n```",
104 serde_json::to_string_pretty(&json)?
105 ))
106 }
107 }
108 }
109}
110
111impl Tool for FetchTool {
112 fn name(&self) -> String {
113 "fetch".to_string()
114 }
115
116 fn needs_confirmation(&self) -> bool {
117 true
118 }
119
120 fn description(&self) -> String {
121 include_str!("./fetch_tool/description.md").to_string()
122 }
123
124 fn input_schema(&self) -> serde_json::Value {
125 let schema = schemars::schema_for!(FetchToolInput);
126 serde_json::to_value(&schema).unwrap()
127 }
128
129 fn ui_text(&self, input: &serde_json::Value) -> String {
130 match serde_json::from_value::<FetchToolInput>(input.clone()) {
131 Ok(input) => format!("Fetch `{}`", input.url),
132 Err(_) => "Fetch URL".to_string(),
133 }
134 }
135
136 fn run(
137 self: Arc<Self>,
138 input: serde_json::Value,
139 _messages: &[LanguageModelRequestMessage],
140 _project: Entity<Project>,
141 _action_log: Entity<ActionLog>,
142 cx: &mut App,
143 ) -> Task<Result<String>> {
144 let input = match serde_json::from_value::<FetchToolInput>(input) {
145 Ok(input) => input,
146 Err(err) => return Task::ready(Err(anyhow!(err))),
147 };
148
149 let text = cx.background_spawn({
150 let http_client = self.http_client.clone();
151 let url = input.url.clone();
152 async move { Self::build_message(http_client, &url).await }
153 });
154
155 cx.foreground_executor().spawn(async move {
156 let text = text.await?;
157 if text.trim().is_empty() {
158 bail!("no textual content found");
159 }
160
161 Ok(text)
162 })
163 }
164}