1package anthropic
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/anthropics/anthropic-sdk-go"
15 "github.com/anthropics/anthropic-sdk-go/option"
16 "github.com/anthropics/anthropic-sdk-go/packages/param"
17 "github.com/charmbracelet/fantasy/ai"
18)
19
20const (
21 Name = "anthropic"
22 DefaultURL = "https://api.anthropic.com"
23)
24
25type options struct {
26 baseURL string
27 apiKey string
28 name string
29 headers map[string]string
30 client option.HTTPClient
31}
32
33type provider struct {
34 options options
35}
36
37type Option = func(*options)
38
39func New(opts ...Option) ai.Provider {
40 providerOptions := options{
41 headers: map[string]string{},
42 }
43 for _, o := range opts {
44 o(&providerOptions)
45 }
46
47 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
48 providerOptions.name = cmp.Or(providerOptions.name, Name)
49 return &provider{options: providerOptions}
50}
51
52func WithBaseURL(baseURL string) Option {
53 return func(o *options) {
54 o.baseURL = baseURL
55 }
56}
57
58func WithAPIKey(apiKey string) Option {
59 return func(o *options) {
60 o.apiKey = apiKey
61 }
62}
63
64func WithName(name string) Option {
65 return func(o *options) {
66 o.name = name
67 }
68}
69
70func WithHeaders(headers map[string]string) Option {
71 return func(o *options) {
72 maps.Copy(o.headers, headers)
73 }
74}
75
76func WithHTTPClient(client option.HTTPClient) Option {
77 return func(o *options) {
78 o.client = client
79 }
80}
81
82func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
83 anthropicClientOptions := []option.RequestOption{}
84 if a.options.apiKey != "" {
85 anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(a.options.apiKey))
86 }
87 if a.options.baseURL != "" {
88 anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(a.options.baseURL))
89 }
90
91 for key, value := range a.options.headers {
92 anthropicClientOptions = append(anthropicClientOptions, option.WithHeader(key, value))
93 }
94
95 if a.options.client != nil {
96 anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(a.options.client))
97 }
98 return languageModel{
99 modelID: modelID,
100 provider: a.options.name,
101 options: a.options,
102 client: anthropic.NewClient(anthropicClientOptions...),
103 }, nil
104}
105
106type languageModel struct {
107 provider string
108 modelID string
109 client anthropic.Client
110 options options
111}
112
113// Model implements ai.LanguageModel.
114func (a languageModel) Model() string {
115 return a.modelID
116}
117
118// Provider implements ai.LanguageModel.
119func (a languageModel) Provider() string {
120 return a.provider
121}
122
123func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
124 params := &anthropic.MessageNewParams{}
125 providerOptions := &ProviderOptions{}
126 if v, ok := call.ProviderOptions[Name]; ok {
127 providerOptions, ok = v.(*ProviderOptions)
128 if !ok {
129 return nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
130 }
131 }
132 sendReasoning := true
133 if providerOptions.SendReasoning != nil {
134 sendReasoning = *providerOptions.SendReasoning
135 }
136 systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning)
137
138 if call.FrequencyPenalty != nil {
139 warnings = append(warnings, ai.CallWarning{
140 Type: ai.CallWarningTypeUnsupportedSetting,
141 Setting: "FrequencyPenalty",
142 })
143 }
144 if call.PresencePenalty != nil {
145 warnings = append(warnings, ai.CallWarning{
146 Type: ai.CallWarningTypeUnsupportedSetting,
147 Setting: "PresencePenalty",
148 })
149 }
150
151 params.System = systemBlocks
152 params.Messages = messages
153 params.Model = anthropic.Model(a.modelID)
154 params.MaxTokens = 4096
155
156 if call.MaxOutputTokens != nil {
157 params.MaxTokens = *call.MaxOutputTokens
158 }
159
160 if call.Temperature != nil {
161 params.Temperature = param.NewOpt(*call.Temperature)
162 }
163 if call.TopK != nil {
164 params.TopK = param.NewOpt(*call.TopK)
165 }
166 if call.TopP != nil {
167 params.TopP = param.NewOpt(*call.TopP)
168 }
169
170 isThinking := false
171 var thinkingBudget int64
172 if providerOptions.Thinking != nil {
173 isThinking = true
174 thinkingBudget = providerOptions.Thinking.BudgetTokens
175 }
176 if isThinking {
177 if thinkingBudget == 0 {
178 return nil, nil, ai.NewUnsupportedFunctionalityError("thinking requires budget", "")
179 }
180 params.Thinking = anthropic.ThinkingConfigParamOfEnabled(thinkingBudget)
181 if call.Temperature != nil {
182 params.Temperature = param.Opt[float64]{}
183 warnings = append(warnings, ai.CallWarning{
184 Type: ai.CallWarningTypeUnsupportedSetting,
185 Setting: "temperature",
186 Details: "temperature is not supported when thinking is enabled",
187 })
188 }
189 if call.TopP != nil {
190 params.TopP = param.Opt[float64]{}
191 warnings = append(warnings, ai.CallWarning{
192 Type: ai.CallWarningTypeUnsupportedSetting,
193 Setting: "TopP",
194 Details: "TopP is not supported when thinking is enabled",
195 })
196 }
197 if call.TopK != nil {
198 params.TopK = param.Opt[int64]{}
199 warnings = append(warnings, ai.CallWarning{
200 Type: ai.CallWarningTypeUnsupportedSetting,
201 Setting: "TopK",
202 Details: "TopK is not supported when thinking is enabled",
203 })
204 }
205 params.MaxTokens = params.MaxTokens + thinkingBudget
206 }
207
208 if len(call.Tools) > 0 {
209 disableParallelToolUse := false
210 if providerOptions.DisableParallelToolUse != nil {
211 disableParallelToolUse = *providerOptions.DisableParallelToolUse
212 }
213 tools, toolChoice, toolWarnings := toTools(call.Tools, call.ToolChoice, disableParallelToolUse)
214 params.Tools = tools
215 if toolChoice != nil {
216 params.ToolChoice = *toolChoice
217 }
218 warnings = append(warnings, toolWarnings...)
219 }
220
221 return params, warnings, nil
222}
223
224func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
225 var options ProviderOptions
226 if err := ai.ParseOptions(data, &options); err != nil {
227 return nil, err
228 }
229 return &options, nil
230}
231
232func (a *provider) Name() string {
233 return Name
234}
235
236func getCacheControl(providerOptions ai.ProviderOptions) *CacheControl {
237 if anthropicOptions, ok := providerOptions[Name]; ok {
238 if options, ok := anthropicOptions.(*ProviderCacheControlOptions); ok {
239 return &options.CacheControl
240 }
241 }
242 return nil
243}
244
245func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningOptionMetadata {
246 if anthropicOptions, ok := providerOptions[Name]; ok {
247 if reasoning, ok := anthropicOptions.(*ReasoningOptionMetadata); ok {
248 return reasoning
249 }
250 }
251 return nil
252}
253
254type messageBlock struct {
255 Role ai.MessageRole
256 Messages []ai.Message
257}
258
259func groupIntoBlocks(prompt ai.Prompt) []*messageBlock {
260 var blocks []*messageBlock
261
262 var currentBlock *messageBlock
263
264 for _, msg := range prompt {
265 switch msg.Role {
266 case ai.MessageRoleSystem:
267 if currentBlock == nil || currentBlock.Role != ai.MessageRoleSystem {
268 currentBlock = &messageBlock{
269 Role: ai.MessageRoleSystem,
270 Messages: []ai.Message{},
271 }
272 blocks = append(blocks, currentBlock)
273 }
274 currentBlock.Messages = append(currentBlock.Messages, msg)
275 case ai.MessageRoleUser:
276 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
277 currentBlock = &messageBlock{
278 Role: ai.MessageRoleUser,
279 Messages: []ai.Message{},
280 }
281 blocks = append(blocks, currentBlock)
282 }
283 currentBlock.Messages = append(currentBlock.Messages, msg)
284 case ai.MessageRoleAssistant:
285 if currentBlock == nil || currentBlock.Role != ai.MessageRoleAssistant {
286 currentBlock = &messageBlock{
287 Role: ai.MessageRoleAssistant,
288 Messages: []ai.Message{},
289 }
290 blocks = append(blocks, currentBlock)
291 }
292 currentBlock.Messages = append(currentBlock.Messages, msg)
293 case ai.MessageRoleTool:
294 if currentBlock == nil || currentBlock.Role != ai.MessageRoleUser {
295 currentBlock = &messageBlock{
296 Role: ai.MessageRoleUser,
297 Messages: []ai.Message{},
298 }
299 blocks = append(blocks, currentBlock)
300 }
301 currentBlock.Messages = append(currentBlock.Messages, msg)
302 }
303 }
304 return blocks
305}
306
307func toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) {
308 for _, tool := range tools {
309 if tool.GetType() == ai.ToolTypeFunction {
310 ft, ok := tool.(ai.FunctionTool)
311 if !ok {
312 continue
313 }
314 required := []string{}
315 var properties any
316 if props, ok := ft.InputSchema["properties"]; ok {
317 properties = props
318 }
319 if req, ok := ft.InputSchema["required"]; ok {
320 if reqArr, ok := req.([]string); ok {
321 required = reqArr
322 }
323 }
324 cacheControl := getCacheControl(ft.ProviderOptions)
325
326 anthropicTool := anthropic.ToolParam{
327 Name: ft.Name,
328 Description: anthropic.String(ft.Description),
329 InputSchema: anthropic.ToolInputSchemaParam{
330 Properties: properties,
331 Required: required,
332 },
333 }
334 if cacheControl != nil {
335 anthropicTool.CacheControl = anthropic.NewCacheControlEphemeralParam()
336 }
337 anthropicTools = append(anthropicTools, anthropic.ToolUnionParam{OfTool: &anthropicTool})
338 continue
339 }
340 // TODO: handle provider tool calls
341 warnings = append(warnings, ai.CallWarning{
342 Type: ai.CallWarningTypeUnsupportedTool,
343 Tool: tool,
344 Message: "tool is not supported",
345 })
346 }
347 if toolChoice == nil {
348 if disableParallelToolCalls {
349 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
350 OfAuto: &anthropic.ToolChoiceAutoParam{
351 Type: "auto",
352 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
353 },
354 }
355 }
356 return anthropicTools, anthropicToolChoice, warnings
357 }
358
359 switch *toolChoice {
360 case ai.ToolChoiceAuto:
361 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
362 OfAuto: &anthropic.ToolChoiceAutoParam{
363 Type: "auto",
364 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
365 },
366 }
367 case ai.ToolChoiceRequired:
368 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
369 OfAny: &anthropic.ToolChoiceAnyParam{
370 Type: "any",
371 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
372 },
373 }
374 case ai.ToolChoiceNone:
375 return anthropicTools, anthropicToolChoice, warnings
376 default:
377 anthropicToolChoice = &anthropic.ToolChoiceUnionParam{
378 OfTool: &anthropic.ToolChoiceToolParam{
379 Type: "tool",
380 Name: string(*toolChoice),
381 DisableParallelToolUse: param.NewOpt(disableParallelToolCalls),
382 },
383 }
384 }
385 return anthropicTools, anthropicToolChoice, warnings
386}
387
388func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) {
389 var systemBlocks []anthropic.TextBlockParam
390 var messages []anthropic.MessageParam
391 var warnings []ai.CallWarning
392
393 blocks := groupIntoBlocks(prompt)
394 finishedSystemBlock := false
395 for _, block := range blocks {
396 switch block.Role {
397 case ai.MessageRoleSystem:
398 if finishedSystemBlock {
399 // skip multiple system messages that are separated by user/assistant messages
400 // TODO: see if we need to send error here?
401 continue
402 }
403 finishedSystemBlock = true
404 for _, msg := range block.Messages {
405 for _, part := range msg.Content {
406 cacheControl := getCacheControl(part.Options())
407 text, ok := ai.AsMessagePart[ai.TextPart](part)
408 if !ok {
409 continue
410 }
411 textBlock := anthropic.TextBlockParam{
412 Text: text.Text,
413 }
414 if cacheControl != nil {
415 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
416 }
417 systemBlocks = append(systemBlocks, textBlock)
418 }
419 }
420
421 case ai.MessageRoleUser:
422 var anthropicContent []anthropic.ContentBlockParamUnion
423 for _, msg := range block.Messages {
424 if msg.Role == ai.MessageRoleUser {
425 for i, part := range msg.Content {
426 isLastPart := i == len(msg.Content)-1
427 cacheControl := getCacheControl(part.Options())
428 if cacheControl == nil && isLastPart {
429 cacheControl = getCacheControl(msg.ProviderOptions)
430 }
431 switch part.GetType() {
432 case ai.ContentTypeText:
433 text, ok := ai.AsMessagePart[ai.TextPart](part)
434 if !ok {
435 continue
436 }
437 textBlock := &anthropic.TextBlockParam{
438 Text: text.Text,
439 }
440 if cacheControl != nil {
441 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
442 }
443 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
444 OfText: textBlock,
445 })
446 case ai.ContentTypeFile:
447 file, ok := ai.AsMessagePart[ai.FilePart](part)
448 if !ok {
449 continue
450 }
451 // TODO: handle other file types
452 if !strings.HasPrefix(file.MediaType, "image/") {
453 continue
454 }
455
456 base64Encoded := base64.StdEncoding.EncodeToString(file.Data)
457 imageBlock := anthropic.NewImageBlockBase64(file.MediaType, base64Encoded)
458 if cacheControl != nil {
459 imageBlock.OfImage.CacheControl = anthropic.NewCacheControlEphemeralParam()
460 }
461 anthropicContent = append(anthropicContent, imageBlock)
462 }
463 }
464 } else if msg.Role == ai.MessageRoleTool {
465 for i, part := range msg.Content {
466 isLastPart := i == len(msg.Content)-1
467 cacheControl := getCacheControl(part.Options())
468 if cacheControl == nil && isLastPart {
469 cacheControl = getCacheControl(msg.ProviderOptions)
470 }
471 result, ok := ai.AsMessagePart[ai.ToolResultPart](part)
472 if !ok {
473 continue
474 }
475 toolResultBlock := anthropic.ToolResultBlockParam{
476 ToolUseID: result.ToolCallID,
477 }
478 switch result.Output.GetType() {
479 case ai.ToolResultContentTypeText:
480 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Output)
481 if !ok {
482 continue
483 }
484 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
485 {
486 OfText: &anthropic.TextBlockParam{
487 Text: content.Text,
488 },
489 },
490 }
491 case ai.ToolResultContentTypeMedia:
492 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentMedia](result.Output)
493 if !ok {
494 continue
495 }
496 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
497 {
498 OfImage: anthropic.NewImageBlockBase64(content.MediaType, content.Data).OfImage,
499 },
500 }
501 case ai.ToolResultContentTypeError:
502 content, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Output)
503 if !ok {
504 continue
505 }
506 toolResultBlock.Content = []anthropic.ToolResultBlockParamContentUnion{
507 {
508 OfText: &anthropic.TextBlockParam{
509 Text: content.Error.Error(),
510 },
511 },
512 }
513 toolResultBlock.IsError = param.NewOpt(true)
514 }
515 if cacheControl != nil {
516 toolResultBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
517 }
518 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
519 OfToolResult: &toolResultBlock,
520 })
521 }
522 }
523 }
524 messages = append(messages, anthropic.NewUserMessage(anthropicContent...))
525 case ai.MessageRoleAssistant:
526 var anthropicContent []anthropic.ContentBlockParamUnion
527 for _, msg := range block.Messages {
528 for i, part := range msg.Content {
529 isLastPart := i == len(msg.Content)-1
530 cacheControl := getCacheControl(part.Options())
531 if cacheControl == nil && isLastPart {
532 cacheControl = getCacheControl(msg.ProviderOptions)
533 }
534 switch part.GetType() {
535 case ai.ContentTypeText:
536 text, ok := ai.AsMessagePart[ai.TextPart](part)
537 if !ok {
538 continue
539 }
540 textBlock := &anthropic.TextBlockParam{
541 Text: text.Text,
542 }
543 if cacheControl != nil {
544 textBlock.CacheControl = anthropic.NewCacheControlEphemeralParam()
545 }
546 anthropicContent = append(anthropicContent, anthropic.ContentBlockParamUnion{
547 OfText: textBlock,
548 })
549 case ai.ContentTypeReasoning:
550 reasoning, ok := ai.AsMessagePart[ai.ReasoningPart](part)
551 if !ok {
552 continue
553 }
554 if !sendReasoningData {
555 warnings = append(warnings, ai.CallWarning{
556 Type: "other",
557 Message: "sending reasoning content is disabled for this model",
558 })
559 continue
560 }
561 reasoningMetadata := getReasoningMetadata(part.Options())
562 if reasoningMetadata == nil {
563 warnings = append(warnings, ai.CallWarning{
564 Type: "other",
565 Message: "unsupported reasoning metadata",
566 })
567 continue
568 }
569
570 if reasoningMetadata.Signature != "" {
571 anthropicContent = append(anthropicContent, anthropic.NewThinkingBlock(reasoningMetadata.Signature, reasoning.Text))
572 } else if reasoningMetadata.RedactedData != "" {
573 anthropicContent = append(anthropicContent, anthropic.NewRedactedThinkingBlock(reasoningMetadata.RedactedData))
574 } else {
575 warnings = append(warnings, ai.CallWarning{
576 Type: "other",
577 Message: "unsupported reasoning metadata",
578 })
579 continue
580 }
581 case ai.ContentTypeToolCall:
582 toolCall, ok := ai.AsMessagePart[ai.ToolCallPart](part)
583 if !ok {
584 continue
585 }
586 if toolCall.ProviderExecuted {
587 // TODO: implement provider executed call
588 continue
589 }
590
591 var inputMap map[string]any
592 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
593 if err != nil {
594 continue
595 }
596 toolUseBlock := anthropic.NewToolUseBlock(toolCall.ToolCallID, inputMap, toolCall.ToolName)
597 if cacheControl != nil {
598 toolUseBlock.OfToolUse.CacheControl = anthropic.NewCacheControlEphemeralParam()
599 }
600 anthropicContent = append(anthropicContent, toolUseBlock)
601 case ai.ContentTypeToolResult:
602 // TODO: implement provider executed tool result
603 }
604 }
605 }
606 messages = append(messages, anthropic.NewAssistantMessage(anthropicContent...))
607 }
608 }
609 return systemBlocks, messages, warnings
610}
611
612func (o languageModel) handleError(err error) error {
613 var apiErr *anthropic.Error
614 if errors.As(err, &apiErr) {
615 requestDump := apiErr.DumpRequest(true)
616 responseDump := apiErr.DumpResponse(true)
617 headers := map[string]string{}
618 for k, h := range apiErr.Response.Header {
619 v := h[len(h)-1]
620 headers[strings.ToLower(k)] = v
621 }
622 return ai.NewAPICallError(
623 apiErr.Error(),
624 apiErr.Request.URL.String(),
625 string(requestDump),
626 apiErr.StatusCode,
627 headers,
628 string(responseDump),
629 apiErr,
630 false,
631 )
632 }
633 return err
634}
635
636func mapFinishReason(finishReason string) ai.FinishReason {
637 switch finishReason {
638 case "end_turn", "pause_turn", "stop_sequence":
639 return ai.FinishReasonStop
640 case "max_tokens":
641 return ai.FinishReasonLength
642 case "tool_use":
643 return ai.FinishReasonToolCalls
644 default:
645 return ai.FinishReasonUnknown
646 }
647}
648
649// Generate implements ai.LanguageModel.
650func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
651 params, warnings, err := a.prepareParams(call)
652 if err != nil {
653 return nil, err
654 }
655 response, err := a.client.Messages.New(ctx, *params)
656 if err != nil {
657 return nil, a.handleError(err)
658 }
659
660 var content []ai.Content
661 for _, block := range response.Content {
662 switch block.Type {
663 case "text":
664 text, ok := block.AsAny().(anthropic.TextBlock)
665 if !ok {
666 continue
667 }
668 content = append(content, ai.TextContent{
669 Text: text.Text,
670 })
671 case "thinking":
672 reasoning, ok := block.AsAny().(anthropic.ThinkingBlock)
673 if !ok {
674 continue
675 }
676 content = append(content, ai.ReasoningContent{
677 Text: reasoning.Thinking,
678 ProviderMetadata: ai.ProviderMetadata{
679 Name: &ReasoningOptionMetadata{
680 Signature: reasoning.Signature,
681 },
682 },
683 })
684 case "redacted_thinking":
685 reasoning, ok := block.AsAny().(anthropic.RedactedThinkingBlock)
686 if !ok {
687 continue
688 }
689 content = append(content, ai.ReasoningContent{
690 Text: "",
691 ProviderMetadata: ai.ProviderMetadata{
692 Name: &ReasoningOptionMetadata{
693 RedactedData: reasoning.Data,
694 },
695 },
696 })
697 case "tool_use":
698 toolUse, ok := block.AsAny().(anthropic.ToolUseBlock)
699 if !ok {
700 continue
701 }
702 content = append(content, ai.ToolCallContent{
703 ToolCallID: toolUse.ID,
704 ToolName: toolUse.Name,
705 Input: string(toolUse.Input),
706 ProviderExecuted: false,
707 })
708 }
709 }
710
711 return &ai.Response{
712 Content: content,
713 Usage: ai.Usage{
714 InputTokens: response.Usage.InputTokens,
715 OutputTokens: response.Usage.OutputTokens,
716 TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
717 CacheCreationTokens: response.Usage.CacheCreationInputTokens,
718 CacheReadTokens: response.Usage.CacheReadInputTokens,
719 },
720 FinishReason: mapFinishReason(string(response.StopReason)),
721 ProviderMetadata: ai.ProviderMetadata{},
722 Warnings: warnings,
723 }, nil
724}
725
726// Stream implements ai.LanguageModel.
727func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
728 params, warnings, err := a.prepareParams(call)
729 if err != nil {
730 return nil, err
731 }
732
733 stream := a.client.Messages.NewStreaming(ctx, *params)
734 acc := anthropic.Message{}
735 return func(yield func(ai.StreamPart) bool) {
736 if len(warnings) > 0 {
737 if !yield(ai.StreamPart{
738 Type: ai.StreamPartTypeWarnings,
739 Warnings: warnings,
740 }) {
741 return
742 }
743 }
744
745 for stream.Next() {
746 chunk := stream.Current()
747 _ = acc.Accumulate(chunk)
748 switch chunk.Type {
749 case "content_block_start":
750 contentBlockType := chunk.ContentBlock.Type
751 switch contentBlockType {
752 case "text":
753 if !yield(ai.StreamPart{
754 Type: ai.StreamPartTypeTextStart,
755 ID: fmt.Sprintf("%d", chunk.Index),
756 }) {
757 return
758 }
759 case "thinking":
760 if !yield(ai.StreamPart{
761 Type: ai.StreamPartTypeReasoningStart,
762 ID: fmt.Sprintf("%d", chunk.Index),
763 }) {
764 return
765 }
766 case "redacted_thinking":
767 if !yield(ai.StreamPart{
768 Type: ai.StreamPartTypeReasoningStart,
769 ID: fmt.Sprintf("%d", chunk.Index),
770 ProviderMetadata: ai.ProviderMetadata{
771 Name: &ReasoningOptionMetadata{
772 RedactedData: chunk.ContentBlock.Data,
773 },
774 },
775 }) {
776 return
777 }
778 case "tool_use":
779 if !yield(ai.StreamPart{
780 Type: ai.StreamPartTypeToolInputStart,
781 ID: chunk.ContentBlock.ID,
782 ToolCallName: chunk.ContentBlock.Name,
783 ToolCallInput: "",
784 }) {
785 return
786 }
787 }
788 case "content_block_stop":
789 if len(acc.Content)-1 < int(chunk.Index) {
790 continue
791 }
792 contentBlock := acc.Content[int(chunk.Index)]
793 switch contentBlock.Type {
794 case "text":
795 if !yield(ai.StreamPart{
796 Type: ai.StreamPartTypeTextEnd,
797 ID: fmt.Sprintf("%d", chunk.Index),
798 }) {
799 return
800 }
801 case "thinking":
802 if !yield(ai.StreamPart{
803 Type: ai.StreamPartTypeReasoningEnd,
804 ID: fmt.Sprintf("%d", chunk.Index),
805 }) {
806 return
807 }
808 case "tool_use":
809 if !yield(ai.StreamPart{
810 Type: ai.StreamPartTypeToolInputEnd,
811 ID: contentBlock.ID,
812 }) {
813 return
814 }
815 if !yield(ai.StreamPart{
816 Type: ai.StreamPartTypeToolCall,
817 ID: contentBlock.ID,
818 ToolCallName: contentBlock.Name,
819 ToolCallInput: string(contentBlock.Input),
820 }) {
821 return
822 }
823 }
824 case "content_block_delta":
825 switch chunk.Delta.Type {
826 case "text_delta":
827 if !yield(ai.StreamPart{
828 Type: ai.StreamPartTypeTextDelta,
829 ID: fmt.Sprintf("%d", chunk.Index),
830 Delta: chunk.Delta.Text,
831 }) {
832 return
833 }
834 case "thinking_delta":
835 if !yield(ai.StreamPart{
836 Type: ai.StreamPartTypeReasoningDelta,
837 ID: fmt.Sprintf("%d", chunk.Index),
838 Delta: chunk.Delta.Thinking,
839 }) {
840 return
841 }
842 case "signature_delta":
843 if !yield(ai.StreamPart{
844 Type: ai.StreamPartTypeReasoningDelta,
845 ID: fmt.Sprintf("%d", chunk.Index),
846 ProviderMetadata: ai.ProviderMetadata{
847 Name: &ReasoningOptionMetadata{
848 Signature: chunk.Delta.Signature,
849 },
850 },
851 }) {
852 return
853 }
854 case "input_json_delta":
855 if len(acc.Content)-1 < int(chunk.Index) {
856 continue
857 }
858 contentBlock := acc.Content[int(chunk.Index)]
859 if !yield(ai.StreamPart{
860 Type: ai.StreamPartTypeToolInputDelta,
861 ID: contentBlock.ID,
862 ToolCallInput: chunk.Delta.PartialJSON,
863 }) {
864 return
865 }
866 }
867 case "message_stop":
868 }
869 }
870
871 err := stream.Err()
872 if err == nil || errors.Is(err, io.EOF) {
873 yield(ai.StreamPart{
874 Type: ai.StreamPartTypeFinish,
875 ID: acc.ID,
876 FinishReason: mapFinishReason(string(acc.StopReason)),
877 Usage: ai.Usage{
878 InputTokens: acc.Usage.InputTokens,
879 OutputTokens: acc.Usage.OutputTokens,
880 TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
881 CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
882 CacheReadTokens: acc.Usage.CacheReadInputTokens,
883 },
884 ProviderMetadata: ai.ProviderMetadata{},
885 })
886 return
887 } else {
888 yield(ai.StreamPart{
889 Type: ai.StreamPartTypeError,
890 Error: a.handleError(err),
891 })
892 return
893 }
894 }, nil
895}