diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f25cbdc7849c9f9f3d55e34206faaca82834960c..53c17f8b28d0d4de6be03873acb797c598ff8ed0 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -423,7 +423,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy return getSessionErr } a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata)) - _, sessionErr := a.sessions.Save(ctx, updatedSession) + _, sessionErr := a.sessions.SaveWithModels(ctx, updatedSession, map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg, + config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg, + }) if sessionErr != nil { return sessionErr } @@ -723,7 +726,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan currentSession.SummaryMessageID = summaryMessage.ID currentSession.CompletionTokens = usage.OutputTokens currentSession.PromptTokens = 0 - _, err = a.sessions.Save(genCtx, currentSession) + _, err = a.sessions.SaveWithModels(genCtx, currentSession, map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg, + config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg, + }) return err } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 1130438a61217b3491bd21e388376f303631e9ec..60f65c8d4514f1a69e99a8466b920d7021d164af 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -476,7 +476,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep), tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls), tools.NewSourcegraphTool(nil), - tools.NewTodosTool(c.sessions), + tools.NewTodosTool(c.sessions, c.cfg), tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.skillTracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...), tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), ) @@ -1051,7 +1051,7 @@ func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionI parentSession.Cost += childSession.Cost - if _, err := c.sessions.Save(ctx, parentSession); err != nil { + if _, err := c.sessions.SaveWithModels(ctx, parentSession, c.cfg.Config().Models); err != nil { return fmt.Errorf("save parent session: %w", err) } diff --git a/internal/agent/tools/todos.go b/internal/agent/tools/todos.go index 2f69f7bf84581d9f0e0776d73660bef7ba34ba43..ea6dca52d54814a6e0e83b71f6c7d77b8c5502db 100644 --- a/internal/agent/tools/todos.go +++ b/internal/agent/tools/todos.go @@ -6,6 +6,7 @@ import ( "fmt" "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/session" ) @@ -33,7 +34,7 @@ type TodosResponseMetadata struct { Total int `json:"total"` } -func NewTodosTool(sessions session.Service) fantasy.AgentTool { +func NewTodosTool(sessions session.Service, cfg *config.ConfigStore) fantasy.AgentTool { return fantasy.NewAgentTool( TodosToolName, FirstLineDescription(todosDescription), @@ -96,7 +97,7 @@ func NewTodosTool(sessions session.Service) fantasy.AgentTool { } currentSession.Todos = todos - _, err = sessions.Save(ctx, currentSession) + _, err = sessions.SaveWithModels(ctx, currentSession, cfg.Config().Models) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("failed to save todos: %w", err) } diff --git a/internal/app/resolve_session_test.go b/internal/app/resolve_session_test.go index 9b0c7af736fa9637c095c7851da3460bacf737a2..7bc79063b71401c4b4886eb981a58aceb75ac921 100644 --- a/internal/app/resolve_session_test.go +++ b/internal/app/resolve_session_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/stretchr/testify/require" @@ -60,6 +61,14 @@ func (m *mockSessionService) Save(_ context.Context, s session.Session) (session return s, nil } +func (m *mockSessionService) SaveWithModels(_ context.Context, s session.Session, _ map[config.SelectedModelType]config.SelectedModel) (session.Session, error) { + return s, nil +} + +func (m *mockSessionService) UpdateSessionModels(context.Context, string, map[config.SelectedModelType]config.SelectedModel) error { + return nil +} + func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error { return nil } diff --git a/internal/backend/session.go b/internal/backend/session.go index 10e21ed8932ccbc990a525785166517cd231595c..6465d9eedfeaf20b52a9460f8ac25a4aae91ed54 100644 --- a/internal/backend/session.go +++ b/internal/backend/session.go @@ -3,6 +3,7 @@ package backend import ( "context" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/session" @@ -105,6 +106,15 @@ func (b *Backend) DeleteSession(ctx context.Context, workspaceID, sessionID stri return ws.Sessions.Delete(ctx, sessionID) } +// UpdateSessionModels updates the models for a session in the given workspace. +func (b *Backend) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + ws, err := b.GetWorkspace(workspaceID) + if err != nil { + return err + } + return ws.Sessions.UpdateSessionModels(ctx, sessionID, models) +} + // ListUserMessages returns user-role messages for a session. func (b *Backend) ListUserMessages(ctx context.Context, workspaceID, sessionID string) ([]message.Message, error) { ws, err := b.GetWorkspace(workspaceID) diff --git a/internal/client/proto.go b/internal/client/proto.go index f444cfc04f1e185a4551d6ac43eae4d99f3a02ba..ab125d7b1091ea68fb250aa1389c2d881706904e 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -585,6 +585,23 @@ func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string) return nil } +// UpdateSessionModels updates the models for a session. +func (c *Client) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/models", workspaceID, sessionID), nil, + jsonBody(struct { + Models map[config.SelectedModelType]config.SelectedModel `json:"models"` + }{Models: models}), + http.Header{"Content-Type": []string{"application/json"}}) + if err != nil { + return fmt.Errorf("failed to update session models: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusNoContent { + return fmt.Errorf("failed to update session models: status code %d", rsp.StatusCode) + } + return nil +} + // ListUserMessages retrieves user-role messages for a session as proto types. func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil) diff --git a/internal/config/config.go b/internal/config/config.go index cee8ab8c4964bce56aa7c3ddffe98af115498776..e010a81b91ebb81ebb139372f1b8e3394cb72427 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -89,6 +89,27 @@ type SelectedModel struct { ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for the model"` } +func (a SelectedModel) Equal(b SelectedModel) bool { + return a.Model == b.Model && + a.Provider == b.Provider && + a.ReasoningEffort == b.ReasoningEffort && + a.Think == b.Think && + a.MaxTokens == b.MaxTokens && + ptrEqual(a.Temperature, b.Temperature) && + ptrEqual(a.TopP, b.TopP) && + ptrEqual(a.TopK, b.TopK) && + ptrEqual(a.FrequencyPenalty, b.FrequencyPenalty) && + ptrEqual(a.PresencePenalty, b.PresencePenalty) && + maps.Equal(a.ProviderOptions, b.ProviderOptions) +} + +func ptrEqual[T comparable](a, b *T) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + type ProviderConfig struct { // The provider's id. ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"` diff --git a/internal/config/selected_model_test.go b/internal/config/selected_model_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1a96b92e876043caef8dfe101ddcf8cab347ecd8 --- /dev/null +++ b/internal/config/selected_model_test.go @@ -0,0 +1,143 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSelectedModelEqual(t *testing.T) { + t.Parallel() + + t.Run("equal models", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.True(t, a.Equal(b)) + }) + + t.Run("different model", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o-mini", Provider: "openai"} + require.False(t, a.Equal(b)) + }) + + t.Run("different provider", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "anthropic"} + require.False(t, a.Equal(b)) + }) + + t.Run("different reasoning effort", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "high"} + b := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "low"} + require.False(t, a.Equal(b)) + }) + + t.Run("different think", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: true} + b := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: false} + require.False(t, a.Equal(b)) + }) + + t.Run("different max tokens", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 4096} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 8192} + require.False(t, a.Equal(b)) + }) + + t.Run("both nil pointers", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.True(t, a.Equal(b)) + }) + + t.Run("one nil one non-nil pointer", func(t *testing.T) { + t.Parallel() + temp := 0.7 + a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.False(t, a.Equal(b)) + require.False(t, b.Equal(a)) + }) + + t.Run("both non-nil equal pointers", func(t *testing.T) { + t.Parallel() + temp := 0.7 + a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp} + require.True(t, a.Equal(b)) + }) + + t.Run("both non-nil different pointers", func(t *testing.T) { + t.Parallel() + t1 := 0.7 + t2 := 0.9 + a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t1} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t2} + require.False(t, a.Equal(b)) + }) + + t.Run("nil ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai"} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.True(t, a.Equal(b)) + }) + + t.Run("empty ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}} + require.True(t, a.Equal(b)) + }) + + t.Run("different ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}} + b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "b"}} + require.False(t, a.Equal(b)) + }) + + t.Run("one nil one non-nil ProviderOptions", func(t *testing.T) { + t.Parallel() + a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}} + b := SelectedModel{Model: "gpt-4o", Provider: "openai"} + require.False(t, a.Equal(b)) + }) +} + +func TestPtrEqual(t *testing.T) { + t.Parallel() + + t.Run("both nil", func(t *testing.T) { + t.Parallel() + require.True(t, ptrEqual[int](nil, nil)) + }) + + t.Run("one nil", func(t *testing.T) { + t.Parallel() + v := 42 + require.False(t, ptrEqual(&v, nil)) + require.False(t, ptrEqual(nil, &v)) + }) + + t.Run("both equal", func(t *testing.T) { + t.Parallel() + v := 42 + require.True(t, ptrEqual(&v, &v)) + }) + + t.Run("both different", func(t *testing.T) { + t.Parallel() + a := 42 + b := 43 + require.False(t, ptrEqual(&a, &b)) + }) +} diff --git a/internal/db/db.go b/internal/db/db.go index 6237bba5892b240ddd3c3f018926e2eb0ef4355a..45e84f6e2246109ee1869a27b473db7b193c3f74 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -11,10 +11,10 @@ import ( ) type DBTX interface { - ExecContext(context.Context, string, ...any) (sql.Result, error) + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...any) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...any) *sql.Row + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row } func New(db DBTX) *Queries { @@ -132,6 +132,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil { return nil, fmt.Errorf("error preparing query UpdateSession: %w", err) } + if q.updateSessionModelsStmt, err = db.PrepareContext(ctx, updateSessionModels); err != nil { + return nil, fmt.Errorf("error preparing query UpdateSessionModels: %w", err) + } if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil { return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err) } @@ -320,6 +323,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing updateSessionStmt: %w", cerr) } } + if q.updateSessionModelsStmt != nil { + if cerr := q.updateSessionModelsStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing updateSessionModelsStmt: %w", cerr) + } + } if q.updateSessionTitleAndUsageStmt != nil { if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr) @@ -400,6 +408,7 @@ type Queries struct { renameSessionStmt *sql.Stmt updateMessageStmt *sql.Stmt updateSessionStmt *sql.Stmt + updateSessionModelsStmt *sql.Stmt updateSessionTitleAndUsageStmt *sql.Stmt } @@ -443,6 +452,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { renameSessionStmt: q.renameSessionStmt, updateMessageStmt: q.updateMessageStmt, updateSessionStmt: q.updateSessionStmt, + updateSessionModelsStmt: q.updateSessionModelsStmt, updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt, } } diff --git a/internal/db/migrations/20260401000000_add_models_to_sessions.sql b/internal/db/migrations/20260401000000_add_models_to_sessions.sql new file mode 100644 index 0000000000000000000000000000000000000000..639989eae3f6e4ef04791ab6aad422fbcc84b77d --- /dev/null +++ b/internal/db/migrations/20260401000000_add_models_to_sessions.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE sessions ADD COLUMN models TEXT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE sessions DROP COLUMN models; +-- +goose StatementEnd \ No newline at end of file diff --git a/internal/db/models.go b/internal/db/models.go index 20034fb00a935bed7c4cfe4906dba66dd380ed64..c9220a4ab714cd2ad4a5d1d23688bcb5981b1150 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -49,4 +49,5 @@ type Session struct { CreatedAt int64 `json:"created_at"` SummaryMessageID sql.NullString `json:"summary_message_id"` Todos sql.NullString `json:"todos"` + Models sql.NullString `json:"models"` } diff --git a/internal/db/querier.go b/internal/db/querier.go index 9031505a3db825f2c21d83e005046323bde3a6c2..3886d6388c0123636e560e8d4a9b977993f5be39 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -45,6 +45,7 @@ type Querier interface { RenameSession(ctx context.Context, arg RenameSessionParams) error UpdateMessage(ctx context.Context, arg UpdateMessageParams) error UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) + UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error } diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index 685948e60e84ec4df66e4d5d1c9645a9ff1fb43f..4bf7ee8febbfd611f28f1b4d32e5c9f1690fd4e3 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -33,7 +33,7 @@ INSERT INTO sessions ( null, strftime('%s', 'now'), strftime('%s', 'now') -) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models ` type CreateSessionParams struct { @@ -69,6 +69,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ) return i, err } @@ -84,7 +85,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error { } const getLastSession = `-- name: GetLastSession :one -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models FROM sessions ORDER BY updated_at DESC LIMIT 1 @@ -105,12 +106,13 @@ func (q *Queries) GetLastSession(ctx context.Context) (Session, error) { &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ) return i, err } const getSessionByID = `-- name: GetSessionByID :one -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models FROM sessions WHERE id = ? LIMIT 1 ` @@ -130,12 +132,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ) return i, err } const listSessions = `-- name: ListSessions :many -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models FROM sessions WHERE parent_session_id is NULL ORDER BY updated_at DESC @@ -162,6 +165,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, ); err != nil { return nil, err } @@ -203,7 +207,7 @@ SET cost = ?, todos = ? WHERE id = ? -RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos +RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models ` type UpdateSessionParams struct { @@ -239,6 +243,39 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S &i.CreatedAt, &i.SummaryMessageID, &i.Todos, + &i.Models, + ) + return i, err +} + +const updateSessionModels = `-- name: UpdateSessionModels :one +UPDATE sessions +SET models = ? +WHERE id = ? +RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models +` + +type UpdateSessionModelsParams struct { + Models sql.NullString `json:"models"` + ID string `json:"id"` +} + +func (q *Queries) UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error) { + row := q.queryRow(ctx, q.updateSessionModelsStmt, updateSessionModels, arg.Models, arg.ID) + var i Session + err := row.Scan( + &i.ID, + &i.ParentSessionID, + &i.Title, + &i.MessageCount, + &i.PromptTokens, + &i.CompletionTokens, + &i.Cost, + &i.UpdatedAt, + &i.CreatedAt, + &i.SummaryMessageID, + &i.Todos, + &i.Models, ) return i, err } diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index 44c1609ecfbc3867bea827088fcbcff6e718427b..75ab5fa5ae7efb50f1e32328fedd46adce77a9cd 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -69,6 +69,12 @@ SET title = ? WHERE id = ?; +-- name: UpdateSessionModels :one +UPDATE sessions +SET models = ? +WHERE id = ? +RETURNING *; + -- name: DeleteSession :exec DELETE FROM sessions WHERE id = ?; diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 9c84c6c8bf0c2f14da75933f8eebd7a36ca534ba..6d7ee4b150a132013a80c2070588e20ef8bdf456 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -56,7 +56,7 @@ type AgentSession struct { // IsZero checks if the AgentSession is zero-valued. func (a AgentSession) IsZero() bool { - return a == AgentSession{} + return a.ID == "" && !a.IsBusy } // PermissionAction represents an action taken on a permission request. diff --git a/internal/proto/session.go b/internal/proto/session.go index 846ac592017e6ce447c6c6a94535d9317adad7d8..fdf31b6c7746998db5c4c24a7820551357a5c33c 100644 --- a/internal/proto/session.go +++ b/internal/proto/session.go @@ -1,15 +1,39 @@ package proto +// SelectedModelType represents the type of model selection (large or small). +type SelectedModelType string + +const ( + SelectedModelTypeLarge SelectedModelType = "large" + SelectedModelTypeSmall SelectedModelType = "small" +) + +// SelectedModel represents a model selection with provider and configuration. +type SelectedModel struct { + Model string `json:"model"` + Provider string `json:"provider"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Think bool `json:"think,omitempty"` + MaxTokens int64 `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int64 `json:"top_k,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ProviderOptions map[string]any `json:"provider_options,omitempty"` +} + // Session represents a session in the proto layer. type Session struct { - ID string `json:"id"` - ParentSessionID string `json:"parent_session_id"` - Title string `json:"title"` - MessageCount int64 `json:"message_count"` - PromptTokens int64 `json:"prompt_tokens"` - CompletionTokens int64 `json:"completion_tokens"` - SummaryMessageID string `json:"summary_message_id"` - Cost float64 `json:"cost"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` + ID string `json:"id"` + ParentSessionID string `json:"parent_session_id"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + SummaryMessageID string `json:"summary_message_id"` + Cost float64 `json:"cost"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + Models map[SelectedModelType]SelectedModel `json:"models,omitempty"` } diff --git a/internal/proto/session_test.go b/internal/proto/session_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7634be230f744e11836ed1e85545256f09c31bac --- /dev/null +++ b/internal/proto/session_test.go @@ -0,0 +1,168 @@ +package proto + +import ( + "encoding/json" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/require" +) + +func TestModelsRoundTrip(t *testing.T) { + t.Parallel() + + t.Run("nil models", func(t *testing.T) { + t.Parallel() + s := Session{ + ID: "test-id", + Title: "Test", + } + data, err := json.Marshal(s) + require.NoError(t, err) + var decoded Session + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Nil(t, decoded.Models) + }) + + t.Run("populated models", func(t *testing.T) { + t.Parallel() + temp := 0.7 + s := Session{ + ID: "test-id", + Title: "Test", + Models: map[SelectedModelType]SelectedModel{ + SelectedModelTypeLarge: { + Model: "gpt-4o", + Provider: "openai", + ReasoningEffort: "high", + Think: true, + MaxTokens: 4096, + Temperature: &temp, + ProviderOptions: map[string]any{"key": "value"}, + }, + SelectedModelTypeSmall: { + Model: "gpt-4o-mini", + Provider: "openai", + }, + }, + } + data, err := json.Marshal(s) + require.NoError(t, err) + var decoded Session + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Equal(t, "gpt-4o", decoded.Models[SelectedModelTypeLarge].Model) + require.Equal(t, "openai", decoded.Models[SelectedModelTypeLarge].Provider) + require.Equal(t, "high", decoded.Models[SelectedModelTypeLarge].ReasoningEffort) + require.True(t, decoded.Models[SelectedModelTypeLarge].Think) + require.Equal(t, int64(4096), decoded.Models[SelectedModelTypeLarge].MaxTokens) + require.NotNil(t, decoded.Models[SelectedModelTypeLarge].Temperature) + require.Equal(t, 0.7, *decoded.Models[SelectedModelTypeLarge].Temperature) + require.Equal(t, "gpt-4o-mini", decoded.Models[SelectedModelTypeSmall].Model) + }) + + t.Run("empty map models", func(t *testing.T) { + t.Parallel() + s := Session{ + ID: "test-id", + Title: "Test", + Models: map[SelectedModelType]SelectedModel{}, + } + data, err := json.Marshal(s) + require.NoError(t, err) + var decoded Session + require.NoError(t, json.Unmarshal(data, &decoded)) + // Empty map with omitempty is dropped during marshaling. + require.Nil(t, decoded.Models) + }) +} + +func TestProtoToDomainRoundTrip(t *testing.T) { + t.Parallel() + + t.Run("models through proto", func(t *testing.T) { + t.Parallel() + temp := 0.7 + domainModels := map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Model: "gpt-4o", + Provider: "openai", + ReasoningEffort: "high", + Think: true, + MaxTokens: 4096, + Temperature: &temp, + ProviderOptions: map[string]any{"key": "value"}, + }, + } + + // Domain → Proto + protoModels := convertModelsToProtoLocal(domainModels) + require.Equal(t, SelectedModelTypeLarge, SelectedModelType(config.SelectedModelTypeLarge)) + require.Equal(t, "gpt-4o", protoModels[SelectedModelTypeLarge].Model) + require.Equal(t, "openai", protoModels[SelectedModelTypeLarge].Provider) + + // Proto → Domain + result := convertModelsFromProtoLocal(protoModels) + require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model) + require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider) + require.Equal(t, "high", result[config.SelectedModelTypeLarge].ReasoningEffort) + require.True(t, result[config.SelectedModelTypeLarge].Think) + require.Equal(t, int64(4096), result[config.SelectedModelTypeLarge].MaxTokens) + require.NotNil(t, result[config.SelectedModelTypeLarge].Temperature) + require.Equal(t, 0.7, *result[config.SelectedModelTypeLarge].Temperature) + }) + + t.Run("nil models round-trip", func(t *testing.T) { + t.Parallel() + protoModels := convertModelsToProtoLocal(nil) + require.Nil(t, protoModels) + + domainModels := convertModelsFromProtoLocal(nil) + require.Nil(t, domainModels) + }) +} + +func convertModelsToProtoLocal(models map[config.SelectedModelType]config.SelectedModel) map[SelectedModelType]SelectedModel { + if models == nil { + return nil + } + result := make(map[SelectedModelType]SelectedModel, len(models)) + for k, v := range models { + result[SelectedModelType(k)] = SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result +} + +func convertModelsFromProtoLocal(models map[SelectedModelType]SelectedModel) map[config.SelectedModelType]config.SelectedModel { + if models == nil { + return nil + } + result := make(map[config.SelectedModelType]config.SelectedModel, len(models)) + for k, v := range models { + result[config.SelectedModelType(k)] = config.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result +} diff --git a/internal/server/events.go b/internal/server/events.go index 752311666bb6fcc2b1efde4d037711eaafaa0162..d3b1341a1e8166c2eb70d804f100a5b2fd08e614 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -8,6 +8,7 @@ import ( "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" @@ -137,9 +138,33 @@ func sessionToProto(s session.Session) proto.Session { Cost: s.Cost, CreatedAt: s.CreatedAt, UpdatedAt: s.UpdatedAt, + Models: convertModelsToProto(s.Models), } } +func convertModelsToProto(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel { + if models == nil { + return nil + } + result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models)) + for k, v := range models { + result[proto.SelectedModelType(k)] = proto.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result +} + func fileToProto(f history.File) proto.File { return proto.File{ ID: f.ID, diff --git a/internal/server/proto.go b/internal/server/proto.go index af34131810a7af8c3672fe460198d25afe9ba064..2e9f0ddef2f7b3c072b751ba689be8d6e4517655 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/session" ) @@ -460,6 +461,39 @@ func (c *controllerV1) handleDeleteWorkspaceSession(w http.ResponseWriter, r *ht w.WriteHeader(http.StatusOK) } +// handlePostWorkspaceSessionModels updates the models for a session. +// +// @Summary Update session models +// @Tags sessions +// @Accept json +// @Param id path string true "Workspace ID" +// @Param sid path string true "Session ID" +// @Param body body map[config.SelectedModelType]config.SelectedModel true "Models" +// @Success 204 +// @Failure 400 {object} proto.Error +// @Failure 500 {object} proto.Error +// @Router /workspaces/{id}/sessions/{sid}/models [post] +func (c *controllerV1) handlePostWorkspaceSessionModels(w http.ResponseWriter, r *http.Request) { + workspaceID := r.PathValue("id") + sessionID := r.PathValue("sid") + + var req struct { + Models map[config.SelectedModelType]config.SelectedModel `json:"models"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if err := c.backend.UpdateSessionModels(r.Context(), workspaceID, sessionID, req.Models); err != nil { + c.handleError(w, r, err) + return + } + + w.WriteHeader(http.StatusNoContent) +} + // handleGetWorkspaceSessionUserMessages returns user messages for a session. // // @Summary Get user messages for session diff --git a/internal/server/server.go b/internal/server/server.go index 9ac4dba4c908050a0381b49258941d1b3a931970..275c79ec031a50876d7c72e8e2868b1f6c378950 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -122,6 +122,7 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession) mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession) mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession) + mux.HandleFunc("POST /v1/workspaces/{id}/sessions/{sid}/models", c.handlePostWorkspaceSessionModels) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages) mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages) diff --git a/internal/session/session.go b/internal/session/session.go index 66bd9f4c9a12916d02c6d22ed7d51f81d74efdfd..0271bf9a293095c259f114bce2de5e0d901e5ec1 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -8,6 +8,7 @@ import ( "log/slog" "strings" + "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" "github.com/charmbracelet/crush/internal/pubsub" @@ -56,6 +57,7 @@ type Session struct { SummaryMessageID string Cost float64 Todos []Todo + Models map[config.SelectedModelType]config.SelectedModel CreatedAt int64 UpdatedAt int64 } @@ -69,6 +71,8 @@ type Service interface { GetLast(ctx context.Context) (Session, error) List(ctx context.Context) ([]Session, error) Save(ctx context.Context, session Session) (Session, error) + SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) + UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error Rename(ctx context.Context, id string, title string) error Delete(ctx context.Context, id string) error @@ -242,6 +246,10 @@ func (s service) fromDBItem(item db.Session) Session { if err != nil { slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err) } + models, err := unmarshalModels(item.Models.String) + if err != nil { + slog.Error("Failed to unmarshal models", "session_id", item.ID, "error", err) + } return Session{ ID: item.ID, ParentSessionID: item.ParentSessionID.String, @@ -252,6 +260,7 @@ func (s service) fromDBItem(item db.Session) Session { SummaryMessageID: item.SummaryMessageID.String, Cost: item.Cost, Todos: todos, + Models: models, CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, } @@ -279,6 +288,51 @@ func unmarshalTodos(data string) ([]Todo, error) { return todos, nil } +func marshalModels(models map[config.SelectedModelType]config.SelectedModel) (string, error) { + if len(models) == 0 { + return "", nil + } + data, err := json.Marshal(models) + if err != nil { + return "", err + } + return string(data), nil +} + +func unmarshalModels(data string) (map[config.SelectedModelType]config.SelectedModel, error) { + if data == "" { + return nil, nil + } + var models map[config.SelectedModelType]config.SelectedModel + if err := json.Unmarshal([]byte(data), &models); err != nil { + return nil, err + } + return models, nil +} + +func (s *service) UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error { + modelsJSON, err := marshalModels(models) + if err != nil { + return fmt.Errorf("failed to marshal models: %w", err) + } + _, err = s.q.UpdateSessionModels(ctx, db.UpdateSessionModelsParams{ + Models: sql.NullString{String: modelsJSON, Valid: modelsJSON != ""}, + ID: id, + }) + return err +} + +func (s *service) SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) { + saved, err := s.Save(ctx, session) + if err != nil { + return Session{}, err + } + if err := s.UpdateSessionModels(ctx, session.ID, models); err != nil { + return Session{}, fmt.Errorf("failed to persist models: %w", err) + } + return saved, nil +} + func NewService(q *db.Queries, conn *sql.DB) Service { broker := pubsub.NewBroker[Session]() return &service{ diff --git a/internal/session/session_test.go b/internal/session/session_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c8f06795006081e7a05750fd6970b2d79076819f --- /dev/null +++ b/internal/session/session_test.go @@ -0,0 +1,151 @@ +package session + +import ( + "database/sql" + "testing" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/db" + "github.com/stretchr/testify/require" +) + +func TestMarshalModels(t *testing.T) { + t.Parallel() + + t.Run("empty", func(t *testing.T) { + t.Parallel() + result, err := marshalModels(map[config.SelectedModelType]config.SelectedModel{}) + require.NoError(t, err) + require.Equal(t, "", result) + }) + + t.Run("nil", func(t *testing.T) { + t.Parallel() + result, err := marshalModels(nil) + require.NoError(t, err) + require.Equal(t, "", result) + }) + + t.Run("single entry", func(t *testing.T) { + t.Parallel() + models := map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Model: "claude-sonnet-4-20250514", + Provider: "anthropic", + }, + } + result, err := marshalModels(models) + require.NoError(t, err) + require.Contains(t, result, "claude-sonnet-4-20250514") + require.Contains(t, result, "anthropic") + }) + + t.Run("round-trip", func(t *testing.T) { + t.Parallel() + temp := 0.7 + topP := 0.9 + topK := int64(50) + freqPen := 0.1 + presPen := 0.2 + models := map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Model: "gpt-4o", + Provider: "openai", + ReasoningEffort: "high", + Think: true, + MaxTokens: 4096, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + FrequencyPenalty: &freqPen, + PresencePenalty: &presPen, + ProviderOptions: map[string]any{"key": "value"}, + }, + config.SelectedModelTypeSmall: { + Model: "gpt-4o-mini", + Provider: "openai", + }, + } + data, err := marshalModels(models) + require.NoError(t, err) + result, err := unmarshalModels(data) + require.NoError(t, err) + require.Equal(t, models, result) + }) +} + +func TestUnmarshalModels(t *testing.T) { + t.Parallel() + + t.Run("empty string", func(t *testing.T) { + t.Parallel() + result, err := unmarshalModels("") + require.NoError(t, err) + require.Nil(t, result) + }) + + t.Run("valid JSON", func(t *testing.T) { + t.Parallel() + data := `{"large":{"model":"gpt-4o","provider":"openai"}}` + result, err := unmarshalModels(data) + require.NoError(t, err) + require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model) + require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider) + }) + + t.Run("invalid JSON", func(t *testing.T) { + t.Parallel() + _, err := unmarshalModels("{invalid}") + require.Error(t, err) + }) +} + +func TestFromDBItemWithModels(t *testing.T) { + t.Parallel() + + t.Run("null models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{Valid: false} + result := service{}.fromDBItem(item) + require.Nil(t, result.Models) + }) + + t.Run("empty models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{String: "", Valid: false} + result := service{}.fromDBItem(item) + require.Nil(t, result.Models) + }) + + t.Run("valid models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{ + String: `{"large":{"model":"gpt-4o","provider":"openai"}}`, + Valid: true, + } + result := service{}.fromDBItem(item) + require.NotNil(t, result.Models) + require.Equal(t, "gpt-4o", result.Models[config.SelectedModelTypeLarge].Model) + }) + + t.Run("invalid JSON models", func(t *testing.T) { + t.Parallel() + item := testDBSession() + item.Models = sql.NullString{ + String: "{invalid}", + Valid: true, + } + result := service{}.fromDBItem(item) + require.Nil(t, result.Models) + }) +} + +func testDBSession() db.Session { + return db.Session{ + ID: "test-id", + Title: "Test", + } +} diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index e5724987898a755f06a2b2b4643cce9ff94a3d83..15e8ec8cdb7b5f084cdda9e2d860a6933fa45efb 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -498,6 +498,29 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.setState(uiChat, m.focus) m.session = msg.session m.sessionFiles = msg.files + if msg.session.Models != nil { + cfg := m.com.Config() + anyChanged := false + for modelType, selectedModel := range msg.session.Models { + if current, ok := cfg.Models[modelType]; ok && current.Equal(selectedModel) { + continue + } + if cfg.GetModel(selectedModel.Provider, selectedModel.Model) != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, modelType, selectedModel); err != nil { + slog.Error("Failed to restore model", "type", modelType, "error", err) + } + anyChanged = true + } + } + if anyChanged { + cmds = append(cmds, func() tea.Msg { + if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil { + return util.ReportError(err) + } + return nil + }) + } + } cmds = append(cmds, m.startLSPs(msg.lspFilePaths())) msgs, err := m.com.Workspace.ListMessages(context.Background(), m.session.ID) if err != nil { @@ -1376,6 +1399,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { } return util.NewInfoMsg("Thinking mode " + status) }) + if m.session != nil { + cmds = append(cmds, m.saveSessionModels()) + } m.dialog.CloseDialog(dialog.CommandsID) case dialog.ActionToggleTransparentBackground: cmds = append(cmds, func() tea.Msg { @@ -1465,6 +1491,10 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { return util.NewInfoMsg(modelMsg) }) + if m.session != nil { + cmds = append(cmds, m.saveSessionModels()) + } + m.dialog.CloseDialog(dialog.APIKeyInputID) m.dialog.CloseDialog(dialog.OAuthID) m.dialog.CloseDialog(dialog.ModelsID) @@ -1505,6 +1535,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { m.com.Workspace.UpdateAgentModel(context.TODO()) return util.NewInfoMsg("Reasoning effort set to " + msg.Effort) }) + if m.session != nil { + cmds = append(cmds, m.saveSessionModels()) + } m.dialog.CloseDialog(dialog.ReasoningID) case dialog.ActionPermissionResponse: m.dialog.CloseDialog(dialog.PermissionsID) @@ -3652,3 +3685,18 @@ func renderLogo(t *styles.Styles, compact bool, width int) string { Width: width, }) } + +func (m *UI) saveSessionModels() tea.Cmd { + session := m.session + if session == nil { + return nil + } + models := m.com.Config().Models + return func() tea.Msg { + err := m.com.Workspace.UpdateSessionModels(context.Background(), session.ID, models) + if err != nil { + return util.ReportError(err) + } + return nil + } +} diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go index 57b1228e7eacb28a16141283ee2703a33511bd18..7a5c7a01aaefa96195e07c3aaaebfacf9b0faa91 100644 --- a/internal/workspace/app_workspace.go +++ b/internal/workspace/app_workspace.go @@ -59,6 +59,10 @@ func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) erro return w.app.Sessions.Delete(ctx, sessionID) } +func (w *AppWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + return w.app.Sessions.UpdateSessionModels(ctx, sessionID, models) +} + func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID) } diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index 7c4e1408882cc70859ea2ab05981461d262513e9..1dbda7aa3ee5b9a1c73cee2681ecd383dd9e98ed 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -119,6 +119,10 @@ func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) e return w.client.DeleteSession(ctx, w.workspaceID(), sessionID) } +func (w *ClientWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error { + return w.client.UpdateSessionModels(ctx, w.workspaceID(), sessionID, models) +} + func (w *ClientWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string { return fmt.Sprintf("%s$$%s", messageID, toolCallID) } @@ -673,6 +677,7 @@ func protoToSession(s proto.Session) session.Session { Cost: s.Cost, CreatedAt: s.CreatedAt, UpdatedAt: s.UpdatedAt, + Models: convertModelsFromProto(s.Models), } } @@ -769,5 +774,52 @@ func sessionToProto(s session.Session) proto.Session { Cost: s.Cost, CreatedAt: s.CreatedAt, UpdatedAt: s.UpdatedAt, + Models: convertModelsToProtoClient(s.Models), + } +} + +func convertModelsFromProto(models map[proto.SelectedModelType]proto.SelectedModel) map[config.SelectedModelType]config.SelectedModel { + if models == nil { + return nil + } + result := make(map[config.SelectedModelType]config.SelectedModel, len(models)) + for k, v := range models { + result[config.SelectedModelType(k)] = config.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } } + return result +} + +func convertModelsToProtoClient(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel { + if models == nil { + return nil + } + result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models)) + for k, v := range models { + result[proto.SelectedModelType(k)] = proto.SelectedModel{ + Model: v.Model, + Provider: v.Provider, + ReasoningEffort: v.ReasoningEffort, + Think: v.Think, + MaxTokens: v.MaxTokens, + Temperature: v.Temperature, + TopP: v.TopP, + TopK: v.TopK, + FrequencyPenalty: v.FrequencyPenalty, + PresencePenalty: v.PresencePenalty, + ProviderOptions: v.ProviderOptions, + } + } + return result } diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go index 02c54c616f3251140bbee441451c3a4cb14845bd..486edf01bd8fce6f975234a1840237aba4416929 100644 --- a/internal/workspace/workspace.go +++ b/internal/workspace/workspace.go @@ -65,6 +65,7 @@ type Workspace interface { ListSessions(ctx context.Context) ([]session.Session, error) SaveSession(ctx context.Context, sess session.Session) (session.Session, error) DeleteSession(ctx context.Context, sessionID string) error + UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error CreateAgentToolSessionID(messageID, toolCallID string) string ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)