assistant.rs

  1mod ambient_context;
  2pub mod assistant_panel;
  3pub mod assistant_settings;
  4mod codegen;
  5mod completion_provider;
  6mod prompt_library;
  7mod prompts;
  8mod saved_conversation;
  9mod search;
 10mod streaming_diff;
 11
 12use ambient_context::AmbientContextSnapshot;
 13pub use assistant_panel::AssistantPanel;
 14use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
 15use client::{proto, Client};
 16use command_palette_hooks::CommandPaletteFilter;
 17pub(crate) use completion_provider::*;
 18use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
 19pub(crate) use saved_conversation::*;
 20use serde::{Deserialize, Serialize};
 21use settings::{Settings, SettingsStore};
 22use std::{
 23    fmt::{self, Display},
 24    sync::Arc,
 25};
 26
 27actions!(
 28    assistant,
 29    [
 30        Assist,
 31        Split,
 32        CycleMessageRole,
 33        QuoteSelection,
 34        ToggleFocus,
 35        ResetKey,
 36        InlineAssist,
 37        InsertActivePrompt,
 38        ToggleIncludeConversation,
 39        ToggleHistory,
 40        ApplyEdit
 41    ]
 42);
 43
 44#[derive(
 45    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 46)]
 47struct MessageId(usize);
 48
 49#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 50#[serde(rename_all = "lowercase")]
 51pub enum Role {
 52    User,
 53    Assistant,
 54    System,
 55}
 56
 57impl Role {
 58    pub fn cycle(&mut self) {
 59        *self = match self {
 60            Role::User => Role::Assistant,
 61            Role::Assistant => Role::System,
 62            Role::System => Role::User,
 63        }
 64    }
 65}
 66
 67impl Display for Role {
 68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 69        match self {
 70            Role::User => write!(f, "user"),
 71            Role::Assistant => write!(f, "assistant"),
 72            Role::System => write!(f, "system"),
 73        }
 74    }
 75}
 76
 77#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 78pub enum LanguageModel {
 79    ZedDotDev(ZedDotDevModel),
 80    OpenAi(OpenAiModel),
 81    Anthropic(AnthropicModel),
 82}
 83
 84impl Default for LanguageModel {
 85    fn default() -> Self {
 86        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 87    }
 88}
 89
 90impl LanguageModel {
 91    pub fn telemetry_id(&self) -> String {
 92        match self {
 93            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 94            LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
 95            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 96        }
 97    }
 98
 99    pub fn display_name(&self) -> String {
100        match self {
101            LanguageModel::OpenAi(model) => model.display_name().into(),
102            LanguageModel::Anthropic(model) => model.display_name().into(),
103            LanguageModel::ZedDotDev(model) => model.display_name().into(),
104        }
105    }
106
107    pub fn max_token_count(&self) -> usize {
108        match self {
109            LanguageModel::OpenAi(model) => model.max_token_count(),
110            LanguageModel::Anthropic(model) => model.max_token_count(),
111            LanguageModel::ZedDotDev(model) => model.max_token_count(),
112        }
113    }
114
115    pub fn id(&self) -> &str {
116        match self {
117            LanguageModel::OpenAi(model) => model.id(),
118            LanguageModel::Anthropic(model) => model.id(),
119            LanguageModel::ZedDotDev(model) => model.id(),
120        }
121    }
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125pub struct LanguageModelRequestMessage {
126    pub role: Role,
127    pub content: String,
128}
129
130impl LanguageModelRequestMessage {
131    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
132        proto::LanguageModelRequestMessage {
133            role: match self.role {
134                Role::User => proto::LanguageModelRole::LanguageModelUser,
135                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
136                Role::System => proto::LanguageModelRole::LanguageModelSystem,
137            } as i32,
138            content: self.content.clone(),
139            tool_calls: Vec::new(),
140            tool_call_id: None,
141        }
142    }
143}
144
145#[derive(Debug, Default, Serialize)]
146pub struct LanguageModelRequest {
147    pub model: LanguageModel,
148    pub messages: Vec<LanguageModelRequestMessage>,
149    pub stop: Vec<String>,
150    pub temperature: f32,
151}
152
153impl LanguageModelRequest {
154    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
155        proto::CompleteWithLanguageModel {
156            model: self.model.id().to_string(),
157            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
158            stop: self.stop.clone(),
159            temperature: self.temperature,
160            tool_choice: None,
161            tools: Vec::new(),
162        }
163    }
164}
165
166#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
167pub struct LanguageModelResponseMessage {
168    pub role: Option<Role>,
169    pub content: Option<String>,
170}
171
172#[derive(Deserialize, Debug)]
173pub struct LanguageModelUsage {
174    pub prompt_tokens: u32,
175    pub completion_tokens: u32,
176    pub total_tokens: u32,
177}
178
179#[derive(Deserialize, Debug)]
180pub struct LanguageModelChoiceDelta {
181    pub index: u32,
182    pub delta: LanguageModelResponseMessage,
183    pub finish_reason: Option<String>,
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize)]
187struct MessageMetadata {
188    role: Role,
189    status: MessageStatus,
190    // todo!("delete this")
191    #[serde(skip)]
192    ambient_context: AmbientContextSnapshot,
193}
194
195#[derive(Clone, Debug, Serialize, Deserialize)]
196enum MessageStatus {
197    Pending,
198    Done,
199    Error(SharedString),
200}
201
202/// The state pertaining to the Assistant.
203#[derive(Default)]
204struct Assistant {
205    /// Whether the Assistant is enabled.
206    enabled: bool,
207}
208
209impl Global for Assistant {}
210
211impl Assistant {
212    const NAMESPACE: &'static str = "assistant";
213
214    fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
215        if self.enabled == enabled {
216            return;
217        }
218
219        self.enabled = enabled;
220
221        if !enabled {
222            CommandPaletteFilter::update_global(cx, |filter, _cx| {
223                filter.hide_namespace(Self::NAMESPACE);
224            });
225
226            return;
227        }
228
229        CommandPaletteFilter::update_global(cx, |filter, _cx| {
230            filter.show_namespace(Self::NAMESPACE);
231        });
232    }
233}
234
235pub fn init(client: Arc<Client>, cx: &mut AppContext) {
236    cx.set_global(Assistant::default());
237    AssistantSettings::register(cx);
238    completion_provider::init(client, cx);
239    assistant_panel::init(cx);
240
241    CommandPaletteFilter::update_global(cx, |filter, _cx| {
242        filter.hide_namespace(Assistant::NAMESPACE);
243    });
244    Assistant::update_global(cx, |assistant, cx| {
245        let settings = AssistantSettings::get_global(cx);
246
247        assistant.set_enabled(settings.enabled, cx);
248    });
249    cx.observe_global::<SettingsStore>(|cx| {
250        Assistant::update_global(cx, |assistant, cx| {
251            let settings = AssistantSettings::get_global(cx);
252
253            assistant.set_enabled(settings.enabled, cx);
254        });
255    })
256    .detach();
257}
258
259#[cfg(test)]
260#[ctor::ctor]
261fn init_logger() {
262    if std::env::var("RUST_LOG").is_ok() {
263        env_logger::init();
264    }
265}