1use std::rc::Rc;
2use std::sync::Arc;
3use std::{borrow::Cow, cell::RefCell};
4
5use crate::schema::json_schema_for;
6use action_log::ActionLog;
7use anyhow::{Context as _, Result, anyhow, bail};
8use assistant_tool::{Tool, ToolResult};
9use futures::AsyncReadExt as _;
10use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
11use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
12use http_client::{AsyncBody, HttpClientWithUrl};
13use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
14use project::Project;
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use ui::IconName;
18use util::markdown::MarkdownEscaped;
19
20#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
21enum ContentType {
22 Html,
23 Plaintext,
24 Json,
25}
26
27#[derive(Debug, Serialize, Deserialize, JsonSchema)]
28pub struct FetchToolInput {
29 /// The URL to fetch.
30 url: String,
31}
32
33pub struct FetchTool {
34 http_client: Arc<HttpClientWithUrl>,
35}
36
37impl FetchTool {
38 pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
39 Self { http_client }
40 }
41
42 async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
43 let url = if !url.starts_with("https://") && !url.starts_with("http://") {
44 Cow::Owned(format!("https://{url}"))
45 } else {
46 Cow::Borrowed(url)
47 };
48
49 let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
50
51 let mut body = Vec::new();
52 response
53 .body_mut()
54 .read_to_end(&mut body)
55 .await
56 .context("error reading response body")?;
57
58 if response.status().is_client_error() {
59 let text = String::from_utf8_lossy(body.as_slice());
60 bail!(
61 "status error {}, response: {text:?}",
62 response.status().as_u16()
63 );
64 }
65
66 let Some(content_type) = response.headers().get("content-type") else {
67 bail!("missing Content-Type header");
68 };
69 let content_type = content_type
70 .to_str()
71 .context("invalid Content-Type header")?;
72 let content_type = match content_type {
73 "text/html" | "application/xhtml+xml" => ContentType::Html,
74 "application/json" => ContentType::Json,
75 _ => ContentType::Plaintext,
76 };
77
78 match content_type {
79 ContentType::Html => {
80 let mut handlers: Vec<TagHandler> = vec![
81 Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
82 Rc::new(RefCell::new(markdown::ParagraphHandler)),
83 Rc::new(RefCell::new(markdown::HeadingHandler)),
84 Rc::new(RefCell::new(markdown::ListHandler)),
85 Rc::new(RefCell::new(markdown::TableHandler::new())),
86 Rc::new(RefCell::new(markdown::StyledTextHandler)),
87 ];
88 if url.contains("wikipedia.org") {
89 use html_to_markdown::structure::wikipedia;
90
91 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
92 handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
93 handlers.push(Rc::new(
94 RefCell::new(wikipedia::WikipediaCodeHandler::new()),
95 ));
96 } else {
97 handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
98 }
99
100 convert_html_to_markdown(&body[..], &mut handlers)
101 }
102 ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
103 ContentType::Json => {
104 let json: serde_json::Value = serde_json::from_slice(&body)?;
105
106 Ok(format!(
107 "```json\n{}\n```",
108 serde_json::to_string_pretty(&json)?
109 ))
110 }
111 }
112 }
113}
114
115impl Tool for FetchTool {
116 fn name(&self) -> String {
117 "fetch".to_string()
118 }
119
120 fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
121 true
122 }
123
124 fn may_perform_edits(&self) -> bool {
125 false
126 }
127
128 fn description(&self) -> String {
129 include_str!("./fetch_tool/description.md").to_string()
130 }
131
132 fn icon(&self) -> IconName {
133 IconName::ToolWeb
134 }
135
136 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
137 json_schema_for::<FetchToolInput>(format)
138 }
139
140 fn ui_text(&self, input: &serde_json::Value) -> String {
141 match serde_json::from_value::<FetchToolInput>(input.clone()) {
142 Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
143 Err(_) => "Fetch URL".to_string(),
144 }
145 }
146
147 fn run(
148 self: Arc<Self>,
149 input: serde_json::Value,
150 _request: Arc<LanguageModelRequest>,
151 _project: Entity<Project>,
152 _action_log: Entity<ActionLog>,
153 _model: Arc<dyn LanguageModel>,
154 _window: Option<AnyWindowHandle>,
155 cx: &mut App,
156 ) -> ToolResult {
157 let input = match serde_json::from_value::<FetchToolInput>(input) {
158 Ok(input) => input,
159 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
160 };
161
162 let text = cx.background_spawn({
163 let http_client = self.http_client.clone();
164 async move { Self::build_message(http_client, &input.url).await }
165 });
166
167 cx.foreground_executor()
168 .spawn(async move {
169 let text = text.await?;
170 if text.trim().is_empty() {
171 bail!("no textual content found");
172 }
173
174 Ok(text.into())
175 })
176 .into()
177 }
178}