ssestream.go

  1// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
  2
  3package ssestream
  4
  5import (
  6	"bufio"
  7	"bytes"
  8	"encoding/json"
  9	"fmt"
 10	"io"
 11	"net/http"
 12	"strings"
 13)
 14
 15type Decoder interface {
 16	Event() Event
 17	Next() bool
 18	Close() error
 19	Err() error
 20}
 21
 22func NewDecoder(res *http.Response) Decoder {
 23	if res == nil || res.Body == nil {
 24		return nil
 25	}
 26
 27	var decoder Decoder
 28	contentType := res.Header.Get("content-type")
 29	if t, ok := decoderTypes[contentType]; ok {
 30		decoder = t(res.Body)
 31	} else {
 32		scn := bufio.NewScanner(res.Body)
 33		scn.Buffer(nil, bufio.MaxScanTokenSize<<4)
 34		decoder = &eventStreamDecoder{rc: res.Body, scn: scn}
 35	}
 36	return decoder
 37}
 38
 39var decoderTypes = map[string](func(io.ReadCloser) Decoder){}
 40
 41func RegisterDecoder(contentType string, decoder func(io.ReadCloser) Decoder) {
 42	decoderTypes[strings.ToLower(contentType)] = decoder
 43}
 44
 45type Event struct {
 46	Type string
 47	Data []byte
 48}
 49
 50// A base implementation of a Decoder for text/event-stream.
 51type eventStreamDecoder struct {
 52	evt Event
 53	rc  io.ReadCloser
 54	scn *bufio.Scanner
 55	err error
 56}
 57
 58func (s *eventStreamDecoder) Next() bool {
 59	if s.err != nil {
 60		return false
 61	}
 62
 63	event := ""
 64	data := bytes.NewBuffer(nil)
 65
 66	for s.scn.Scan() {
 67		txt := s.scn.Bytes()
 68
 69		// Dispatch event on an empty line
 70		if len(txt) == 0 {
 71			s.evt = Event{
 72				Type: event,
 73				Data: data.Bytes(),
 74			}
 75			return true
 76		}
 77
 78		// Split a string like "event: bar" into name="event" and value=" bar".
 79		name, value, _ := bytes.Cut(txt, []byte(":"))
 80
 81		// Consume an optional space after the colon if it exists.
 82		if len(value) > 0 && value[0] == ' ' {
 83			value = value[1:]
 84		}
 85
 86		switch string(name) {
 87		case "":
 88			// An empty line in the for ": something" is a comment and should be ignored.
 89			continue
 90		case "event":
 91			event = string(value)
 92		case "data":
 93			_, s.err = data.Write(value)
 94			if s.err != nil {
 95				break
 96			}
 97			_, s.err = data.WriteRune('\n')
 98			if s.err != nil {
 99				break
100			}
101		}
102	}
103
104	if s.scn.Err() != nil {
105		s.err = s.scn.Err()
106	}
107
108	return false
109}
110
111func (s *eventStreamDecoder) Event() Event {
112	return s.evt
113}
114
115func (s *eventStreamDecoder) Close() error {
116	return s.rc.Close()
117}
118
119func (s *eventStreamDecoder) Err() error {
120	return s.err
121}
122
123type Stream[T any] struct {
124	decoder Decoder
125	cur     T
126	err     error
127}
128
129func NewStream[T any](decoder Decoder, err error) *Stream[T] {
130	return &Stream[T]{
131		decoder: decoder,
132		err:     err,
133	}
134}
135
136// Next returns false if the stream has ended or an error occurred.
137// Call Stream.Current() to get the current value.
138// Call Stream.Err() to get the error.
139//
140//		for stream.Next() {
141//			data := stream.Current()
142//		}
143//
144//	 	if stream.Err() != nil {
145//			...
146//	 	}
147func (s *Stream[T]) Next() bool {
148	if s.err != nil {
149		return false
150	}
151
152	for s.decoder.Next() {
153		switch s.decoder.Event().Type {
154		case "completion":
155			var nxt T
156			s.err = json.Unmarshal(s.decoder.Event().Data, &nxt)
157			if s.err != nil {
158				return false
159			}
160			s.cur = nxt
161			return true
162		case "message_start", "message_delta", "message_stop", "content_block_start", "content_block_delta", "content_block_stop":
163			var nxt T
164			s.err = json.Unmarshal(s.decoder.Event().Data, &nxt)
165			if s.err != nil {
166				return false
167			}
168			s.cur = nxt
169			return true
170		case "ping":
171			continue
172		case "error":
173			s.err = fmt.Errorf("received error while streaming: %s", string(s.decoder.Event().Data))
174			return false
175		}
176	}
177
178	// decoder.Next() may be false because of an error
179	s.err = s.decoder.Err()
180
181	return false
182}
183
184func (s *Stream[T]) Current() T {
185	return s.cur
186}
187
188func (s *Stream[T]) Err() error {
189	return s.err
190}
191
192func (s *Stream[T]) Close() error {
193	if s.decoder == nil {
194		// already closed
195		return nil
196	}
197	return s.decoder.Close()
198}