initial working agent

Kujtim Hoxha created

Change summary

cmd/root.go                              |  7 +-
internal/llm/agent/title.go              | 31 ++++++++++
internal/llm/llm.go                      | 31 +++++++++-
internal/llm/models/models.go            | 48 +++++++++++++--
internal/tui/components/repl/messages.go | 75 ++++++++++++++++++++++++-
internal/tui/components/repl/sessions.go | 31 ++++++++++
6 files changed, 201 insertions(+), 22 deletions(-)

Detailed changes

cmd/root.go 🔗

@@ -109,8 +109,6 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
 	}
 }
 
-// Execute adds all child commands to the root command and sets flags appropriately.
-// This is called by main.main(). It only needs to happen once to the rootCmd.
 func Execute() {
 	err := rootCmd.Execute()
 	if err != nil {
@@ -131,13 +129,14 @@ func loadConfig() {
 
 	// LLM
 	viper.SetDefault("models.big", string(models.DefaultBigModel))
-	viper.SetDefault("models.little", string(models.DefaultLittleModel))
+	viper.SetDefault("models.small", string(models.DefaultLittleModel))
 	viper.SetDefault("providers.openai.key", os.Getenv("OPENAI_API_KEY"))
 	viper.SetDefault("providers.anthropic.key", os.Getenv("ANTHROPIC_API_KEY"))
+	viper.SetDefault("providers.groq.key", os.Getenv("GROQ_API_KEY"))
 	viper.SetDefault("providers.common.max_tokens", 4000)
 
 	viper.SetDefault("agents.default", "coder")
-	//
+
 	viper.ReadInConfig()
 
 	workdir, err := os.Getwd()

internal/llm/agent/title.go 🔗

@@ -0,0 +1,31 @@
+package agent
+
+import (
+	"context"
+
+	"github.com/cloudwego/eino/schema"
+	"github.com/kujtimiihoxha/termai/internal/llm/models"
+	"github.com/spf13/viper"
+)
+
+func GenerateTitle(ctx context.Context, content string) (string, error) {
+	model, err := models.GetModel(ctx, models.ModelID(viper.GetString("models.small")))
+	if err != nil {
+		return "", err
+	}
+	out, err := model.Generate(
+		ctx,
+		[]*schema.Message{
+			schema.SystemMessage(`- you will generate a short title based on the first message a user begins a conversation with
+      - ensure it is not more than 80 characters long
+      - the title should be a summary of the user's message
+      - do not use quotes or colons
+      - the entire text you return will be used as the title`),
+			schema.UserMessage(content),
+		},
+	)
+	if err != nil {
+		return "", err
+	}
+	return out.Content, nil
+}

internal/llm/llm.go 🔗

@@ -11,6 +11,7 @@ import (
 	"github.com/cloudwego/eino/schema"
 	"github.com/google/uuid"
 	"github.com/kujtimiihoxha/termai/internal/llm/agent"
+	"github.com/kujtimiihoxha/termai/internal/llm/models"
 	"github.com/kujtimiihoxha/termai/internal/logging"
 	"github.com/kujtimiihoxha/termai/internal/message"
 	"github.com/kujtimiihoxha/termai/internal/pubsub"
@@ -88,7 +89,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
 	}
 
 	log.Printf("Request: %s", content)
-	agent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
+	currentAgent, systemMessage, err := agent.GetAgent(s.ctx, viper.GetString("agents.default"))
 	if err != nil {
 		s.Publish(AgentErrorEvent, AgentEvent{
 			ID:        id,
@@ -110,6 +111,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
 	for _, m := range history {
 		messages = append(messages, &m.MessageData)
 	}
+
 	builder := callbacks.NewHandlerBuilder()
 	builder.OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
 		i, ok := input.(*eModel.CallbackInput)
@@ -140,7 +142,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
 		return ctx
 	})
 
-	out, err := agent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
+	out, err := currentAgent.Generate(s.ctx, messages, enioAgent.WithComposeOptions(compose.WithCallbacks(builder.Build())))
 	if err != nil {
 		s.Publish(AgentErrorEvent, AgentEvent{
 			ID:        id,
@@ -153,6 +155,7 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
 		return
 	}
 	usage := out.ResponseMeta.Usage
+	s.messages.Create(sessionID, *out)
 	if usage != nil {
 		log.Printf("Prompt Tokens: %d, Completion Tokens: %d, Total Tokens: %d", usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens)
 		session, err := s.sessions.Get(sessionID)
@@ -170,6 +173,29 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
 		session.PromptTokens += int64(usage.PromptTokens)
 		session.CompletionTokens += int64(usage.CompletionTokens)
 		// TODO: calculate cost
+		model := models.SupportedModels[models.ModelID(viper.GetString("models.big"))]
+		session.Cost += float64(usage.PromptTokens)*(model.CostPer1MIn/1_000_000) +
+			float64(usage.CompletionTokens)*(model.CostPer1MOut/1_000_000)
+		var newTitle string
+		if len(history) == 1 {
+			// first message generate the title
+			newTitle, err = agent.GenerateTitle(s.ctx, content)
+			if err != nil {
+				s.Publish(AgentErrorEvent, AgentEvent{
+					ID:        id,
+					Type:      AgentMessageTypeError,
+					AgentID:   RootAgent,
+					MessageID: "",
+					SessionID: sessionID,
+					Content:   err.Error(),
+				})
+				return
+			}
+		}
+		if newTitle != "" {
+			session.Title = newTitle
+		}
+
 		_, err = s.sessions.Save(session)
 		if err != nil {
 			s.Publish(AgentErrorEvent, AgentEvent{
@@ -183,7 +209,6 @@ func (s *service) handleRequest(id string, sessionID string, content string) {
 			return
 		}
 	}
-	s.messages.Create(sessionID, *out)
 }
 
 func (s *service) SendRequest(sessionID string, content string) {

internal/llm/models/models.go 🔗

@@ -3,6 +3,7 @@ package models
 import (
 	"context"
 	"errors"
+	"log"
 
 	"github.com/cloudwego/eino-ext/components/model/claude"
 	"github.com/cloudwego/eino-ext/components/model/openai"
@@ -16,10 +17,12 @@ type (
 )
 
 type Model struct {
-	ID       ModelID       `json:"id"`
-	Name     string        `json:"name"`
-	Provider ModelProvider `json:"provider"`
-	APIModel string        `json:"api_model"` // Actual value used when calling the API
+	ID           ModelID       `json:"id"`
+	Name         string        `json:"name"`
+	Provider     ModelProvider `json:"provider"`
+	APIModel     string        `json:"api_model"`
+	CostPer1MIn  float64       `json:"cost_per_1m_in"`
+	CostPer1MOut float64       `json:"cost_per_1m_out"`
 }
 
 const (
@@ -52,6 +55,9 @@ const (
 	// Meta
 	Llama3    ModelID = "llama-3"
 	Llama270B ModelID = "llama-2-70b"
+	// GROQ
+	GroqLlama3SpecDec ModelID = "groq-llama-3-spec-dec"
+	GroqQwen32BCoder  ModelID = "qwen-2.5-coder-32b"
 )
 
 const (
@@ -61,6 +67,7 @@ const (
 	ProviderXAI       ModelProvider = "xai"
 	ProviderDeepSeek  ModelProvider = "deepseek"
 	ProviderMeta      ModelProvider = "meta"
+	ProviderGroq      ModelProvider = "groq"
 )
 
 var SupportedModels = map[ModelID]Model{
@@ -72,10 +79,12 @@ var SupportedModels = map[ModelID]Model{
 		APIModel: "gpt-4o",
 	},
 	GPT4oMini: {
-		ID:       GPT4oMini,
-		Name:     "GPT-4o Mini",
-		Provider: ProviderOpenAI,
-		APIModel: "gpt-4o-mini",
+		ID:           GPT4oMini,
+		Name:         "GPT-4o Mini",
+		Provider:     ProviderOpenAI,
+		APIModel:     "gpt-4o-mini",
+		CostPer1MIn:  0.150,
+		CostPer1MOut: 0.600,
 	},
 	GPT45: {
 		ID:       GPT45,
@@ -172,10 +181,25 @@ var SupportedModels = map[ModelID]Model{
 		Provider: ProviderMeta,
 		APIModel: "llama-2-70b",
 	},
+
+	// GROQ
+	GroqLlama3SpecDec: {
+		ID:       GroqLlama3SpecDec,
+		Name:     "GROQ LLaMA 3 SpecDec",
+		Provider: ProviderGroq,
+		APIModel: "llama-3.3-70b-specdec",
+	},
+	GroqQwen32BCoder: {
+		ID:       GroqQwen32BCoder,
+		Name:     "GROQ Qwen 2.5 Coder 32B",
+		Provider: ProviderGroq,
+		APIModel: "qwen-2.5-coder-32b",
+	},
 }
 
 func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
 	provider := SupportedModels[model].Provider
+	log.Printf("Provider: %s", provider)
 	maxTokens := viper.GetInt("providers.common.max_tokens")
 	switch provider {
 	case ProviderOpenAI:
@@ -191,6 +215,14 @@ func GetModel(ctx context.Context, model ModelID) (model.ChatModel, error) {
 			MaxTokens: maxTokens,
 		})
 
+	case ProviderGroq:
+		return openai.NewChatModel(ctx, &openai.ChatModelConfig{
+			BaseURL:   "https://api.groq.com/openai/v1",
+			APIKey:    viper.GetString("providers.groq.key"),
+			Model:     string(SupportedModels[model].APIModel),
+			MaxTokens: &maxTokens,
+		})
+
 	}
 	return nil, errors.New("unsupported provider")
 }

internal/tui/components/repl/messages.go 🔗

@@ -1,22 +1,33 @@
 package repl
 
 import (
+	"github.com/charmbracelet/bubbles/key"
+	"github.com/charmbracelet/bubbles/viewport"
 	tea "github.com/charmbracelet/bubbletea"
 	"github.com/charmbracelet/lipgloss"
 	"github.com/kujtimiihoxha/termai/internal/app"
 	"github.com/kujtimiihoxha/termai/internal/message"
 	"github.com/kujtimiihoxha/termai/internal/pubsub"
 	"github.com/kujtimiihoxha/termai/internal/session"
+	"github.com/kujtimiihoxha/termai/internal/tui/layout"
 )
 
+type MessagesCmp interface {
+	tea.Model
+	layout.Focusable
+	layout.Bordered
+	layout.Sizeable
+	layout.Bindings
+}
+
 type messagesCmp struct {
 	app      *app.App
 	messages []message.Message
 	session  session.Session
-}
-
-func (m *messagesCmp) Init() tea.Cmd {
-	return nil
+	viewport viewport.Model
+	width    int
+	height   int
+	focused  bool
 }
 
 func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -25,6 +36,12 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		if msg.Type == pubsub.CreatedEvent {
 			m.messages = append(m.messages, msg.Payload)
 		}
+	case pubsub.Event[session.Session]:
+		if msg.Type == pubsub.UpdatedEvent {
+			if m.session.ID == msg.Payload.ID {
+				m.session = msg.Payload
+			}
+		}
 	case SelectedSessionMsg:
 		m.session, _ = m.app.Sessions.Get(msg.SessionID)
 		m.messages, _ = m.app.Messages.List(m.session.ID)
@@ -40,7 +57,55 @@ func (i *messagesCmp) View() string {
 	return lipgloss.JoinVertical(lipgloss.Top, stringMessages...)
 }
 
-func NewMessagesCmp(app *app.App) tea.Model {
+// BindingKeys implements MessagesCmp.
+func (m *messagesCmp) BindingKeys() []key.Binding {
+	return []key.Binding{}
+}
+
+// Blur implements MessagesCmp.
+func (m *messagesCmp) Blur() tea.Cmd {
+	m.focused = false
+	return nil
+}
+
+// BorderText implements MessagesCmp.
+func (m *messagesCmp) BorderText() map[layout.BorderPosition]string {
+	title := m.session.Title
+	if len(title) > 20 {
+		title = title[:20] + "..."
+	}
+	return map[layout.BorderPosition]string{
+		layout.TopLeftBorder: title,
+	}
+}
+
+// Focus implements MessagesCmp.
+func (m *messagesCmp) Focus() tea.Cmd {
+	m.focused = true
+	return nil
+}
+
+// GetSize implements MessagesCmp.
+func (m *messagesCmp) GetSize() (int, int) {
+	return m.width, m.height
+}
+
+// IsFocused implements MessagesCmp.
+func (m *messagesCmp) IsFocused() bool {
+	return m.focused
+}
+
+// SetSize implements MessagesCmp.
+func (m *messagesCmp) SetSize(width int, height int) {
+	m.width = width
+	m.height = height
+}
+
+func (m *messagesCmp) Init() tea.Cmd {
+	return nil
+}
+
+func NewMessagesCmp(app *app.App) MessagesCmp {
 	return &messagesCmp{
 		app:      app,
 		messages: []message.Message{},

internal/tui/components/repl/sessions.go 🔗

@@ -2,6 +2,7 @@ package repl
 
 import (
 	"fmt"
+	"strings"
 
 	"github.com/charmbracelet/bubbles/key"
 	"github.com/charmbracelet/bubbles/list"
@@ -82,7 +83,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			items[i] = listItem{
 				id:    s.ID,
 				title: s.Title,
-				desc:  fmt.Sprintf("Tokens: %d, Cost: %.2f", s.PromptTokens+s.CompletionTokens, s.Cost),
+				desc:  formatTokensAndCost(s.PromptTokens+s.CompletionTokens, s.Cost),
 			}
 		}
 		return i, i.list.SetItems(items)
@@ -94,7 +95,7 @@ func (i *sessionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 				s := item.(listItem)
 				if s.id == msg.Payload.ID {
 					s.title = msg.Payload.Title
-					s.desc = fmt.Sprintf("Tokens: %d, Cost: %.2f", msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
+					s.desc = formatTokensAndCost(msg.Payload.PromptTokens+msg.Payload.CompletionTokens, msg.Payload.Cost)
 					items[idx] = s
 					break
 				}
@@ -169,6 +170,32 @@ func (i *sessionsCmp) BindingKeys() []key.Binding {
 	return append(layout.KeyMapToSlice(i.list.KeyMap), sessionKeyMapValue.Select)
 }
 
+func formatTokensAndCost(tokens int64, cost float64) string {
+	// Format tokens in human-readable format (e.g., 110K, 1.2M)
+	var formattedTokens string
+	switch {
+	case tokens >= 1_000_000:
+		formattedTokens = fmt.Sprintf("%.1fM", float64(tokens)/1_000_000)
+	case tokens >= 1_000:
+		formattedTokens = fmt.Sprintf("%.1fK", float64(tokens)/1_000)
+	default:
+		formattedTokens = fmt.Sprintf("%d", tokens)
+	}
+
+	// Remove .0 suffix if present
+	if strings.HasSuffix(formattedTokens, ".0K") {
+		formattedTokens = strings.Replace(formattedTokens, ".0K", "K", 1)
+	}
+	if strings.HasSuffix(formattedTokens, ".0M") {
+		formattedTokens = strings.Replace(formattedTokens, ".0M", "M", 1)
+	}
+
+	// Format cost with $ symbol and 2 decimal places
+	formattedCost := fmt.Sprintf("$%.2f", cost)
+
+	return fmt.Sprintf("Tokens: %s, Cost: %s", formattedTokens, formattedCost)
+}
+
 func NewSessionsCmp(app *app.App) SessionsCmp {
 	listDelegate := list.NewDefaultDelegate()
 	defaultItemStyle := list.NewDefaultItemStyles()