Detailed changes
@@ -109,8 +109,6 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
}
}
-// Execute adds all child commands to the root command and sets flags appropriately.
-// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
err := rootCmd.Execute()
if err != nil {
@@ -131,13 +129,14 @@ func loadConfig() {
// LLM
viper.SetDefault("models.big", string(models.DefaultBigModel))
- viper.SetDefault("models.little", string(models.DefaultLittleModel))
+ viper.SetDefault("models.small", string(models.DefaultLittleModel))
viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY"))
viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY"))
+ viper.SetDefault("providers.groq.key", os.Getenv("GROQ_API_KEY"))
viper.SetDefault("providers.common.max_tokens", 4000)
viper.SetDefault("agents.default", "coder")
- //
+
viper.ReadInConfig()
workdir, err := os.Getwd()
@@ -0,0 +1,31 @@
+package agent
+
+import (
+ "context"
+
+ "github.com/cloudwego/eino/schema"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
+ "github.com/spf13/viper"
+)
+
+func GenerateTitle(ctx context.Context, content string) (string, error) {
+ model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
+ if err != nil {
+ return "", err
+ }
+ out, err := model.Generate(
+ ctx,
+ []*schema.Message{
+ schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
+ - ensure it is not more than 80 characters long
+ - the title should be a summary of the user's message
+ - do not use quotes or colons
+ - the entire text you return will be used as the title`),
+ schema.UserMessage(content),
+ },
+ )
+ if err != nil {
+ return "", err
+ }
+ return out.Content, nil
+}
@@ -11,6 +11,7 @@ import (
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
+ "github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
@@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
}
log.Printf("Request: %s", content)
- agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
+ currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
for _, m := range history {
messages = append(messages, &m.MessageData)
}
+
builder := callbacks.NewHandlerBuilder()
builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
i, ok := input.(*eModel.CallbackInput)
@@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return ctx
})
- out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
+ out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
ID: id,
@@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
usage := out.ResponseMeta.Usage
+ s.messages.Create(sessionID, *out)
if usage != nil {
log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
session, err := s.sessions.Get(sessionID)
@@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
session.PromptTokens += int64(usage.PromptTokens)
session.CompletionTokens += int64(usage.CompletionTokens)
// TODO: calculate cost
+ model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
+ session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
+ float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
+ var newTitle string
+ if len(history) == 1 {
+ // first message generate the title
+ newTitle, err = agent.GenerateTitle(s.ctx, content)
+ if err != nil {
+ s.Publish(AgentErrorEvent, AgentEvent{
+ ID: id,
+ Type: AgentMessageTypeError,
+ AgentID: RootAgent,
+ MessageID: "",
+ SessionID: sessionID,
+ Content: err.Error(),
+ })
+ return
+ }
+ }
+ if newTitle != "" {
+ session.Title = newTitle
+ }
+
_, err = s.sessions.Save(session)
if err != nil {
s.Publish(AgentErrorEvent, AgentEvent{
@@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
return
}
}
- s.messages.Create(sessionID, *out)
}
func (s *service) SendRequest(sessionID string, content string) {
@@ -3,6 +3,7 @@ package models
import (
"context"
"errors"
+ "log"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/openai"
@@ -16,10 +17,12 @@ type (
)
type Model struct {
- ID ModelID `json:"id"`
- Name string `json:"name"`
- Provider ModelProvider `json:"provider"`
- APIModel string `json:"api_model"` // Actual value used when calling the API
+ ID ModelID `json:"id"`
+ Name string `json:"name"`
+ Provider ModelProvider `json:"provider"`
+ APIModel string `json:"api_model"`
+ CostPer1MIn float64 `json:"cost_per_1m_in"`
+ CostPer1MOut float64 `json:"cost_per_1m_out"`
}
const (
@@ -52,6 +55,9 @@ const (
// Meta
Llama3 ModelID = "llama-3"
Llama270B ModelID = "llama-2-70b"
+ // GROQ
+ GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
+ GroqQwen32BCoder ModelID = "qwen-2.5-coder-32b"
)
const (
@@ -61,6 +67,7 @@ const (
ProviderXAI ModelProvider = "xai"
ProviderDeepSeek ModelProvider = "deepseek"
ProviderMeta ModelProvider = "meta"
+ ProviderGroq ModelProvider = "groq"
)
var SupportedModels = map[ModelID]Model{
@@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{
APIModel: "gpt-4o",
},
GPT4oMini: {
- ID: GPT4oMini,
- Name: "GPT-4o Mini",
- Provider: ProviderOpenAI,
- APIModel: "gpt-4o-mini",
+ ID: GPT4oMini,
+ Name: "GPT-4o Mini",
+ Provider: ProviderOpenAI,
+ APIModel: "gpt-4o-mini",
+ CostPer1MIn: 0.150,
+ CostPer1MOut: 0.600,
},
GPT45: {
ID: GPT45,
@@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{
Provider: ProviderMeta,
APIModel: "llama-2-70b",
},
+
+ // GROQ
+ GroqLlama3SpecDec: {
+ ID: GroqLlama3SpecDec,
+ Name: "GROQ LLaMA 3 SpecDec",
+ Provider: ProviderGroq,
+ APIModel: "llama-3.3-70b-specdec",
+ },
+ GroqQwen32BCoder: {
+ ID: GroqQwen32BCoder,
+ Name: "GROQ Qwen 2.5 Coder 32B",
+ Provider: ProviderGroq,
+ APIModel: "qwen-2.5-coder-32b",
+ },
}
func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
provider := SupportedModels[model].Provider
+ log.Printf("Provider: %s", provider)
maxTokens := viper.GetInt("providers.common.max_tokens")
switch provider {
case ProviderOpenAI:
@@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
MaxTokens: maxTokens,
})
+ case ProviderGroq:
+ return openai.NewChatModel(ctx, &openai.ChatModelConfig{
+ BaseURL: "https://api.groq.com/openai/v1",
+ APIKey: viper.GetString("providers.groq.key"),
+ Model: string(SupportedModels[model].APIModel),
+ MaxTokens: &maxTokens,
+ })
+
}
return nil, errors.New("unsupported provider")
}
@@ -1,22 +1,33 @@
package repl
import (
+ "github.com/charmbracelet/bubbles/key"
+ "github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/message"
"github.com/kujtimiihoxha/termai/internal/pubsub"
"github.com/kujtimiihoxha/termai/internal/session"
+ "github.com/kujtimiihoxha/termai/internal/tui/layout"
)
+type MessagesCmp interface {
+ tea.Model
+ layout.Focusable
+ layout.Bordered
+ layout.Sizeable
+ layout.Bindings
+}
+
type messagesCmp struct {
app *app.App
messages []message.Message
session session.Session
-}
-
-func (m *messagesCmp) Init() tea.Cmd {
- return nil
+ viewport viewport.Model
+ width int
+ height int
+ focused bool
}
func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -25,6 +36,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if msg.Type == pubsub.CreatedEvent {
m.messages = append(m.messages, msg.Payload)
}
+ case pubsub.Event[session.Session]:
+ if msg.Type == pubsub.UpdatedEvent {
+ if m.session.ID == msg.Payload.ID {
+ m.session = msg.Payload
+ }
+ }
case SelectedSessionMsg:
m.session, _ = m.app.Sessions.Get(msg.SessionID)
m.messages, _ = m.app.Messages.List(m.session.ID)
@@ -40,7 +57,55 @@ func (i *messagesCmp) View() string {
return lipgloss.JoinVertical(lipgloss.Top, stringMessages...)
}
-func NewMessagesCmp(app *app.App) tea.Model {
+// BindingKeys implements MessagesCmp.
+func (m *messagesCmp) BindingKeys() []key.Binding {
+ return []key.Binding{}
+}
+
+// Blur implements MessagesCmp.
+func (m *messagesCmp) Blur() tea.Cmd {
+ m.focused = false
+ return nil
+}
+
+// BorderText implements MessagesCmp.
+func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
+ title := m.session.Title
+ if len(title) > 20 {
+ title = title[:20] + "..."
+ }
+ return map[layout.BorderPosition]string{
+ layout.TopLeftBorder: title,
+ }
+}
+
+// Focus implements MessagesCmp.
+func (m *messagesCmp) Focus() tea.Cmd {
+ m.focused = true
+ return nil
+}
+
+// GetSize implements MessagesCmp.
+func (m *messagesCmp) GetSize() (int, int) {
+ return m.width, m.height
+}
+
+// IsFocused implements MessagesCmp.
+func (m *messagesCmp) IsFocused() bool {
+ return m.focused
+}
+
+// SetSize implements MessagesCmp.
+func (m *messagesCmp) SetSize(width int, height int) {
+ m.width = width
+ m.height = height
+}
+
+func (m *messagesCmp) Init() tea.Cmd {
+ return nil
+}
+
+func NewMessagesCmp(app *app.App) MessagesCmp {
return &messagesCmp{
app: app,
messages: []message.Message{},
@@ -2,6 +2,7 @@ package repl
import (
"fmt"
+ "strings"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
@@ -82,7 +83,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
items[i] = listItem{
id: s.ID,
title: s.Title,
- desc: fmt.Sprintf("Tokens: %d, Cost: %.2f", s.PromptTokens+s.CompletionTokens, s.Cost),
+ desc: formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost),
}
}
return i, i.list.SetItems(items)
@@ -94,7 +95,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
s := item.(listItem)
if s.id == msg.Payload.ID {
s.title = msg.Payload.Title
- s.desc = fmt.Sprintf("Tokens: %d, Cost: %.2f", msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
+ s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
items[idx] = s
break
}
@@ -169,6 +170,32 @@ func (i *sessionsCmp) BindingKeys() []key.Binding {
return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select)
}
+func formatTokensAndCost(tokens int64, cost float64) string {
+ // Format tokens in human-readable format (e.g., 110K, 1.2M)
+ var formattedTokens string
+ switch {
+ case tokens >= 1_000_000:
+ formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000)
+ case tokens >= 1_000:
+ formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000)
+ default:
+ formattedTokens = fmt.Sprintf("%d", tokens)
+ }
+
+ // Remove .0 suffix if present
+ if strings.HasSuffix(formattedTokens, ".0K") {
+ formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1)
+ }
+ if strings.HasSuffix(formattedTokens, ".0M") {
+ formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1)
+ }
+
+ // Format cost with $ symbol and 2 decimal places
+ formattedCost := fmt.Sprintf("$%.2f", cost)
+
+ return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost)
+}
+
func NewSessionsCmp(app *app.App) SessionsCmp {
listDelegate := list.NewDefaultDelegate()
defaultItemStyle := list.NewDefaultItemStyles()