1use std::cell::RefCell;
2use std::rc::Rc;
3use std::sync::Arc;
4
5use crate::schema::json_schema_for;
6use anyhow::{Context as _, Result, anyhow, bail};
7use assistant_tool::{ActionLog, Tool, ToolResult};
8use futures::AsyncReadExt as _;
9use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
10use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
11use http_client::{AsyncBody, HttpClientWithUrl};
12use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
13use project::Project;
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use ui::IconName;
17use util::markdown::MarkdownEscaped;
18
19#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
20enum ContentType {
21 Html,
22 Plaintext,
23 Json,
24}
25
26#[derive(Debug, Serialize, Deserialize, JsonSchema)]
27pub struct FetchToolInput {
28 /// The URL to fetch.
29 url: String,
30}
31
32pub struct FetchTool {
33 http_client: Arc<HttpClientWithUrl>,
34}
35
36impl FetchTool {
37 pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
38 Self { http_client }
39 }
40
41 async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
42 let mut url = url.to_owned();
43 if !url.starts_with("https://") && !url.starts_with("http://") {
44 url = format!("https://{url}");
45 }
46
47 let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
48
49 let mut body = Vec::new();
50 response
51 .body_mut()
52 .read_to_end(&mut body)
53 .await
54 .context("error reading response body")?;
55
56 if response.status().is_client_error() {
57 let text = String::from_utf8_lossy(body.as_slice());
58 bail!(
59 "status error {}, response: {text:?}",
60 response.status().as_u16()
61 );
62 }
63
64 let Some(content_type) = response.headers().get("content-type") else {
65 bail!("missing Content-Type header");
66 };
67 let content_type = content_type
68 .to_str()
69 .context("invalid Content-Type header")?;
70 let content_type = match content_type {
71 "text/html" => ContentType::Html,
72 "text/plain" => ContentType::Plaintext,
73 "application/json" => ContentType::Json,
74 _ => ContentType::Html,
75 };
76
77 match content_type {
78 ContentType::Html => {
79 let mut handlers: Vec<TagHandler> = vec![
80 Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
81 Rc::new(RefCell::new(markdown::ParagraphHandler)),
82 Rc::new(RefCell::new(markdown::HeadingHandler)),
83 Rc::new(RefCell::new(markdown::ListHandler)),
84 Rc::new(RefCell::new(markdown::TableHandler::new())),
85 Rc::new(RefCell::new(markdown::StyledTextHandler)),
86 ];
87 if url.contains("wikipedia.org") {
88 use html_to_markdown::structure::wikipedia;
89
90 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
91 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
92 handlers.push(Rc::new(
93 RefCell::new(wikipedia::WikipediaCodeHandler::new()),
94 ));
95 } else {
96 handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
97 }
98
99 convert_html_to_markdown(&body[..], &mut handlers)
100 }
101 ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
102 ContentType::Json => {
103 let json: serde_json::Value = serde_json::from_slice(&body)?;
104
105 Ok(format!(
106 "```json\n{}\n```",
107 serde_json::to_string_pretty(&json)?
108 ))
109 }
110 }
111 }
112}
113
114impl Tool for FetchTool {
115 fn name(&self) -> String {
116 "fetch".to_string()
117 }
118
119 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
120 true
121 }
122
123 fn description(&self) -> String {
124 include_str!("./fetch_tool/description.md").to_string()
125 }
126
127 fn icon(&self) -> IconName {
128 IconName::Globe
129 }
130
131 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
132 json_schema_for::<FetchToolInput>(format)
133 }
134
135 fn ui_text(&self, input: &serde_json::Value) -> String {
136 match serde_json::from_value::<FetchToolInput>(input.clone()) {
137 Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
138 Err(_) => "Fetch URL".to_string(),
139 }
140 }
141
142 fn run(
143 self: Arc<Self>,
144 input: serde_json::Value,
145 _request: Arc<LanguageModelRequest>,
146 _project: Entity<Project>,
147 _action_log: Entity<ActionLog>,
148 _model: Arc<dyn LanguageModel>,
149 _window: Option<AnyWindowHandle>,
150 cx: &mut App,
151 ) -> ToolResult {
152 let input = match serde_json::from_value::<FetchToolInput>(input) {
153 Ok(input) => input,
154 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
155 };
156
157 let text = cx.background_spawn({
158 let http_client = self.http_client.clone();
159 let url = input.url.clone();
160 async move { Self::build_message(http_client, &url).await }
161 });
162
163 cx.foreground_executor()
164 .spawn(async move {
165 let text = text.await?;
166 if text.trim().is_empty() {
167 bail!("no textual content found");
168 }
169
170 Ok(text.into())
171 })
172 .into()
173 }
174}