1use std::cell::RefCell;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _, Result};
6use assistant_tool::{ActionLog, Tool};
7use futures::AsyncReadExt as _;
8use gpui::{App, AppContext as _, Entity, Task};
9use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
10use http_client::{AsyncBody, HttpClientWithUrl};
11use language_model::LanguageModelRequestMessage;
12use project::Project;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use ui::IconName;
16use util::markdown::MarkdownString;
17
18#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
19enum ContentType {
20 Html,
21 Plaintext,
22 Json,
23}
24
25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
26pub struct FetchToolInput {
27 /// The URL to fetch.
28 url: String,
29}
30
31pub struct FetchTool {
32 http_client: Arc<HttpClientWithUrl>,
33}
34
35impl FetchTool {
36 pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
37 Self { http_client }
38 }
39
40 async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
41 let mut url = url.to_owned();
42 if !url.starts_with("https://") && !url.starts_with("http://") {
43 url = format!("https://{url}");
44 }
45
46 let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
47
48 let mut body = Vec::new();
49 response
50 .body_mut()
51 .read_to_end(&mut body)
52 .await
53 .context("error reading response body")?;
54
55 if response.status().is_client_error() {
56 let text = String::from_utf8_lossy(body.as_slice());
57 bail!(
58 "status error {}, response: {text:?}",
59 response.status().as_u16()
60 );
61 }
62
63 let Some(content_type) = response.headers().get("content-type") else {
64 bail!("missing Content-Type header");
65 };
66 let content_type = content_type
67 .to_str()
68 .context("invalid Content-Type header")?;
69 let content_type = match content_type {
70 "text/html" => ContentType::Html,
71 "text/plain" => ContentType::Plaintext,
72 "application/json" => ContentType::Json,
73 _ => ContentType::Html,
74 };
75
76 match content_type {
77 ContentType::Html => {
78 let mut handlers: Vec<TagHandler> = vec![
79 Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
80 Rc::new(RefCell::new(markdown::ParagraphHandler)),
81 Rc::new(RefCell::new(markdown::HeadingHandler)),
82 Rc::new(RefCell::new(markdown::ListHandler)),
83 Rc::new(RefCell::new(markdown::TableHandler::new())),
84 Rc::new(RefCell::new(markdown::StyledTextHandler)),
85 ];
86 if url.contains("wikipedia.org") {
87 use html_to_markdown::structure::wikipedia;
88
89 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
90 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
91 handlers.push(Rc::new(
92 RefCell::new(wikipedia::WikipediaCodeHandler::new()),
93 ));
94 } else {
95 handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
96 }
97
98 convert_html_to_markdown(&body[..], &mut handlers)
99 }
100 ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
101 ContentType::Json => {
102 let json: serde_json::Value = serde_json::from_slice(&body)?;
103
104 Ok(format!(
105 "```json\n{}\n```",
106 serde_json::to_string_pretty(&json)?
107 ))
108 }
109 }
110 }
111}
112
113impl Tool for FetchTool {
114 fn name(&self) -> String {
115 "fetch".to_string()
116 }
117
118 fn needs_confirmation(&self) -> bool {
119 true
120 }
121
122 fn description(&self) -> String {
123 include_str!("./fetch_tool/description.md").to_string()
124 }
125
126 fn icon(&self) -> IconName {
127 IconName::Globe
128 }
129
130 fn input_schema(&self) -> serde_json::Value {
131 let schema = schemars::schema_for!(FetchToolInput);
132 serde_json::to_value(&schema).unwrap()
133 }
134
135 fn ui_text(&self, input: &serde_json::Value) -> String {
136 match serde_json::from_value::<FetchToolInput>(input.clone()) {
137 Ok(input) => format!("Fetch {}", MarkdownString::escape(&input.url)),
138 Err(_) => "Fetch URL".to_string(),
139 }
140 }
141
142 fn run(
143 self: Arc<Self>,
144 input: serde_json::Value,
145 _messages: &[LanguageModelRequestMessage],
146 _project: Entity<Project>,
147 _action_log: Entity<ActionLog>,
148 cx: &mut App,
149 ) -> Task<Result<String>> {
150 let input = match serde_json::from_value::<FetchToolInput>(input) {
151 Ok(input) => input,
152 Err(err) => return Task::ready(Err(anyhow!(err))),
153 };
154
155 let text = cx.background_spawn({
156 let http_client = self.http_client.clone();
157 let url = input.url.clone();
158 async move { Self::build_message(http_client, &url).await }
159 });
160
161 cx.foreground_executor().spawn(async move {
162 let text = text.await?;
163 if text.trim().is_empty() {
164 bail!("no textual content found");
165 }
166
167 Ok(text)
168 })
169 }
170}