assistant.rs

  1pub mod assistant_panel;
  2pub mod assistant_settings;
  3mod codegen;
  4mod completion_provider;
  5mod prompts;
  6mod saved_conversation;
  7mod streaming_diff;
  8
  9pub use assistant_panel::AssistantPanel;
 10use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
 11use chrono::{DateTime, Local};
 12use client::{proto, Client};
 13pub(crate) use completion_provider::*;
 14use gpui::{actions, AppContext, SharedString};
 15pub(crate) use saved_conversation::*;
 16use serde::{Deserialize, Serialize};
 17use settings::Settings;
 18use std::{
 19    fmt::{self, Display},
 20    sync::Arc,
 21};
 22
 23actions!(
 24    assistant,
 25    [
 26        NewConversation,
 27        Assist,
 28        Split,
 29        CycleMessageRole,
 30        QuoteSelection,
 31        ToggleFocus,
 32        ResetKey,
 33        InlineAssist,
 34        ToggleIncludeConversation,
 35    ]
 36);
 37
 38#[derive(
 39    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 40)]
 41struct MessageId(usize);
 42
 43#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 44#[serde(rename_all = "lowercase")]
 45pub enum Role {
 46    User,
 47    Assistant,
 48    System,
 49}
 50
 51impl Role {
 52    pub fn cycle(&mut self) {
 53        *self = match self {
 54            Role::User => Role::Assistant,
 55            Role::Assistant => Role::System,
 56            Role::System => Role::User,
 57        }
 58    }
 59}
 60
 61impl Display for Role {
 62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
 63        match self {
 64            Role::User => write!(f, "user"),
 65            Role::Assistant => write!(f, "assistant"),
 66            Role::System => write!(f, "system"),
 67        }
 68    }
 69}
 70
 71#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
 72pub enum LanguageModel {
 73    ZedDotDev(ZedDotDevModel),
 74    OpenAi(OpenAiModel),
 75}
 76
 77impl Default for LanguageModel {
 78    fn default() -> Self {
 79        LanguageModel::ZedDotDev(ZedDotDevModel::default())
 80    }
 81}
 82
 83impl LanguageModel {
 84    pub fn telemetry_id(&self) -> String {
 85        match self {
 86            LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
 87            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
 88        }
 89    }
 90
 91    pub fn display_name(&self) -> String {
 92        match self {
 93            LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
 94            LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
 95        }
 96    }
 97
 98    pub fn max_token_count(&self) -> usize {
 99        match self {
100            LanguageModel::OpenAi(model) => model.max_token_count(),
101            LanguageModel::ZedDotDev(model) => model.max_token_count(),
102        }
103    }
104
105    pub fn id(&self) -> &str {
106        match self {
107            LanguageModel::OpenAi(model) => model.id(),
108            LanguageModel::ZedDotDev(model) => model.id(),
109        }
110    }
111}
112
113#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
114pub struct LanguageModelRequestMessage {
115    pub role: Role,
116    pub content: String,
117}
118
119impl LanguageModelRequestMessage {
120    pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
121        proto::LanguageModelRequestMessage {
122            role: match self.role {
123                Role::User => proto::LanguageModelRole::LanguageModelUser,
124                Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
125                Role::System => proto::LanguageModelRole::LanguageModelSystem,
126            } as i32,
127            content: self.content.clone(),
128        }
129    }
130}
131
132#[derive(Debug, Default, Serialize)]
133pub struct LanguageModelRequest {
134    pub model: LanguageModel,
135    pub messages: Vec<LanguageModelRequestMessage>,
136    pub stop: Vec<String>,
137    pub temperature: f32,
138}
139
140impl LanguageModelRequest {
141    pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
142        proto::CompleteWithLanguageModel {
143            model: self.model.id().to_string(),
144            messages: self.messages.iter().map(|m| m.to_proto()).collect(),
145            stop: self.stop.clone(),
146            temperature: self.temperature,
147        }
148    }
149}
150
151#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
152pub struct LanguageModelResponseMessage {
153    pub role: Option<Role>,
154    pub content: Option<String>,
155}
156
157#[derive(Deserialize, Debug)]
158pub struct LanguageModelUsage {
159    pub prompt_tokens: u32,
160    pub completion_tokens: u32,
161    pub total_tokens: u32,
162}
163
164#[derive(Deserialize, Debug)]
165pub struct LanguageModelChoiceDelta {
166    pub index: u32,
167    pub delta: LanguageModelResponseMessage,
168    pub finish_reason: Option<String>,
169}
170
171#[derive(Clone, Debug, Serialize, Deserialize)]
172struct MessageMetadata {
173    role: Role,
174    sent_at: DateTime<Local>,
175    status: MessageStatus,
176}
177
178#[derive(Clone, Debug, Serialize, Deserialize)]
179enum MessageStatus {
180    Pending,
181    Done,
182    Error(SharedString),
183}
184
185pub fn init(client: Arc<Client>, cx: &mut AppContext) {
186    AssistantSettings::register(cx);
187    completion_provider::init(client, cx);
188    assistant_panel::init(cx);
189}
190
191#[cfg(test)]
192#[ctor::ctor]
193fn init_logger() {
194    if std::env::var("RUST_LOG").is_ok() {
195        env_logger::init();
196    }
197}