Detailed changes
@@ -53,6 +53,7 @@ termai -d
### Keyboard Shortcuts
#### Global Shortcuts
+
- `?`: Toggle help panel
- `Ctrl+C` or `q`: Quit application
- `L`: View logs
@@ -60,10 +61,12 @@ termai -d
- `Esc`: Close current view/dialog or return to normal mode
#### Session Management
+
- `N`: Create new session
- `Enter` or `Space`: Select session (in sessions list)
#### Editor Shortcuts (Vim-like)
+
- `i`: Enter insert mode
- `Esc`: Enter normal mode
- `v`: Enter visual mode
@@ -72,6 +75,7 @@ termai -d
- `Ctrl+S`: Send message (in insert mode)
#### Navigation
+
- Arrow keys: Navigate through lists and content
- Page Up/Down: Scroll through content
@@ -112,16 +116,6 @@ go build -o termai
./termai
```
-### Important: Building the Diff Script
-
-Before building or running the application, you must first build the diff script by running:
-
-```bash
-go run cmd/diff/main.go
-```
-
-This command generates the necessary JavaScript file (`index.mjs`) used by the diff functionality in the application.
-
## Acknowledgments
TermAI builds upon the work of several open source projects and developers:
@@ -1,102 +0,0 @@
-package main
-
-import (
- "fmt"
- "io"
- "os"
- "os/exec"
- "path/filepath"
-)
-
-func main() {
- // Create a temporary directory
- tempDir, err := os.MkdirTemp("", "git-split-diffs")
- if err != nil {
- fmt.Printf("Error creating temp directory: %v\n", err)
- os.Exit(1)
- }
- defer func() {
- fmt.Printf("Cleaning up temporary directory: %s\n", tempDir)
- os.RemoveAll(tempDir)
- }()
- fmt.Printf("Created temporary directory: %s\n", tempDir)
-
- // Clone the repository with minimum depth
- fmt.Println("Cloning git-split-diffs repository with minimum depth...")
- cmd := exec.Command("git", "clone", "--depth=1", "https://github.com/kujtimiihoxha/git-split-diffs", tempDir)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- if err := cmd.Run(); err != nil {
- fmt.Printf("Error cloning repository: %v\n", err)
- os.Exit(1)
- }
-
- // Run npm install
- fmt.Println("Running npm install...")
- cmdNpmInstall := exec.Command("npm", "install")
- cmdNpmInstall.Dir = tempDir
- cmdNpmInstall.Stdout = os.Stdout
- cmdNpmInstall.Stderr = os.Stderr
- if err := cmdNpmInstall.Run(); err != nil {
- fmt.Printf("Error running npm install: %v\n", err)
- os.Exit(1)
- }
-
- // Run npm run build
- fmt.Println("Running npm run build...")
- cmdNpmBuild := exec.Command("npm", "run", "build")
- cmdNpmBuild.Dir = tempDir
- cmdNpmBuild.Stdout = os.Stdout
- cmdNpmBuild.Stderr = os.Stderr
- if err := cmdNpmBuild.Run(); err != nil {
- fmt.Printf("Error running npm run build: %v\n", err)
- os.Exit(1)
- }
-
- destDir := filepath.Join(".", "internal", "assets", "diff")
- destFile := filepath.Join(destDir, "index.mjs")
-
- // Make sure the destination directory exists
- if err := os.MkdirAll(destDir, 0o755); err != nil {
- fmt.Printf("Error creating destination directory: %v\n", err)
- os.Exit(1)
- }
-
- // Copy the file
- srcFile := filepath.Join(tempDir, "build", "index.mjs")
- fmt.Printf("Copying %s to %s\n", srcFile, destFile)
- if err := copyFile(srcFile, destFile); err != nil {
- fmt.Printf("Error copying file: %v\n", err)
- os.Exit(1)
- }
-
- fmt.Println("Successfully completed the process!")
-}
-
-// copyFile copies a file from src to dst
-func copyFile(src, dst string) error {
- sourceFile, err := os.Open(src)
- if err != nil {
- return err
- }
- defer sourceFile.Close()
-
- destFile, err := os.Create(dst)
- if err != nil {
- return err
- }
- defer destFile.Close()
-
- _, err = io.Copy(destFile, sourceFile)
- if err != nil {
- return err
- }
-
- // Make sure the file is written to disk
- err = destFile.Sync()
- if err != nil {
- return err
- }
-
- return nil
-}
@@ -9,7 +9,6 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/kujtimiihoxha/termai/internal/app"
- "github.com/kujtimiihoxha/termai/internal/assets"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/llm/agent"
@@ -52,11 +51,6 @@ var rootCmd = &cobra.Command{
return err
}
- err = assets.WriteAssets()
- if err != nil {
- logging.Error("Error writing assets: %v", err)
- }
-
// Connect DB, this will also run migrations
conn, err := db.Connect()
if err != nil {
@@ -67,7 +61,11 @@ var rootCmd = &cobra.Command{
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- app := app.New(ctx, conn)
+ app, err := app.New(ctx, conn)
+ if err != nil {
+ logging.Error("Failed to create app: %v", err)
+ return err
+ }
// Set up the TUI
zone.NewGlobal()
@@ -31,7 +31,6 @@ require (
github.com/muesli/reflow v0.3.0
github.com/muesli/termenv v0.16.0
github.com/openai/openai-go v0.1.0-beta.2
- github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3
github.com/spf13/cobra v1.9.1
github.com/spf13/viper v1.20.0
github.com/stretchr/testify v1.10.0
@@ -107,6 +106,7 @@ require (
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect
github.com/sahilm/fuzzy v0.1.1 // indirect
+ github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
github.com/skeema/knownhosts v1.3.1 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.12.0 // indirect
@@ -9,6 +9,7 @@ import (
"github.com/kujtimiihoxha/termai/internal/db"
"github.com/kujtimiihoxha/termai/internal/history"
+ "github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/message"
@@ -22,6 +23,8 @@ type App struct {
Files history.Service
Permissions permission.Service
+ CoderAgent agent.Service
+
LSPClients map[string]*lsp.Client
clientsMutex sync.RWMutex
@@ -31,7 +34,7 @@ type App struct {
watcherWG sync.WaitGroup
}
-func New(ctx context.Context, conn *sql.DB) *App {
+func New(ctx context.Context, conn *sql.DB) (*App, error) {
q := db.New(conn)
sessions := session.NewService(q)
messages := message.NewService(q)
@@ -45,9 +48,22 @@ func New(ctx context.Context, conn *sql.DB) *App {
LSPClients: make(map[string]*lsp.Client),
}
+ var err error
+ app.CoderAgent, err = agent.NewCoderAgent(
+
+ app.Permissions,
+ app.Sessions,
+ app.Messages,
+ app.LSPClients,
+ )
+ if err != nil {
+ logging.Error("Failed to create coder agent", err)
+ return nil, err
+ }
+
app.initLSPClients(ctx)
- return app
+ return app, nil
}
// Shutdown performs a clean shutdown of the application
@@ -1,73 +0,0 @@
-{
- "SYNTAX_HIGHLIGHTING_THEME": "dark-plus",
- "DEFAULT_COLOR": {
- "color": "#ffffff",
- "backgroundColor": "#212121"
- },
- "COMMIT_HEADER_COLOR": {
- "color": "#cccccc"
- },
- "COMMIT_HEADER_LABEL_COLOR": {
- "color": "#00000022"
- },
- "COMMIT_SHA_COLOR": {
- "color": "#00eeaa"
- },
- "COMMIT_AUTHOR_COLOR": {
- "color": "#00aaee"
- },
- "COMMIT_DATE_COLOR": {
- "color": "#cccccc"
- },
- "COMMIT_MESSAGE_COLOR": {
- "color": "#cccccc"
- },
- "COMMIT_TITLE_COLOR": {
- "modifiers": [
- "bold"
- ]
- },
- "FILE_NAME_COLOR": {
- "color": "#ffdd99"
- },
- "BORDER_COLOR": {
- "color": "#ffdd9966",
- "modifiers": [
- "dim"
- ]
- },
- "HUNK_HEADER_COLOR": {
- "modifiers": [
- "dim"
- ]
- },
- "DELETED_WORD_COLOR": {
- "color": "#ffcccc",
- "backgroundColor": "#ff000033"
- },
- "INSERTED_WORD_COLOR": {
- "color": "#ccffcc",
- "backgroundColor": "#00ff0033"
- },
- "DELETED_LINE_NO_COLOR": {
- "color": "#00000022",
- "backgroundColor": "#00000022"
- },
- "INSERTED_LINE_NO_COLOR": {
- "color": "#00000022",
- "backgroundColor": "#00000022"
- },
- "UNMODIFIED_LINE_NO_COLOR": {
- "color": "#666666"
- },
- "DELETED_LINE_COLOR": {
- "color": "#cc6666",
- "backgroundColor": "#3a3030"
- },
- "INSERTED_LINE_COLOR": {
- "color": "#66cc66",
- "backgroundColor": "#303a30"
- },
- "UNMODIFIED_LINE_COLOR": {},
- "MISSING_LINE_COLOR": {}
-}
@@ -1,6 +0,0 @@
-package assets
-
-import "embed"
-
-//go:embed diff
-var FS embed.FS
@@ -1,60 +0,0 @@
-package assets
-
-import (
- "os"
- "path/filepath"
-
- "github.com/kujtimiihoxha/termai/internal/config"
-)
-
-func WriteAssets() error {
- appCfg := config.Get()
- appWd := config.WorkingDirectory()
- scriptDir := filepath.Join(
- appWd,
- appCfg.Data.Directory,
- "diff",
- )
- scriptPath := filepath.Join(scriptDir, "index.mjs")
- // Before, run the script in cmd/diff/main.go to build this file
- if _, err := os.Stat(scriptPath); err != nil {
- scriptData, err := FS.ReadFile("diff/index.mjs")
- if err != nil {
- return err
- }
-
- err = os.MkdirAll(scriptDir, 0o755)
- if err != nil {
- return err
- }
- err = os.WriteFile(scriptPath, scriptData, 0o755)
- if err != nil {
- return err
- }
- }
-
- themeDir := filepath.Join(
- appWd,
- appCfg.Data.Directory,
- "themes",
- )
-
- themePath := filepath.Join(themeDir, "dark.json")
-
- if _, err := os.Stat(themePath); err != nil {
- themeData, err := FS.ReadFile("diff/themes/dark.json")
- if err != nil {
- return err
- }
-
- err = os.MkdirAll(themeDir, 0o755)
- if err != nil {
- return err
- }
- err = os.WriteFile(themePath, themeData, 0o755)
- if err != nil {
- return err
- }
- }
- return nil
-}
@@ -11,7 +11,6 @@ import (
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/object"
- "github.com/kujtimiihoxha/termai/internal/config"
)
type DiffStats struct {
@@ -197,32 +196,32 @@ func isSplitDiffsAvailable() bool {
}
func formatWithSplitDiffs(diffText string, width int) (string, error) {
- var cmd *exec.Cmd
+ args := []string{
+ "--color",
+ }
- appCfg := config.Get()
- appWd := config.WorkingDirectory()
- script := filepath.Join(
- appWd,
- appCfg.Data.Directory,
- "diff",
- "index.mjs",
- )
+ var diffCmd *exec.Cmd
- cmd = exec.Command("node", script, "--color")
+ if _, err := exec.LookPath("git-split-diffs-opencode"); err == nil {
+ fullArgs := append([]string{"git-split-diffs-opencode"}, args...)
+ diffCmd = exec.Command(fullArgs[0], fullArgs[1:]...)
+ } else {
+ npxArgs := append([]string{"git-split-diffs-opencode"}, args...)
+ diffCmd = exec.Command("npx", npxArgs...)
+ }
- cmd.Env = append(os.Environ(), fmt.Sprintf("COLUMNS=%d", width))
+ diffCmd.Env = append(os.Environ(), fmt.Sprintf("DIFF_COLUMNS=%d", width))
- cmd.Stdin = strings.NewReader(diffText)
+ diffCmd.Stdin = strings.NewReader(diffText)
var out bytes.Buffer
- cmd.Stdout = &out
+ diffCmd.Stdout = &out
var stderr bytes.Buffer
- cmd.Stderr = &stderr
+ diffCmd.Stderr = &stderr
- err := cmd.Run()
- if err != nil {
- return "", fmt.Errorf("git-split-diffs error: %v, stderr: %s", err, stderr.String())
+ if err := diffCmd.Run(); err != nil {
+ return "", fmt.Errorf("git-split-diffs-opencode error: %w, stderr: %s", err, stderr.String())
}
return out.String(), nil
@@ -5,14 +5,16 @@ import (
"encoding/json"
"fmt"
- "github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
"github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/session"
)
type agentTool struct {
- parentSessionID string
- app *app.App
+ sessions session.Service
+ messages message.Service
+ lspClients map[string]*lsp.Client
}
const (
@@ -46,12 +48,17 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("prompt is required"), nil
}
- agent, err := NewTaskAgent(b.app)
+ sessionID, messageID := tools.GetContextValues(ctx)
+ if sessionID == "" || messageID == "" {
+ return tools.NewTextErrorResponse("session ID and message ID are required"), nil
+ }
+
+ agent, err := NewTaskAgent(b.lspClients)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil
}
- session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session")
+ session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
}
@@ -61,7 +68,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
}
- messages, err := b.app.Messages.List(ctx, session.ID)
+ messages, err := b.messages.List(ctx, session.ID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil
}
@@ -74,11 +81,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
return tools.NewTextErrorResponse("no assistant message found"), nil
}
- updatedSession, err := b.app.Sessions.Get(ctx, session.ID)
+ updatedSession, err := b.sessions.Get(ctx, session.ID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
- parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID)
+ parentSession, err := b.sessions.Get(ctx, sessionID)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
@@ -87,16 +94,19 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
parentSession.PromptTokens += updatedSession.PromptTokens
parentSession.CompletionTokens += updatedSession.CompletionTokens
- _, err = b.app.Sessions.Save(ctx, parentSession)
+ _, err = b.sessions.Save(ctx, parentSession)
if err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
}
return tools.NewTextResponse(response.Content().String()), nil
}
-func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool {
+func NewAgentTool(
+ Sessions session.Service,
+ Messages message.Service,
+) tools.BaseTool {
return &agentTool{
- parentSessionID: parentSessionID,
- app: app,
+ sessions: Sessions,
+ messages: Messages,
}
}
@@ -7,7 +7,6 @@ import (
"strings"
"sync"
- "github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/prompt"
@@ -15,22 +14,118 @@ import (
"github.com/kujtimiihoxha/termai/internal/llm/tools"
"github.com/kujtimiihoxha/termai/internal/logging"
"github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/session"
)
-type Agent interface {
+// Common errors
+var (
+ ErrProviderNotEnabled = errors.New("provider is not enabled")
+ ErrRequestCancelled = errors.New("request cancelled by user")
+ ErrSessionBusy = errors.New("session is currently processing another request")
+)
+
+// Service defines the interface for generating responses
+type Service interface {
Generate(ctx context.Context, sessionID string, content string) error
+ Cancel(sessionID string) error
}
type agent struct {
- *app.App
+ sessions session.Service
+ messages message.Service
model models.Model
tools []tools.BaseTool
agent provider.Provider
titleGenerator provider.Provider
+ activeRequests sync.Map // map[sessionID]context.CancelFunc
+}
+
+// NewAgent creates a new agent instance with the given model and tools
+func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
+ agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize providers: %w", err)
+ }
+
+ return &agent{
+ model: model,
+ tools: tools,
+ sessions: sessions,
+ messages: messages,
+ agent: agentProvider,
+ titleGenerator: titleGenerator,
+ activeRequests: sync.Map{},
+ }, nil
+}
+
+// Cancel cancels an active request by session ID
+func (a *agent) Cancel(sessionID string) error {
+ if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
+ if cancel, ok := cancelFunc.(context.CancelFunc); ok {
+ logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
+ cancel()
+ return nil
+ }
+ }
+ return errors.New("no active request found for this session")
}
-func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
- response, err := c.titleGenerator.SendMessages(
+// Generate starts the generation process
+func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
+ // Check if this session already has an active request
+ if _, busy := a.activeRequests.Load(sessionID); busy {
+ return ErrSessionBusy
+ }
+
+ // Create a cancellable context
+ genCtx, cancel := context.WithCancel(ctx)
+
+ // Store cancel function to allow user cancellation
+ a.activeRequests.Store(sessionID, cancel)
+
+ // Launch the generation in a goroutine
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
+ }
+ }()
+ defer a.activeRequests.Delete(sessionID)
+ defer cancel()
+
+ if err := a.generate(genCtx, sessionID, content); err != nil {
+ if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
+ // Log the error (avoid logging cancellations as they're expected)
+ logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
+
+ // You may want to create an error message in the chat
+ bgCtx := context.Background()
+ errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
+ _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
+ Role: message.System,
+ Parts: []message.ContentPart{
+ message.TextContent{
+ Text: errorMsg,
+ },
+ },
+ })
+ if createErr != nil {
+ logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
+ }
+ }
+ }
+ }()
+
+ return nil
+}
+
+// IsSessionBusy checks if a session currently has an active request
+func (a *agent) IsSessionBusy(sessionID string) bool {
+ _, busy := a.activeRequests.Load(sessionID)
+ return busy
+} // handleTitleGeneration asynchronously generates a title for new sessions
+func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
+ response, err := a.titleGenerator.SendMessages(
ctx,
[]message.Message{
{
@@ -45,25 +140,30 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
nil,
)
if err != nil {
+ logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
return
}
- session, err := c.Sessions.Get(ctx, sessionID)
+ session, err := a.sessions.Get(ctx, sessionID)
if err != nil {
+ logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
return
}
+
if response.Content != "" {
- session.Title = response.Content
- session.Title = strings.TrimSpace(session.Title)
+ session.Title = strings.TrimSpace(response.Content)
session.Title = strings.ReplaceAll(session.Title, "\n", " ")
- c.Sessions.Save(ctx, session)
+ if _, err := a.sessions.Save(ctx, session); err != nil {
+ logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
+ }
}
}
-func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
- session, err := c.Sessions.Get(ctx, sessionID)
+// TrackUsage updates token usage statistics for the session
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
+ session, err := a.sessions.Get(ctx, sessionID)
if err != nil {
- return err
+ return fmt.Errorf("failed to get session: %w", err)
}
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
@@ -75,189 +175,241 @@ func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
session.CompletionTokens += usage.OutputTokens
session.PromptTokens += usage.InputTokens
- _, err = c.Sessions.Save(ctx, session)
- return err
+ _, err = a.sessions.Save(ctx, session)
+ if err != nil {
+ return fmt.Errorf("failed to save session: %w", err)
+ }
+ return nil
}
-func (c *agent) processEvent(
+// processEvent handles different types of events during generation
+func (a *agent) processEvent(
ctx context.Context,
sessionID string,
assistantMsg *message.Message,
event provider.ProviderEvent,
) error {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ // Continue processing
+ }
+
switch event.Type {
case provider.EventThinkingDelta:
assistantMsg.AppendReasoningContent(event.Content)
- return c.Messages.Update(ctx, *assistantMsg)
+ return a.messages.Update(ctx, *assistantMsg)
case provider.EventContentDelta:
assistantMsg.AppendContent(event.Content)
- return c.Messages.Update(ctx, *assistantMsg)
+ return a.messages.Update(ctx, *assistantMsg)
case provider.EventError:
if errors.Is(event.Error, context.Canceled) {
- return nil
+ logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
+ return context.Canceled
}
logging.ErrorPersist(event.Error.Error())
return event.Error
case provider.EventWarning:
logging.WarnPersist(event.Info)
- return nil
case provider.EventInfo:
logging.InfoPersist(event.Info)
case provider.EventComplete:
assistantMsg.SetToolCalls(event.Response.ToolCalls)
assistantMsg.AddFinish(event.Response.FinishReason)
- err := c.Messages.Update(ctx, *assistantMsg)
- if err != nil {
- return err
+ if err := a.messages.Update(ctx, *assistantMsg); err != nil {
+ return fmt.Errorf("failed to update message: %w", err)
}
- return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage)
+ return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
}
return nil
}
-func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
- var wg sync.WaitGroup
+// ExecuteTools runs all tool calls sequentially and returns the results
+func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
toolResults := make([]message.ToolResult, len(toolCalls))
- mutex := &sync.Mutex{}
- errChan := make(chan error, 1)
// Create a child context that can be canceled
ctx, cancel := context.WithCancel(ctx)
defer cancel()
- for i, tc := range toolCalls {
- wg.Add(1)
- go func(index int, toolCall message.ToolCall) {
- defer wg.Done()
+ // Check if already canceled before starting any execution
+ if ctx.Err() != nil {
+ // Mark all tools as canceled
+ for i, toolCall := range toolCalls {
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: "Tool execution canceled by user",
+ IsError: true,
+ }
+ }
+ return toolResults, ctx.Err()
+ }
- // Check if context is already canceled
- select {
- case <-ctx.Done():
- mutex.Lock()
- toolResults[index] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: "Tool execution canceled",
+ for i, toolCall := range toolCalls {
+ // Check for cancellation before executing each tool
+ select {
+ case <-ctx.Done():
+ // Mark this and all remaining tools as canceled
+ for j := i; j < len(toolCalls); j++ {
+ toolResults[j] = message.ToolResult{
+ ToolCallID: toolCalls[j].ID,
+ Content: "Tool execution canceled by user",
IsError: true,
}
- mutex.Unlock()
-
- // Send cancellation error to error channel if it's empty
- select {
- case errChan <- ctx.Err():
- default:
- }
- return
- default:
}
+ return toolResults, ctx.Err()
+ default:
+ // Continue processing
+ }
- response := ""
- isError := false
- found := false
-
- for _, tool := range tls {
- if tool.Info().Name == toolCall.Name {
- found = true
- toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
- ID: toolCall.ID,
- Name: toolCall.Name,
- Input: toolCall.Input,
- })
-
- if toolErr != nil {
- if errors.Is(toolErr, context.Canceled) {
- response = "Tool execution canceled"
-
- // Send cancellation error to error channel if it's empty
- select {
- case errChan <- ctx.Err():
- default:
- }
- } else {
- response = fmt.Sprintf("error running tool: %s", toolErr)
- }
- isError = true
+ response := ""
+ isError := false
+ found := false
+
+ // Find and execute the appropriate tool
+ for _, tool := range tls {
+ if tool.Info().Name == toolCall.Name {
+ found = true
+ toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
+ ID: toolCall.ID,
+ Name: toolCall.Name,
+ Input: toolCall.Input,
+ })
+
+ if toolErr != nil {
+ if errors.Is(toolErr, context.Canceled) {
+ response = "Tool execution canceled by user"
} else {
- response = toolResult.Content
- isError = toolResult.IsError
+ response = fmt.Sprintf("Error running tool: %s", toolErr)
}
- break
+ isError = true
+ } else {
+ response = toolResult.Content
+ isError = toolResult.IsError
}
+ break
}
+ }
- if !found {
- response = fmt.Sprintf("tool not found: %s", toolCall.Name)
- isError = true
- }
-
- mutex.Lock()
- defer mutex.Unlock()
-
- toolResults[index] = message.ToolResult{
- ToolCallID: toolCall.ID,
- Content: response,
- IsError: isError,
- }
- }(i, tc)
- }
-
- // Wait for all goroutines to finish or context to be canceled
- done := make(chan struct{})
- go func() {
- wg.Wait()
- close(done)
- }()
+ if !found {
+ response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
+ isError = true
+ }
- select {
- case <-done:
- // All tools completed successfully
- case err := <-errChan:
- // One of the tools encountered a cancellation
- return toolResults, err
- case <-ctx.Done():
- // Context was canceled externally
- return toolResults, ctx.Err()
+ toolResults[i] = message.ToolResult{
+ ToolCallID: toolCall.ID,
+ Content: response,
+ IsError: isError,
+ }
}
return toolResults, nil
}
-func (c *agent) handleToolExecution(
+// handleToolExecution processes tool calls and creates tool result messages
+func (a *agent) handleToolExecution(
ctx context.Context,
assistantMsg message.Message,
) (*message.Message, error) {
+ select {
+ case <-ctx.Done():
+ // If cancelled, create tool results that indicate cancellation
+ if len(assistantMsg.ToolCalls()) > 0 {
+ toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
+ for _, tc := range assistantMsg.ToolCalls() {
+ toolResults = append(toolResults, message.ToolResult{
+ ToolCallID: tc.ID,
+ Content: "Tool execution canceled by user",
+ IsError: true,
+ })
+ }
+
+ // Use background context to ensure the message is created even if original context is cancelled
+ bgCtx := context.Background()
+ parts := make([]message.ContentPart, 0)
+ for _, toolResult := range toolResults {
+ parts = append(parts, toolResult)
+ }
+ msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
+ Role: message.Tool,
+ Parts: parts,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
+ }
+ return &msg, ctx.Err()
+ }
+ return nil, ctx.Err()
+ default:
+ // Continue processing
+ }
+
if len(assistantMsg.ToolCalls()) == 0 {
return nil, nil
}
- toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
+ toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
if err != nil {
+ // If error is from cancellation, still return the partial results we have
+ if errors.Is(err, context.Canceled) {
+ // Use background context to ensure the message is created even if original context is cancelled
+ bgCtx := context.Background()
+ parts := make([]message.ContentPart, 0)
+ for _, toolResult := range toolResults {
+ parts = append(parts, toolResult)
+ }
+
+ msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
+ Role: message.Tool,
+ Parts: parts,
+ })
+ if createErr != nil {
+ logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
+ return nil, err
+ }
+ return &msg, err
+ }
return nil, err
}
- parts := make([]message.ContentPart, 0)
+
+ parts := make([]message.ContentPart, 0, len(toolResults))
for _, toolResult := range toolResults {
parts = append(parts, toolResult)
}
- msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
+
+ msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create tool message: %w", err)
+ }
- return &msg, err
+ return &msg, nil
}
-func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
+// generate handles the main generation workflow
+func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
- messages, err := c.Messages.List(ctx, sessionID)
+
+ // Handle context cancellation at any point
+ if err := ctx.Err(); err != nil {
+ return ErrRequestCancelled
+ }
+
+ messages, err := a.messages.List(ctx, sessionID)
if err != nil {
- return err
+ return fmt.Errorf("failed to list messages: %w", err)
}
if len(messages) == 0 {
- go c.handleTitleGeneration(ctx, sessionID, content)
+ titleCtx := context.Background()
+ go a.handleTitleGeneration(titleCtx, sessionID, content)
}
- userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
+ userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.User,
Parts: []message.ContentPart{
message.TextContent{
@@ -266,133 +418,125 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
},
})
if err != nil {
- return err
+ return fmt.Errorf("failed to create user message: %w", err)
}
messages = append(messages, userMsg)
+
for {
+ // Check for cancellation before each iteration
select {
case <-ctx.Done():
- assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- })
- if err != nil {
- return err
- }
- assistantMsg.AddFinish("canceled")
- c.Messages.Update(ctx, assistantMsg)
- return context.Canceled
+ return ErrRequestCancelled
default:
// Continue processing
}
- eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
+ eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
if err != nil {
if errors.Is(err, context.Canceled) {
- assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
- Role: message.Assistant,
- Parts: []message.ContentPart{},
- })
- if err != nil {
- return err
- }
- assistantMsg.AddFinish("canceled")
- c.Messages.Update(ctx, assistantMsg)
- return context.Canceled
+ return ErrRequestCancelled
}
- return err
+ return fmt.Errorf("failed to stream response: %w", err)
}
- assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
+ assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{},
- Model: c.model.ID,
+ Model: a.model.ID,
})
if err != nil {
- return err
+ return fmt.Errorf("failed to create assistant message: %w", err)
}
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
+
+ // Process events from the LLM provider
for event := range eventChan {
- err = c.processEvent(ctx, sessionID, &assistantMsg, event)
- if err != nil {
+ if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
if errors.Is(err, context.Canceled) {
+ // Mark as canceled but don't create separate message
assistantMsg.AddFinish("canceled")
- c.Messages.Update(ctx, assistantMsg)
- return context.Canceled
+ _ = a.messages.Update(context.Background(), assistantMsg)
+ return ErrRequestCancelled
}
assistantMsg.AddFinish("error:" + err.Error())
- c.Messages.Update(ctx, assistantMsg)
- return err
+ _ = a.messages.Update(ctx, assistantMsg)
+ return fmt.Errorf("event processing error: %w", err)
}
+ // Check for cancellation during event processing
select {
case <-ctx.Done():
+ // Mark as canceled
assistantMsg.AddFinish("canceled")
- c.Messages.Update(ctx, assistantMsg)
- return context.Canceled
+ _ = a.messages.Update(context.Background(), assistantMsg)
+ return ErrRequestCancelled
default:
}
}
- // Check for context cancellation before tool execution
+ // Check for cancellation before tool execution
select {
case <-ctx.Done():
- assistantMsg.AddFinish("canceled")
- c.Messages.Update(ctx, assistantMsg)
- return context.Canceled
+ assistantMsg.AddFinish("canceled_by_user")
+ _ = a.messages.Update(context.Background(), assistantMsg)
+ return ErrRequestCancelled
default:
- // Continue processing
}
- msg, err := c.handleToolExecution(ctx, assistantMsg)
+ // Execute any tool calls
+ toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
if err != nil {
if errors.Is(err, context.Canceled) {
- assistantMsg.AddFinish("canceled")
- c.Messages.Update(ctx, assistantMsg)
- return context.Canceled
+ assistantMsg.AddFinish("canceled_by_user")
+ _ = a.messages.Update(context.Background(), assistantMsg)
+ return ErrRequestCancelled
}
- return err
+ return fmt.Errorf("tool execution error: %w", err)
}
- c.Messages.Update(ctx, assistantMsg)
+ if err := a.messages.Update(ctx, assistantMsg); err != nil {
+ return fmt.Errorf("failed to update assistant message: %w", err)
+ }
+ // If no tool calls, we're done
if len(assistantMsg.ToolCalls()) == 0 {
break
}
+ // Add messages for next iteration
messages = append(messages, assistantMsg)
- if msg != nil {
- messages = append(messages, *msg)
+ if toolMsg != nil {
+ messages = append(messages, *toolMsg)
}
- // Check for context cancellation after tool execution
+ // Check for cancellation after tool execution
select {
case <-ctx.Done():
- assistantMsg.AddFinish("canceled")
- c.Messages.Update(ctx, assistantMsg)
- return context.Canceled
+ return ErrRequestCancelled
default:
- // Continue processing
}
}
+
return nil
}
+// getAgentProviders initializes the LLM providers based on the chosen model
func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
maxTokens := config.Get().Model.CoderMaxTokens
providerConfig, ok := config.Get().Providers[model.Provider]
if !ok || providerConfig.Disabled {
- return nil, nil, errors.New("provider is not enabled")
+ return nil, nil, ErrProviderNotEnabled
}
+
var agentProvider provider.Provider
var titleGenerator provider.Provider
+ var err error
switch model.Provider {
case models.ProviderOpenAI:
- var err error
agentProvider, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.CoderOpenAISystemPrompt(),
@@ -402,8 +546,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithOpenAIKey(providerConfig.APIKey),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
}
+
titleGenerator, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.TitlePrompt(),
@@ -413,10 +558,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithOpenAIKey(providerConfig.APIKey),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
}
+
case models.ProviderAnthropic:
- var err error
agentProvider, err = provider.NewAnthropicProvider(
provider.WithAnthropicSystemMessage(
prompt.CoderAnthropicSystemPrompt(),
@@ -426,8 +571,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithAnthropicModel(model),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
}
+
titleGenerator, err = provider.NewAnthropicProvider(
provider.WithAnthropicSystemMessage(
prompt.TitlePrompt(),
@@ -437,11 +583,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithAnthropicModel(model),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
}
case models.ProviderGemini:
- var err error
agentProvider, err = provider.NewGeminiProvider(
ctx,
provider.WithGeminiSystemMessage(
@@ -452,8 +597,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithGeminiModel(model),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
}
+
titleGenerator, err = provider.NewGeminiProvider(
ctx,
provider.WithGeminiSystemMessage(
@@ -464,10 +610,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithGeminiModel(model),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
}
+
case models.ProviderGROQ:
- var err error
agentProvider, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.CoderAnthropicSystemPrompt(),
@@ -478,8 +624,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
}
+
titleGenerator, err = provider.NewOpenAIProvider(
provider.WithOpenAISystemMessage(
prompt.TitlePrompt(),
@@ -490,11 +637,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
}
case models.ProviderBedrock:
- var err error
agentProvider, err = provider.NewBedrockProvider(
provider.WithBedrockSystemMessage(
prompt.CoderAnthropicSystemPrompt(),
@@ -503,19 +649,21 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
provider.WithBedrockModel(model),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
}
+
titleGenerator, err = provider.NewBedrockProvider(
provider.WithBedrockSystemMessage(
prompt.TitlePrompt(),
),
- provider.WithBedrockMaxTokens(maxTokens),
+ provider.WithBedrockMaxTokens(80),
provider.WithBedrockModel(model),
)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
}
-
+ default:
+ return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
}
return agentProvider, titleGenerator, nil
@@ -4,71 +4,60 @@ import (
"context"
"errors"
- "github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
+ "github.com/kujtimiihoxha/termai/internal/message"
+ "github.com/kujtimiihoxha/termai/internal/permission"
+ "github.com/kujtimiihoxha/termai/internal/session"
)
type coderAgent struct {
- *agent
+ Service
}
-func (c *coderAgent) setAgentTool(sessionID string) {
- inx := -1
- for i, tool := range c.tools {
- if tool.Info().Name == AgentToolName {
- inx = i
- break
- }
- }
- if inx == -1 {
- c.tools = append(c.tools, NewAgentTool(sessionID, c.App))
- } else {
- c.tools[inx] = NewAgentTool(sessionID, c.App)
- }
-}
-
-func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error {
- c.setAgentTool(sessionID)
- return c.generate(ctx, sessionID, content)
-}
-
-func NewCoderAgent(app *app.App) (Agent, error) {
+func NewCoderAgent(
+ permissions permission.Service,
+ sessions session.Service,
+ messages message.Service,
+ lspClients map[string]*lsp.Client,
+) (Service, error) {
model, ok := models.SupportedModels[config.Get().Model.Coder]
if !ok {
return nil, errors.New("model not supported")
}
ctx := context.Background()
- agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
+ otherTools := GetMcpTools(ctx, permissions)
+ if len(lspClients) > 0 {
+ otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
+ }
+ agent, err := NewAgent(
+ ctx,
+ sessions,
+ messages,
+ model,
+ append(
+ []tools.BaseTool{
+ tools.NewBashTool(permissions),
+ tools.NewEditTool(lspClients, permissions),
+ tools.NewFetchTool(permissions),
+ tools.NewGlobTool(),
+ tools.NewGrepTool(),
+ tools.NewLsTool(),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients),
+ tools.NewWriteTool(lspClients, permissions),
+ NewAgentTool(sessions, messages),
+ }, otherTools...,
+ ),
+ )
if err != nil {
return nil, err
}
- otherTools := GetMcpTools(ctx, app.Permissions)
- if len(app.LSPClients) > 0 {
- otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients))
- }
return &coderAgent{
- agent: &agent{
- App: app,
- tools: append(
- []tools.BaseTool{
- tools.NewBashTool(app.Permissions),
- tools.NewEditTool(app.LSPClients, app.Permissions),
- tools.NewFetchTool(app.Permissions),
- tools.NewGlobTool(),
- tools.NewGrepTool(),
- tools.NewLsTool(),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(app.LSPClients),
- tools.NewWriteTool(app.LSPClients, app.Permissions),
- }, otherTools...,
- ),
- model: model,
- agent: agentProvider,
- titleGenerator: titleGenerator,
- },
+ agent,
}, nil
}
@@ -4,10 +4,10 @@ import (
"context"
"errors"
- "github.com/kujtimiihoxha/termai/internal/app"
"github.com/kujtimiihoxha/termai/internal/config"
"github.com/kujtimiihoxha/termai/internal/llm/models"
"github.com/kujtimiihoxha/termai/internal/llm/tools"
+ "github.com/kujtimiihoxha/termai/internal/lsp"
)
type taskAgent struct {
@@ -18,7 +18,7 @@ func (c *taskAgent) Generate(ctx context.Context, sessionID string, content stri
return c.generate(ctx, sessionID, content)
}
-func NewTaskAgent(app *app.App) (Agent, error) {
+func NewTaskAgent(lspClients map[string]*lsp.Client) (Service, error) {
model, ok := models.SupportedModels[config.Get().Model.Coder]
if !ok {
return nil, errors.New("model not supported")
@@ -31,13 +31,12 @@ func NewTaskAgent(app *app.App) (Agent, error) {
}
return &taskAgent{
agent: &agent{
- App: app,
tools: []tools.BaseTool{
tools.NewGlobTool(),
tools.NewGrepTool(),
tools.NewLsTool(),
tools.NewSourcegraphTool(),
- tools.NewViewTool(app.LSPClients),
+ tools.NewViewTool(lspClients),
},
model: model,
agent: agentProvider,
@@ -57,7 +57,9 @@ func cleanupMessages(messages []message.Message) []message.Message {
// First pass: filter out canceled messages
var cleanedMessages []message.Message
for _, msg := range messages {
- if msg.FinishReason() != "canceled" {
+ if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 {
+ // if there are toolCalls this means we want to return it to the LLM telling it that those tools have been
+ // cancelled
cleanedMessages = append(cleanedMessages, msg)
}
}
@@ -190,7 +190,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
return er, fmt.Errorf("failed to create parent directories: %w", err)
}
- sessionID, messageID := getContextValues(ctx)
+ sessionID, messageID := GetContextValues(ctx)
if sessionID == "" || messageID == "" {
return er, fmt.Errorf("session ID and message ID are required for creating a new file")
}
@@ -277,7 +277,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
newContent := oldContent[:index] + oldContent[index+len(oldString):]
- sessionID, messageID := getContextValues(ctx)
+ sessionID, messageID := GetContextValues(ctx)
if sessionID == "" || messageID == "" {
return er, fmt.Errorf("session ID and message ID are required for creating a new file")
@@ -365,7 +365,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
- sessionID, messageID := getContextValues(ctx)
+ sessionID, messageID := GetContextValues(ctx)
if sessionID == "" || messageID == "" {
return er, fmt.Errorf("session ID and message ID are required for creating a new file")
@@ -409,4 +409,3 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
return er, nil
}
-
@@ -66,7 +66,7 @@ type BaseTool interface {
Run(ctx context.Context, params ToolCall) (ToolResponse, error)
}
-func getContextValues(ctx context.Context) (string, string) {
+func GetContextValues(ctx context.Context) (string, string) {
sessionID := ctx.Value(SessionIDContextKey)
messageID := ctx.Value(MessageIDContextKey)
if sessionID == nil {
@@ -144,7 +144,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
}
}
- sessionID, messageID := getContextValues(ctx)
+ sessionID, messageID := GetContextValues(ctx)
if sessionID == "" || messageID == "" {
return NewTextErrorResponse("session ID or message ID is missing"), nil
}
@@ -7,7 +7,6 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/kujtimiihoxha/termai/internal/app"
- "github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
"github.com/kujtimiihoxha/termai/internal/tui/styles"
"github.com/kujtimiihoxha/termai/internal/tui/util"
@@ -168,11 +167,6 @@ func (m *editorCmp) Send() tea.Cmd {
return util.ReportWarn("Assistant is still working on the previous message")
}
- a, err := agent.NewCoderAgent(m.app)
- if err != nil {
- return util.ReportError(err)
- }
-
content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
if len(content) == 0 {
return util.ReportWarn("Message is empty")
@@ -181,7 +175,7 @@ func (m *editorCmp) Send() tea.Cmd {
m.cancelMessage = cancel
go func() {
defer cancel()
- a.Generate(ctx, m.sessionID, content)
+ m.app.CoderAgent.Generate(ctx, m.sessionID, content)
m.cancelMessage = nil
}()
@@ -6,7 +6,6 @@ import (
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/kujtimiihoxha/termai/internal/app"
- "github.com/kujtimiihoxha/termai/internal/llm/agent"
"github.com/kujtimiihoxha/termai/internal/session"
"github.com/kujtimiihoxha/termai/internal/tui/components/chat"
"github.com/kujtimiihoxha/termai/internal/tui/layout"
@@ -23,6 +22,7 @@ type chatPage struct {
type ChatKeyMap struct {
NewSession key.Binding
+ Cancel key.Binding
}
var keyMap = ChatKeyMap{
@@ -30,6 +30,10 @@ var keyMap = ChatKeyMap{
key.WithKeys("ctrl+n"),
key.WithHelp("ctrl+n", "new session"),
),
+ Cancel: key.NewBinding(
+ key.WithKeys("ctrl+x"),
+ key.WithHelp("ctrl+x", "cancel"),
+ ),
}
func (p *chatPage) Init() tea.Cmd {
@@ -106,15 +110,8 @@ func (p *chatPage) sendMessage(text string) tea.Cmd {
}
cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session)))
}
- // TODO: move this to a service
- a, err := agent.NewCoderAgent(p.app)
- if err != nil {
- return util.ReportError(err)
- }
- go func() {
- a.Generate(context.Background(), p.session.ID, text)
- }()
+ p.app.CoderAgent.Generate(context.Background(), p.session.ID, text)
return tea.Batch(cmds...)
}