mention.rs

  1use agent::ThreadId;
  2use anyhow::{Context as _, Result, bail};
  3use prompt_store::{PromptId, UserPromptId};
  4use std::{
  5    fmt,
  6    ops::Range,
  7    path::{Path, PathBuf},
  8};
  9use url::Url;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub enum MentionUri {
 13    File(PathBuf),
 14    Symbol {
 15        path: PathBuf,
 16        name: String,
 17        line_range: Range<u32>,
 18    },
 19    Thread {
 20        id: ThreadId,
 21        name: String,
 22    },
 23    TextThread {
 24        path: PathBuf,
 25        name: String,
 26    },
 27    Rule {
 28        id: PromptId,
 29        name: String,
 30    },
 31    Selection {
 32        path: PathBuf,
 33        line_range: Range<u32>,
 34    },
 35    Fetch {
 36        url: Url,
 37    },
 38}
 39
 40impl MentionUri {
 41    pub fn parse(input: &str) -> Result<Self> {
 42        let url = url::Url::parse(input)?;
 43        let path = url.path();
 44        match url.scheme() {
 45            "file" => {
 46                if let Some(fragment) = url.fragment() {
 47                    let range = fragment
 48                        .strip_prefix("L")
 49                        .context("Line range must start with \"L\"")?;
 50                    let (start, end) = range
 51                        .split_once(":")
 52                        .context("Line range must use colon as separator")?;
 53                    let line_range = start
 54                        .parse::<u32>()
 55                        .context("Parsing line range start")?
 56                        .checked_sub(1)
 57                        .context("Line numbers should be 1-based")?
 58                        ..end
 59                            .parse::<u32>()
 60                            .context("Parsing line range end")?
 61                            .checked_sub(1)
 62                            .context("Line numbers should be 1-based")?;
 63                    if let Some(name) = single_query_param(&url, "symbol")? {
 64                        Ok(Self::Symbol {
 65                            name,
 66                            path: path.into(),
 67                            line_range,
 68                        })
 69                    } else {
 70                        Ok(Self::Selection {
 71                            path: path.into(),
 72                            line_range,
 73                        })
 74                    }
 75                } else {
 76                    let file_path =
 77                        PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
 78
 79                    Ok(Self::File(file_path))
 80                }
 81            }
 82            "zed" => {
 83                if let Some(thread_id) = path.strip_prefix("/agent/thread/") {
 84                    let name = single_query_param(&url, "name")?.context("Missing thread name")?;
 85                    Ok(Self::Thread {
 86                        id: thread_id.into(),
 87                        name,
 88                    })
 89                } else if let Some(path) = path.strip_prefix("/agent/text-thread/") {
 90                    let name = single_query_param(&url, "name")?.context("Missing thread name")?;
 91                    Ok(Self::TextThread {
 92                        path: path.into(),
 93                        name,
 94                    })
 95                } else if let Some(rule_id) = path.strip_prefix("/agent/rule/") {
 96                    let name = single_query_param(&url, "name")?.context("Missing rule name")?;
 97                    let rule_id = UserPromptId(rule_id.parse()?);
 98                    Ok(Self::Rule {
 99                        id: rule_id.into(),
100                        name,
101                    })
102                } else {
103                    bail!("invalid zed url: {:?}", input);
104                }
105            }
106            "http" | "https" => Ok(MentionUri::Fetch { url }),
107            other => bail!("unrecognized scheme {:?}", other),
108        }
109    }
110
111    fn name(&self) -> String {
112        match self {
113            MentionUri::File(path) => path
114                .file_name()
115                .unwrap_or_default()
116                .to_string_lossy()
117                .into_owned(),
118            MentionUri::Symbol { name, .. } => name.clone(),
119            MentionUri::Thread { name, .. } => name.clone(),
120            MentionUri::TextThread { name, .. } => name.clone(),
121            MentionUri::Rule { name, .. } => name.clone(),
122            MentionUri::Selection {
123                path, line_range, ..
124            } => selection_name(path, line_range),
125            MentionUri::Fetch { url } => url.to_string(),
126        }
127    }
128
129    pub fn as_link<'a>(&'a self) -> MentionLink<'a> {
130        MentionLink(self)
131    }
132
133    pub fn to_uri(&self) -> Url {
134        match self {
135            MentionUri::File(path) => {
136                let mut url = Url::parse("file:///").unwrap();
137                url.set_path(&path.to_string_lossy());
138                url
139            }
140            MentionUri::Symbol {
141                path,
142                name,
143                line_range,
144            } => {
145                let mut url = Url::parse("file:///").unwrap();
146                url.set_path(&path.to_string_lossy());
147                url.query_pairs_mut().append_pair("symbol", name);
148                url.set_fragment(Some(&format!(
149                    "L{}:{}",
150                    line_range.start + 1,
151                    line_range.end + 1
152                )));
153                url
154            }
155            MentionUri::Selection { path, line_range } => {
156                let mut url = Url::parse("file:///").unwrap();
157                url.set_path(&path.to_string_lossy());
158                url.set_fragment(Some(&format!(
159                    "L{}:{}",
160                    line_range.start + 1,
161                    line_range.end + 1
162                )));
163                url
164            }
165            MentionUri::Thread { name, id } => {
166                let mut url = Url::parse("zed:///").unwrap();
167                url.set_path(&format!("/agent/thread/{id}"));
168                url.query_pairs_mut().append_pair("name", name);
169                url
170            }
171            MentionUri::TextThread { path, name } => {
172                let mut url = Url::parse("zed:///").unwrap();
173                url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy()));
174                url.query_pairs_mut().append_pair("name", name);
175                url
176            }
177            MentionUri::Rule { name, id } => {
178                let mut url = Url::parse("zed:///").unwrap();
179                url.set_path(&format!("/agent/rule/{id}"));
180                url.query_pairs_mut().append_pair("name", name);
181                url
182            }
183            MentionUri::Fetch { url } => url.clone(),
184        }
185    }
186}
187
188pub struct MentionLink<'a>(&'a MentionUri);
189
190impl fmt::Display for MentionLink<'_> {
191    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192        write!(f, "[@{}]({})", self.0.name(), self.0.to_uri())
193    }
194}
195
196fn single_query_param(url: &Url, name: &'static str) -> Result<Option<String>> {
197    let pairs = url.query_pairs().collect::<Vec<_>>();
198    match pairs.as_slice() {
199        [] => Ok(None),
200        [(k, v)] => {
201            if k != name {
202                bail!("invalid query parameter")
203            }
204
205            Ok(Some(v.to_string()))
206        }
207        _ => bail!("too many query pairs"),
208    }
209}
210
211pub fn selection_name(path: &Path, line_range: &Range<u32>) -> String {
212    format!(
213        "{} ({}:{})",
214        path.file_name().unwrap_or_default().display(),
215        line_range.start + 1,
216        line_range.end + 1
217    )
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_parse_file_uri() {
226        let file_uri = "file:///path/to/file.rs";
227        let parsed = MentionUri::parse(file_uri).unwrap();
228        match &parsed {
229            MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"),
230            _ => panic!("Expected File variant"),
231        }
232        assert_eq!(parsed.to_uri().to_string(), file_uri);
233    }
234
235    #[test]
236    fn test_parse_symbol_uri() {
237        let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20";
238        let parsed = MentionUri::parse(symbol_uri).unwrap();
239        match &parsed {
240            MentionUri::Symbol {
241                path,
242                name,
243                line_range,
244            } => {
245                assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
246                assert_eq!(name, "MySymbol");
247                assert_eq!(line_range.start, 9);
248                assert_eq!(line_range.end, 19);
249            }
250            _ => panic!("Expected Symbol variant"),
251        }
252        assert_eq!(parsed.to_uri().to_string(), symbol_uri);
253    }
254
255    #[test]
256    fn test_parse_selection_uri() {
257        let selection_uri = "file:///path/to/file.rs#L5:15";
258        let parsed = MentionUri::parse(selection_uri).unwrap();
259        match &parsed {
260            MentionUri::Selection { path, line_range } => {
261                assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
262                assert_eq!(line_range.start, 4);
263                assert_eq!(line_range.end, 14);
264            }
265            _ => panic!("Expected Selection variant"),
266        }
267        assert_eq!(parsed.to_uri().to_string(), selection_uri);
268    }
269
270    #[test]
271    fn test_parse_thread_uri() {
272        let thread_uri = "zed:///agent/thread/session123?name=Thread+name";
273        let parsed = MentionUri::parse(thread_uri).unwrap();
274        match &parsed {
275            MentionUri::Thread {
276                id: thread_id,
277                name,
278            } => {
279                assert_eq!(thread_id.to_string(), "session123");
280                assert_eq!(name, "Thread name");
281            }
282            _ => panic!("Expected Thread variant"),
283        }
284        assert_eq!(parsed.to_uri().to_string(), thread_uri);
285    }
286
287    #[test]
288    fn test_parse_rule_uri() {
289        let rule_uri = "zed:///agent/rule/d8694ff2-90d5-4b6f-be33-33c1763acd52?name=Some+rule";
290        let parsed = MentionUri::parse(rule_uri).unwrap();
291        match &parsed {
292            MentionUri::Rule { id, name } => {
293                assert_eq!(id.to_string(), "d8694ff2-90d5-4b6f-be33-33c1763acd52");
294                assert_eq!(name, "Some rule");
295            }
296            _ => panic!("Expected Rule variant"),
297        }
298        assert_eq!(parsed.to_uri().to_string(), rule_uri);
299    }
300
301    #[test]
302    fn test_parse_fetch_http_uri() {
303        let http_uri = "http://example.com/path?query=value#fragment";
304        let parsed = MentionUri::parse(http_uri).unwrap();
305        match &parsed {
306            MentionUri::Fetch { url } => {
307                assert_eq!(url.to_string(), http_uri);
308            }
309            _ => panic!("Expected Fetch variant"),
310        }
311        assert_eq!(parsed.to_uri().to_string(), http_uri);
312    }
313
314    #[test]
315    fn test_parse_fetch_https_uri() {
316        let https_uri = "https://example.com/api/endpoint";
317        let parsed = MentionUri::parse(https_uri).unwrap();
318        match &parsed {
319            MentionUri::Fetch { url } => {
320                assert_eq!(url.to_string(), https_uri);
321            }
322            _ => panic!("Expected Fetch variant"),
323        }
324        assert_eq!(parsed.to_uri().to_string(), https_uri);
325    }
326
327    #[test]
328    fn test_invalid_scheme() {
329        assert!(MentionUri::parse("ftp://example.com").is_err());
330        assert!(MentionUri::parse("ssh://example.com").is_err());
331        assert!(MentionUri::parse("unknown://example.com").is_err());
332    }
333
334    #[test]
335    fn test_invalid_zed_path() {
336        assert!(MentionUri::parse("zed:///invalid/path").is_err());
337        assert!(MentionUri::parse("zed:///agent/unknown/test").is_err());
338    }
339
340    #[test]
341    fn test_invalid_line_range_format() {
342        // Missing L prefix
343        assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err());
344
345        // Missing colon separator
346        assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err());
347
348        // Invalid numbers
349        assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err());
350        assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err());
351    }
352
353    #[test]
354    fn test_invalid_query_parameters() {
355        // Invalid query parameter name
356        assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err());
357
358        // Too many query parameters
359        assert!(
360            MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err()
361        );
362    }
363
364    #[test]
365    fn test_zero_based_line_numbers() {
366        // Test that 0-based line numbers are rejected (should be 1-based)
367        assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err());
368        assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err());
369        assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err());
370    }
371}