perplexity.rs

  1use zed::{
  2    http_client::HttpMethod,
  3    http_client::HttpRequest,
  4    serde_json::{self, json},
  5};
  6use zed_extension_api::{self as zed, Result, http_client::RedirectPolicy};
  7
  8struct Perplexity;
  9
 10impl zed::Extension for Perplexity {
 11    fn new() -> Self {
 12        Self
 13    }
 14
 15    fn run_slash_command(
 16        &self,
 17        command: zed::SlashCommand,
 18        argument: Vec<String>,
 19        worktree: Option<&zed::Worktree>,
 20    ) -> zed::Result<zed::SlashCommandOutput> {
 21        // Check if the command is 'perplexity'
 22        if command.name != "perplexity" {
 23            return Err("Invalid command. Expected 'perplexity'.".into());
 24        }
 25
 26        let worktree = worktree.ok_or("Worktree is required")?;
 27        // Join arguments with space as the query
 28        let query = argument.join(" ");
 29        if query.is_empty() {
 30            return Ok(zed::SlashCommandOutput {
 31                text: "Error: Query not provided. Please enter a question or topic.".to_string(),
 32                sections: vec![],
 33            });
 34        }
 35
 36        // Get the API key from the environment
 37        let env_vars = worktree.shell_env();
 38        let api_key = env_vars
 39            .iter()
 40            .find(|(key, _)| key == "PERPLEXITY_API_KEY")
 41            .map(|(_, value)| value.clone())
 42            .ok_or("PERPLEXITY_API_KEY not found in environment")?;
 43
 44        // Prepare the request
 45        let request = HttpRequest {
 46            method: HttpMethod::Post,
 47            url: "https://api.perplexity.ai/chat/completions".to_string(),
 48            headers: vec![
 49                ("Authorization".to_string(), format!("Bearer {}", api_key)),
 50                ("Content-Type".to_string(), "application/json".to_string()),
 51            ],
 52            body: Some(
 53                serde_json::to_vec(&json!({
 54                    "model": "llama-3.1-sonar-small-128k-online",
 55                    "messages": [{"role": "user", "content": query}],
 56                    "stream": true,
 57                }))
 58                .unwrap(),
 59            ),
 60            redirect_policy: RedirectPolicy::FollowAll,
 61        };
 62
 63        // Make the HTTP request
 64        match zed::http_client::fetch_stream(&request) {
 65            Ok(stream) => {
 66                let mut full_content = String::new();
 67                let mut buffer = String::new();
 68                while let Ok(Some(chunk)) = stream.next_chunk() {
 69                    buffer.push_str(&String::from_utf8_lossy(&chunk));
 70                    for line in buffer.lines() {
 71                        if let Some(json) = line.strip_prefix("data: ") {
 72                            if let Ok(event) = serde_json::from_str::<StreamEvent>(json) {
 73                                if let Some(choice) = event.choices.first() {
 74                                    full_content.push_str(&choice.delta.content);
 75                                }
 76                            }
 77                        }
 78                    }
 79                    buffer.clear();
 80                }
 81                Ok(zed::SlashCommandOutput {
 82                    text: full_content,
 83                    sections: vec![],
 84                })
 85            }
 86            Err(e) => Ok(zed::SlashCommandOutput {
 87                text: format!("API request failed. Error: {}. API Key: {}", e, api_key),
 88                sections: vec![],
 89            }),
 90        }
 91    }
 92
 93    fn complete_slash_command_argument(
 94        &self,
 95        _command: zed::SlashCommand,
 96        query: Vec<String>,
 97    ) -> zed::Result<Vec<zed::SlashCommandArgumentCompletion>> {
 98        let suggestions = vec!["How do I develop a Zed extension?"];
 99        let query = query.join(" ").to_lowercase();
100
101        Ok(suggestions
102            .into_iter()
103            .filter(|suggestion| suggestion.to_lowercase().contains(&query))
104            .map(|suggestion| zed::SlashCommandArgumentCompletion {
105                label: suggestion.to_string(),
106                new_text: suggestion.to_string(),
107                run_command: true,
108            })
109            .collect())
110    }
111
112    fn language_server_command(
113        &mut self,
114        _language_server_id: &zed_extension_api::LanguageServerId,
115        _worktree: &zed_extension_api::Worktree,
116    ) -> Result<zed_extension_api::Command> {
117        Err("Not implemented".into())
118    }
119}
120
121#[derive(serde::Deserialize)]
122struct StreamEvent {
123    id: String,
124    model: String,
125    created: u64,
126    usage: Usage,
127    object: String,
128    choices: Vec<Choice>,
129}
130
131#[derive(serde::Deserialize)]
132struct Usage {
133    prompt_tokens: u32,
134    completion_tokens: u32,
135    total_tokens: u32,
136}
137
138#[derive(serde::Deserialize)]
139struct Choice {
140    index: u32,
141    finish_reason: Option<String>,
142    message: Message,
143    delta: Delta,
144}
145
146#[derive(serde::Deserialize)]
147struct Message {
148    role: String,
149    content: String,
150}
151
152#[derive(serde::Deserialize)]
153struct Delta {
154    role: String,
155    content: String,
156}
157
158zed::register_extension!(Perplexity);