@@ -18,10 +18,11 @@ import (
// Server wraps the MCP server and implements planning tools
type Server struct {
- config *config.Config
- logger *slog.Logger
- planner *planning.Manager
- server *server.MCPServer
+ config *config.Config
+ logger *slog.Logger
+ planner *planning.Manager
+ validator Validator
+ server *server.MCPServer
}
// New creates a new MCP server
@@ -37,9 +38,10 @@ func New(cfg *config.Config, logger *slog.Logger, planner *planning.Manager) (*S
}
s := &Server{
- config: cfg,
- logger: logger,
- planner: planner,
+ config: cfg,
+ logger: logger,
+ planner: planner,
+ validator: NewPlanningValidator(cfg),
}
// Create MCP server
@@ -165,10 +167,15 @@ func (s *Server) registerTools(mcpServer *server.MCPServer) {
func (s *Server) handleSetGoal(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
s.logger.Info("Received project_management__set_goal tool call")
- // Parse and validate request
+ // Parse request
var req SetGoalRequest
- if err := parseAndValidate(request.GetArguments(), &req); err != nil {
- return createErrorResult(fmt.Sprintf("Invalid request: %v", err)), nil
+ if err := parseRequest(request.GetArguments(), &req); err != nil {
+ return createErrorResult(fmt.Sprintf("Invalid request format: %v", err)), nil
+ }
+
+ // Validate request
+ if err := s.validator.ValidateSetGoalRequest(req); err != nil {
+ return createErrorResult(fmt.Sprintf("Validation error: %v", err)), nil
}
// Set goal
@@ -186,10 +193,15 @@ func (s *Server) handleSetGoal(ctx context.Context, request mcp.CallToolRequest)
func (s *Server) handleChangeGoal(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
s.logger.Info("Received project_management__change_goal tool call")
- // Parse and validate request
+ // Parse request
var req ChangeGoalRequest
- if err := parseAndValidate(request.GetArguments(), &req); err != nil {
- return createErrorResult(fmt.Sprintf("Invalid request: %v", err)), nil
+ if err := parseRequest(request.GetArguments(), &req); err != nil {
+ return createErrorResult(fmt.Sprintf("Invalid request format: %v", err)), nil
+ }
+
+ // Validate request
+ if err := s.validator.ValidateChangeGoalRequest(req); err != nil {
+ return createErrorResult(fmt.Sprintf("Validation error: %v", err)), nil
}
// Change goal
@@ -207,10 +219,15 @@ func (s *Server) handleChangeGoal(ctx context.Context, request mcp.CallToolReque
func (s *Server) handleAddTasks(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
s.logger.Info("Received project_management__add_tasks tool call")
- // Parse and validate request
+ // Parse request
var req AddTasksRequest
- if err := parseAndValidate(request.GetArguments(), &req); err != nil {
- return createErrorResult(fmt.Sprintf("Invalid request: %v", err)), nil
+ if err := parseRequest(request.GetArguments(), &req); err != nil {
+ return createErrorResult(fmt.Sprintf("Invalid request format: %v", err)), nil
+ }
+
+ // Validate request
+ if err := s.validator.ValidateAddTasksRequest(req); err != nil {
+ return createErrorResult(fmt.Sprintf("Validation error: %v", err)), nil
}
// Convert MCP task inputs to planning task inputs
@@ -252,10 +269,15 @@ func (s *Server) handleAddTasks(ctx context.Context, request mcp.CallToolRequest
func (s *Server) handleGetTasks(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
s.logger.Info("Received project_management__get_tasks tool call")
- // Parse and validate request
+ // Parse request
var req GetTasksRequest
- if err := parseAndValidate(request.GetArguments(), &req); err != nil {
- return createErrorResult(fmt.Sprintf("Invalid request: %v", err)), nil
+ if err := parseRequest(request.GetArguments(), &req); err != nil {
+ return createErrorResult(fmt.Sprintf("Invalid request format: %v", err)), nil
+ }
+
+ // Validate request
+ if err := s.validator.ValidateGetTasksRequest(req); err != nil {
+ return createErrorResult(fmt.Sprintf("Validation error: %v", err)), nil
}
// Default status to "all" if empty
@@ -278,10 +300,15 @@ func (s *Server) handleGetTasks(ctx context.Context, request mcp.CallToolRequest
func (s *Server) handleUpdateTaskStatuses(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
s.logger.Info("Received project_management__update_task_statuses tool call")
- // Parse and validate request
+ // Parse request
var req UpdateTaskStatusesRequest
- if err := parseAndValidate(request.GetArguments(), &req); err != nil {
- return createErrorResult(fmt.Sprintf("Invalid request: %v", err)), nil
+ if err := parseRequest(request.GetArguments(), &req); err != nil {
+ return createErrorResult(fmt.Sprintf("Invalid request format: %v", err)), nil
+ }
+
+ // Validate request
+ if err := s.validator.ValidateUpdateTaskStatusesRequest(req); err != nil {
+ return createErrorResult(fmt.Sprintf("Validation error: %v", err)), nil
}
// Convert MCP task update inputs to planning task updates
@@ -308,10 +335,15 @@ func (s *Server) handleUpdateTaskStatuses(ctx context.Context, request mcp.CallT
func (s *Server) handleDeleteTasks(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
s.logger.Info("Received project_management__delete_tasks tool call")
- // Parse and validate request
+ // Parse request
var req DeleteTasksRequest
- if err := parseAndValidate(request.GetArguments(), &req); err != nil {
- return createErrorResult(fmt.Sprintf("Invalid request: %v", err)), nil
+ if err := parseRequest(request.GetArguments(), &req); err != nil {
+ return createErrorResult(fmt.Sprintf("Invalid request format: %v", err)), nil
+ }
+
+ // Validate request
+ if err := s.validator.ValidateDeleteTasksRequest(req); err != nil {
+ return createErrorResult(fmt.Sprintf("Validation error: %v", err)), nil
}
// Delete tasks
@@ -0,0 +1,566 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+
+package mcp
+
+import (
+ "strings"
+ "testing"
+
+ "git.sr.ht/~amolith/planning-mcp-server/internal/config"
+)
+
+func TestPlanningValidator_ValidateSetGoalRequest(t *testing.T) {
+ cfg := &config.Config{
+ Planning: config.PlanningConfig{
+ MaxGoalLength: 100,
+ },
+ }
+ validator := NewPlanningValidator(cfg)
+
+ tests := []struct {
+ name string
+ req SetGoalRequest
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "valid request",
+ req: SetGoalRequest{
+ Title: "Test Goal",
+ Description: "Valid description",
+ },
+ wantErr: false,
+ },
+ {
+ name: "empty title",
+ req: SetGoalRequest{
+ Title: "",
+ Description: "Valid description",
+ },
+ wantErr: true,
+ errMsg: "title is required",
+ },
+ {
+ name: "empty description",
+ req: SetGoalRequest{
+ Title: "Valid title",
+ Description: "",
+ },
+ wantErr: true,
+ errMsg: "description is required",
+ },
+ {
+ name: "title too long",
+ req: SetGoalRequest{
+ Title: strings.Repeat("x", 101),
+ Description: "Valid description",
+ },
+ wantErr: true,
+ errMsg: "title too long",
+ },
+ {
+ name: "description too long",
+ req: SetGoalRequest{
+ Title: "Valid title",
+ Description: strings.Repeat("x", 101),
+ },
+ wantErr: true,
+ errMsg: "description too long",
+ },
+ {
+ name: "title at max length",
+ req: SetGoalRequest{
+ Title: strings.Repeat("x", 100),
+ Description: "Valid description",
+ },
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateSetGoalRequest(tt.req)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("ValidateSetGoalRequest() expected error, got nil")
+ return
+ }
+ if !strings.Contains(err.Error(), tt.errMsg) {
+ t.Errorf("ValidateSetGoalRequest() error = %v, want error containing %v", err, tt.errMsg)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("ValidateSetGoalRequest() error = %v, want nil", err)
+ }
+ }
+ })
+ }
+}
+
+func TestPlanningValidator_ValidateChangeGoalRequest(t *testing.T) {
+ cfg := &config.Config{
+ Planning: config.PlanningConfig{
+ MaxGoalLength: 50,
+ },
+ }
+ validator := NewPlanningValidator(cfg)
+
+ tests := []struct {
+ name string
+ req ChangeGoalRequest
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "valid request",
+ req: ChangeGoalRequest{
+ Title: "New Goal",
+ Description: "New description",
+ Reason: "Valid reason",
+ },
+ wantErr: false,
+ },
+ {
+ name: "empty title",
+ req: ChangeGoalRequest{
+ Title: "",
+ Description: "Valid description",
+ Reason: "Valid reason",
+ },
+ wantErr: true,
+ errMsg: "title is required",
+ },
+ {
+ name: "empty description",
+ req: ChangeGoalRequest{
+ Title: "Valid title",
+ Description: "",
+ Reason: "Valid reason",
+ },
+ wantErr: true,
+ errMsg: "description is required",
+ },
+ {
+ name: "empty reason",
+ req: ChangeGoalRequest{
+ Title: "Valid title",
+ Description: "Valid description",
+ Reason: "",
+ },
+ wantErr: true,
+ errMsg: "reason is required",
+ },
+ {
+ name: "reason too long",
+ req: ChangeGoalRequest{
+ Title: "Valid title",
+ Description: "Valid description",
+ Reason: strings.Repeat("x", 51),
+ },
+ wantErr: true,
+ errMsg: "reason too long",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateChangeGoalRequest(tt.req)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("ValidateChangeGoalRequest() expected error, got nil")
+ return
+ }
+ if !strings.Contains(err.Error(), tt.errMsg) {
+ t.Errorf("ValidateChangeGoalRequest() error = %v, want error containing %v", err, tt.errMsg)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("ValidateChangeGoalRequest() error = %v, want nil", err)
+ }
+ }
+ })
+ }
+}
+
+func TestPlanningValidator_ValidateAddTasksRequest(t *testing.T) {
+ cfg := &config.Config{
+ Planning: config.PlanningConfig{
+ MaxTaskLength: 50,
+ },
+ }
+ validator := NewPlanningValidator(cfg)
+
+ tests := []struct {
+ name string
+ req AddTasksRequest
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "valid single task",
+ req: AddTasksRequest{
+ Tasks: []MCPTaskInput{
+ {Title: "Valid task", Description: "Valid description"},
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "valid multiple tasks",
+ req: AddTasksRequest{
+ Tasks: []MCPTaskInput{
+ {Title: "Task 1", Description: "Description 1"},
+ {Title: "Task 2", Description: ""},
+ {Title: "Task 3", Description: "Description 3"},
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "empty tasks array",
+ req: AddTasksRequest{
+ Tasks: []MCPTaskInput{},
+ },
+ wantErr: true,
+ errMsg: "at least one task is required",
+ },
+ {
+ name: "task with empty title",
+ req: AddTasksRequest{
+ Tasks: []MCPTaskInput{
+ {Title: "", Description: "Valid description"},
+ },
+ },
+ wantErr: true,
+ errMsg: "task 0: title is required",
+ },
+ {
+ name: "task with title too long",
+ req: AddTasksRequest{
+ Tasks: []MCPTaskInput{
+ {Title: strings.Repeat("x", 51), Description: "Valid description"},
+ },
+ },
+ wantErr: true,
+ errMsg: "task 0: title too long",
+ },
+ {
+ name: "task with description too long",
+ req: AddTasksRequest{
+ Tasks: []MCPTaskInput{
+ {Title: "Valid title", Description: strings.Repeat("x", 51)},
+ },
+ },
+ wantErr: true,
+ errMsg: "task 0: description too long",
+ },
+ {
+ name: "second task invalid",
+ req: AddTasksRequest{
+ Tasks: []MCPTaskInput{
+ {Title: "Valid task", Description: "Valid description"},
+ {Title: "", Description: "Another description"},
+ },
+ },
+ wantErr: true,
+ errMsg: "task 1: title is required",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateAddTasksRequest(tt.req)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("ValidateAddTasksRequest() expected error, got nil")
+ return
+ }
+ if !strings.Contains(err.Error(), tt.errMsg) {
+ t.Errorf("ValidateAddTasksRequest() error = %v, want error containing %v", err, tt.errMsg)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("ValidateAddTasksRequest() error = %v, want nil", err)
+ }
+ }
+ })
+ }
+}
+
+func TestPlanningValidator_ValidateGetTasksRequest(t *testing.T) {
+ cfg := &config.Config{
+ Planning: config.PlanningConfig{},
+ }
+ validator := NewPlanningValidator(cfg)
+
+ tests := []struct {
+ name string
+ req GetTasksRequest
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "empty status (valid)",
+ req: GetTasksRequest{Status: ""},
+ wantErr: false,
+ },
+ {
+ name: "valid status all",
+ req: GetTasksRequest{Status: "all"},
+ wantErr: false,
+ },
+ {
+ name: "valid status pending",
+ req: GetTasksRequest{Status: "pending"},
+ wantErr: false,
+ },
+ {
+ name: "valid status in_progress",
+ req: GetTasksRequest{Status: "in_progress"},
+ wantErr: false,
+ },
+ {
+ name: "valid status completed",
+ req: GetTasksRequest{Status: "completed"},
+ wantErr: false,
+ },
+ {
+ name: "valid status cancelled",
+ req: GetTasksRequest{Status: "cancelled"},
+ wantErr: false,
+ },
+ {
+ name: "valid status failed",
+ req: GetTasksRequest{Status: "failed"},
+ wantErr: false,
+ },
+ {
+ name: "invalid status",
+ req: GetTasksRequest{Status: "invalid"},
+ wantErr: true,
+ errMsg: "invalid status 'invalid'",
+ },
+ {
+ name: "invalid status case sensitive",
+ req: GetTasksRequest{Status: "Pending"},
+ wantErr: true,
+ errMsg: "invalid status 'Pending'",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateGetTasksRequest(tt.req)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("ValidateGetTasksRequest() expected error, got nil")
+ return
+ }
+ if !strings.Contains(err.Error(), tt.errMsg) {
+ t.Errorf("ValidateGetTasksRequest() error = %v, want error containing %v", err, tt.errMsg)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("ValidateGetTasksRequest() error = %v, want nil", err)
+ }
+ }
+ })
+ }
+}
+
+func TestPlanningValidator_ValidateUpdateTaskStatusesRequest(t *testing.T) {
+ cfg := &config.Config{
+ Planning: config.PlanningConfig{},
+ }
+ validator := NewPlanningValidator(cfg)
+
+ tests := []struct {
+ name string
+ req UpdateTaskStatusesRequest
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "valid single update",
+ req: UpdateTaskStatusesRequest{
+ Tasks: []MCPTaskUpdateInput{
+ {TaskID: "task1", Status: "completed"},
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "valid multiple updates",
+ req: UpdateTaskStatusesRequest{
+ Tasks: []MCPTaskUpdateInput{
+ {TaskID: "task1", Status: "completed"},
+ {TaskID: "task2", Status: "in_progress"},
+ {TaskID: "task3", Status: "failed"},
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "empty updates array",
+ req: UpdateTaskStatusesRequest{
+ Tasks: []MCPTaskUpdateInput{},
+ },
+ wantErr: true,
+ errMsg: "at least one task update is required",
+ },
+ {
+ name: "empty task ID",
+ req: UpdateTaskStatusesRequest{
+ Tasks: []MCPTaskUpdateInput{
+ {TaskID: "", Status: "completed"},
+ },
+ },
+ wantErr: true,
+ errMsg: "task update 0: task_id is required",
+ },
+ {
+ name: "empty status",
+ req: UpdateTaskStatusesRequest{
+ Tasks: []MCPTaskUpdateInput{
+ {TaskID: "task1", Status: ""},
+ },
+ },
+ wantErr: true,
+ errMsg: "task update 0: status is required",
+ },
+ {
+ name: "invalid status",
+ req: UpdateTaskStatusesRequest{
+ Tasks: []MCPTaskUpdateInput{
+ {TaskID: "task1", Status: "invalid"},
+ },
+ },
+ wantErr: true,
+ errMsg: "task update 0: invalid status 'invalid'",
+ },
+ {
+ name: "second update invalid",
+ req: UpdateTaskStatusesRequest{
+ Tasks: []MCPTaskUpdateInput{
+ {TaskID: "task1", Status: "completed"},
+ {TaskID: "", Status: "pending"},
+ },
+ },
+ wantErr: true,
+ errMsg: "task update 1: task_id is required",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateUpdateTaskStatusesRequest(tt.req)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("ValidateUpdateTaskStatusesRequest() expected error, got nil")
+ return
+ }
+ if !strings.Contains(err.Error(), tt.errMsg) {
+ t.Errorf("ValidateUpdateTaskStatusesRequest() error = %v, want error containing %v", err, tt.errMsg)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("ValidateUpdateTaskStatusesRequest() error = %v, want nil", err)
+ }
+ }
+ })
+ }
+}
+
+func TestPlanningValidator_ValidateDeleteTasksRequest(t *testing.T) {
+ cfg := &config.Config{
+ Planning: config.PlanningConfig{},
+ }
+ validator := NewPlanningValidator(cfg)
+
+ tests := []struct {
+ name string
+ req DeleteTasksRequest
+ wantErr bool
+ errMsg string
+ }{
+ {
+ name: "valid single task ID",
+ req: DeleteTasksRequest{
+ TaskIDs: []string{"task1"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "valid multiple task IDs",
+ req: DeleteTasksRequest{
+ TaskIDs: []string{"task1", "task2", "task3"},
+ },
+ wantErr: false,
+ },
+ {
+ name: "empty task IDs array",
+ req: DeleteTasksRequest{
+ TaskIDs: []string{},
+ },
+ wantErr: true,
+ errMsg: "at least one task ID is required",
+ },
+ {
+ name: "empty task ID",
+ req: DeleteTasksRequest{
+ TaskIDs: []string{""},
+ },
+ wantErr: true,
+ errMsg: "task ID 0 is empty",
+ },
+ {
+ name: "second task ID empty",
+ req: DeleteTasksRequest{
+ TaskIDs: []string{"task1", ""},
+ },
+ wantErr: true,
+ errMsg: "task ID 1 is empty",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validator.ValidateDeleteTasksRequest(tt.req)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("ValidateDeleteTasksRequest() expected error, got nil")
+ return
+ }
+ if !strings.Contains(err.Error(), tt.errMsg) {
+ t.Errorf("ValidateDeleteTasksRequest() error = %v, want error containing %v", err, tt.errMsg)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("ValidateDeleteTasksRequest() error = %v, want nil", err)
+ }
+ }
+ })
+ }
+}
+
+func TestNewPlanningValidator(t *testing.T) {
+ cfg := &config.Config{
+ Planning: config.PlanningConfig{
+ MaxGoalLength: 100,
+ MaxTaskLength: 200,
+ },
+ }
+
+ validator := NewPlanningValidator(cfg)
+
+ if validator == nil {
+ t.Error("NewPlanningValidator() returned nil")
+ return
+ }
+
+ if validator.config != cfg {
+ t.Error("NewPlanningValidator() did not set config correctly")
+ }
+}