1use std::cell::RefCell;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _, Result};
6use assistant_tool::{ActionLog, Tool};
7use futures::AsyncReadExt as _;
8use gpui::{App, AppContext as _, Entity, Task};
9use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
10use http_client::{AsyncBody, HttpClientWithUrl};
11use language_model::LanguageModelRequestMessage;
12use project::Project;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use ui::IconName;
16
17#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
18enum ContentType {
19 Html,
20 Plaintext,
21 Json,
22}
23
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25pub struct FetchToolInput {
26 /// The URL to fetch.
27 url: String,
28}
29
30pub struct FetchTool {
31 http_client: Arc<HttpClientWithUrl>,
32}
33
34impl FetchTool {
35 pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
36 Self { http_client }
37 }
38
39 async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
40 let mut url = url.to_owned();
41 if !url.starts_with("https://") && !url.starts_with("http://") {
42 url = format!("https://{url}");
43 }
44
45 let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
46
47 let mut body = Vec::new();
48 response
49 .body_mut()
50 .read_to_end(&mut body)
51 .await
52 .context("error reading response body")?;
53
54 if response.status().is_client_error() {
55 let text = String::from_utf8_lossy(body.as_slice());
56 bail!(
57 "status error {}, response: {text:?}",
58 response.status().as_u16()
59 );
60 }
61
62 let Some(content_type) = response.headers().get("content-type") else {
63 bail!("missing Content-Type header");
64 };
65 let content_type = content_type
66 .to_str()
67 .context("invalid Content-Type header")?;
68 let content_type = match content_type {
69 "text/html" => ContentType::Html,
70 "text/plain" => ContentType::Plaintext,
71 "application/json" => ContentType::Json,
72 _ => ContentType::Html,
73 };
74
75 match content_type {
76 ContentType::Html => {
77 let mut handlers: Vec<TagHandler> = vec![
78 Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
79 Rc::new(RefCell::new(markdown::ParagraphHandler)),
80 Rc::new(RefCell::new(markdown::HeadingHandler)),
81 Rc::new(RefCell::new(markdown::ListHandler)),
82 Rc::new(RefCell::new(markdown::TableHandler::new())),
83 Rc::new(RefCell::new(markdown::StyledTextHandler)),
84 ];
85 if url.contains("wikipedia.org") {
86 use html_to_markdown::structure::wikipedia;
87
88 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
89 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
90 handlers.push(Rc::new(
91 RefCell::new(wikipedia::WikipediaCodeHandler::new()),
92 ));
93 } else {
94 handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
95 }
96
97 convert_html_to_markdown(&body[..], &mut handlers)
98 }
99 ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
100 ContentType::Json => {
101 let json: serde_json::Value = serde_json::from_slice(&body)?;
102
103 Ok(format!(
104 "```json\n{}\n```",
105 serde_json::to_string_pretty(&json)?
106 ))
107 }
108 }
109 }
110}
111
112impl Tool for FetchTool {
113 fn name(&self) -> String {
114 "fetch".to_string()
115 }
116
117 fn needs_confirmation(&self) -> bool {
118 true
119 }
120
121 fn description(&self) -> String {
122 include_str!("./fetch_tool/description.md").to_string()
123 }
124
125 fn icon(&self) -> IconName {
126 IconName::Globe
127 }
128
129 fn input_schema(&self) -> serde_json::Value {
130 let schema = schemars::schema_for!(FetchToolInput);
131 serde_json::to_value(&schema).unwrap()
132 }
133
134 fn ui_text(&self, input: &serde_json::Value) -> String {
135 match serde_json::from_value::<FetchToolInput>(input.clone()) {
136 Ok(input) => format!("Fetch `{}`", input.url),
137 Err(_) => "Fetch URL".to_string(),
138 }
139 }
140
141 fn run(
142 self: Arc<Self>,
143 input: serde_json::Value,
144 _messages: &[LanguageModelRequestMessage],
145 _project: Entity<Project>,
146 _action_log: Entity<ActionLog>,
147 cx: &mut App,
148 ) -> Task<Result<String>> {
149 let input = match serde_json::from_value::<FetchToolInput>(input) {
150 Ok(input) => input,
151 Err(err) => return Task::ready(Err(anyhow!(err))),
152 };
153
154 let text = cx.background_spawn({
155 let http_client = self.http_client.clone();
156 let url = input.url.clone();
157 async move { Self::build_message(http_client, &url).await }
158 });
159
160 cx.foreground_executor().spawn(async move {
161 let text = text.await?;
162 if text.trim().is_empty() {
163 bail!("no textual content found");
164 }
165
166 Ok(text)
167 })
168 }
169}