assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod streaming_diff;
  8
  9pub use assistant_panel::AssistantPanel;
 10use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
 11use client::{proto, Client};
 12use command_palette_hooks::CommandPaletteFilter;
 13pub(crate) use completion_provider::*;
 14use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
 15pub(crate) use saved_conversation::*;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{
 19    fmt::{self, Display},
 20    sync::Arc,
 21};
 22
 23actions!(
 24    assistant,
 25    [
 26        Assist,
 27        Split,
 28        CycleMessageRole,
 29        QuoteSelection,
 30        ToggleFocus,
 31        ResetKey,
 32        InlineAssist,
 33        ToggleIncludeConversation,
 34        ToggleHistory,
 35    ]
 36);
 37
 38#[derive(
 39    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 40)]
 41struct MessageId(usize);
 42
 43#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 44#[serde(rename_all = "lowercase")]
 45pub enum Role {
 46    User,
 47    Assistant,
 48    System,
 49}
 50
 51impl Role {
 52    pub fn cycle(&mut self) {
 53        *self = match self {
 54            Role::User => Role::Assistant,
 55            Role::Assistant => Role::System,
 56            Role::System => Role::User,
 57        }
 58    }
 59}
 60
 61impl Display for Role {
 62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 63        match self {
 64            Role::User => write!(f, "user"),
 65            Role::Assistant => write!(f, "assistant"),
 66            Role::System => write!(f, "system"),
 67        }
 68    }
 69}
 70
 71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 72pub enum LanguageModel {
 73    ZedDotDev(ZedDotDevModel),
 74    OpenAi(OpenAiModel),
 75}
 76
 77impl Default for LanguageModel {
 78    fn default() -> Self {
 79        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 80    }
 81}
 82
 83impl LanguageModel {
 84    pub fn telemetry_id(&self) -> String {
 85        match self {
 86            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 87            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 88        }
 89    }
 90
 91    pub fn display_name(&self) -> String {
 92        match self {
 93            LanguageModel::OpenAi(model) => model.display_name().into(),
 94            LanguageModel::ZedDotDev(model) => model.display_name().into(),
 95        }
 96    }
 97
 98    pub fn max_token_count(&self) -> usize {
 99        match self {
100            LanguageModel::OpenAi(model) => model.max_token_count(),
101            LanguageModel::ZedDotDev(model) => model.max_token_count(),
102        }
103    }
104
105    pub fn id(&self) -> &str {
106        match self {
107            LanguageModel::OpenAi(model) => model.id(),
108            LanguageModel::ZedDotDev(model) => model.id(),
109        }
110    }
111}
112
113#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
114pub struct LanguageModelRequestMessage {
115    pub role: Role,
116    pub content: String,
117}
118
119impl LanguageModelRequestMessage {
120    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
121        proto::LanguageModelRequestMessage {
122            role: match self.role {
123                Role::User => proto::LanguageModelRole::LanguageModelUser,
124                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
125                Role::System => proto::LanguageModelRole::LanguageModelSystem,
126            } as i32,
127            content: self.content.clone(),
128            tool_calls: Vec::new(),
129            tool_call_id: None,
130        }
131    }
132}
133
134#[derive(Debug, Default, Serialize)]
135pub struct LanguageModelRequest {
136    pub model: LanguageModel,
137    pub messages: Vec<LanguageModelRequestMessage>,
138    pub stop: Vec<String>,
139    pub temperature: f32,
140}
141
142impl LanguageModelRequest {
143    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
144        proto::CompleteWithLanguageModel {
145            model: self.model.id().to_string(),
146            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
147            stop: self.stop.clone(),
148            temperature: self.temperature,
149            tool_choice: None,
150            tools: Vec::new(),
151        }
152    }
153}
154
155#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
156pub struct LanguageModelResponseMessage {
157    pub role: Option<Role>,
158    pub content: Option<String>,
159}
160
161#[derive(Deserialize, Debug)]
162pub struct LanguageModelUsage {
163    pub prompt_tokens: u32,
164    pub completion_tokens: u32,
165    pub total_tokens: u32,
166}
167
168#[derive(Deserialize, Debug)]
169pub struct LanguageModelChoiceDelta {
170    pub index: u32,
171    pub delta: LanguageModelResponseMessage,
172    pub finish_reason: Option<String>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize)]
176struct MessageMetadata {
177    role: Role,
178    status: MessageStatus,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182enum MessageStatus {
183    Pending,
184    Done,
185    Error(SharedString),
186}
187
188/// The state pertaining to the Assistant.
189#[derive(Default)]
190struct Assistant {
191    /// Whether the Assistant is enabled.
192    enabled: bool,
193}
194
195impl Global for Assistant {}
196
197impl Assistant {
198    const NAMESPACE: &'static str = "assistant";
199
200    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
201        if self.enabled == enabled {
202            return;
203        }
204
205        self.enabled = enabled;
206
207        if !enabled {
208            CommandPaletteFilter::update_global(cx, |filter, _cx| {
209                filter.hide_namespace(Self::NAMESPACE);
210            });
211
212            return;
213        }
214
215        CommandPaletteFilter::update_global(cx, |filter, _cx| {
216            filter.show_namespace(Self::NAMESPACE);
217        });
218    }
219}
220
221pub fn init(client: Arc<Client>, cx: &mut AppContext) {
222    cx.set_global(Assistant::default());
223    AssistantSettings::register(cx);
224    completion_provider::init(client, cx);
225    assistant_panel::init(cx);
226
227    CommandPaletteFilter::update_global(cx, |filter, _cx| {
228        filter.hide_namespace(Assistant::NAMESPACE);
229    });
230    cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
231        let settings = AssistantSettings::get_global(cx);
232
233        assistant.set_enabled(settings.enabled, cx);
234    });
235    cx.observe_global::<SettingsStore>(|cx| {
236        cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
237            let settings = AssistantSettings::get_global(cx);
238
239            assistant.set_enabled(settings.enabled, cx);
240        });
241    })
242    .detach();
243}
244
245#[cfg(test)]
246#[ctor::ctor]
247fn init_logger() {
248    if std::env::var("RUST_LOG").is_ok() {
249        env_logger::init();
250    }
251}