writer.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. package wsutil
  2. import (
  3. "fmt"
  4. "io"
  5. "github.com/gobwas/pool"
  6. "github.com/gobwas/pool/pbytes"
  7. "github.com/gobwas/ws"
  8. )
  9. // DefaultWriteBuffer contains size of Writer's default buffer. It used by
  10. // Writer constructor functions.
  11. var DefaultWriteBuffer = 4096
  12. var (
  13. // ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is
  14. // not empty and write through could not be done. That is, caller should call
  15. // Writer.FlushFragment() to make buffer empty.
  16. ErrNotEmpty = fmt.Errorf("writer not empty")
  17. // ErrControlOverflow is returned by ControlWriter.Write() to indicate that
  18. // no more data could be written to the underlying io.Writer because
  19. // MaxControlFramePayloadSize limit is reached.
  20. ErrControlOverflow = fmt.Errorf("control frame payload overflow")
  21. )
  22. // Constants which are represent frame length ranges.
  23. const (
  24. len7 = int64(125) // 126 and 127 are reserved values
  25. len16 = int64(^uint16(0))
  26. len64 = int64((^uint64(0)) >> 1)
  27. )
  28. // ControlWriter is a wrapper around Writer that contains some guards for
  29. // buffered writes of control frames.
  30. type ControlWriter struct {
  31. w *Writer
  32. limit int
  33. n int
  34. }
  35. // NewControlWriter contains ControlWriter with Writer inside whose buffer size
  36. // is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize.
  37. func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter {
  38. return &ControlWriter{
  39. w: NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize),
  40. limit: ws.MaxControlFramePayloadSize,
  41. }
  42. }
  43. // NewControlWriterBuffer returns a new ControlWriter with buf as a buffer.
  44. //
  45. // Note that it reserves x bytes of buf for header data, where x could be
  46. // ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most
  47. // (ws.MaxControlFramePayloadSize + x) bytes of buf will be used.
  48. //
  49. // It panics if len(buf) <= ws.MinHeaderSize + x.
  50. func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter {
  51. max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize)
  52. if len(buf) > max {
  53. buf = buf[:max]
  54. }
  55. w := NewWriterBuffer(dest, state, op, buf)
  56. return &ControlWriter{
  57. w: w,
  58. limit: len(w.buf),
  59. }
  60. }
  61. // Write implements io.Writer. It writes to the underlying Writer until it
  62. // returns error or until ControlWriter write limit will be exceeded.
  63. func (c *ControlWriter) Write(p []byte) (n int, err error) {
  64. if c.n+len(p) > c.limit {
  65. return 0, ErrControlOverflow
  66. }
  67. return c.w.Write(p)
  68. }
  69. // Flush flushes all buffered data to the underlying io.Writer.
  70. func (c *ControlWriter) Flush() error {
  71. return c.w.Flush()
  72. }
  73. var writers = pool.New(128, 65536)
  74. // GetWriter tries to reuse Writer getting it from the pool.
  75. //
  76. // This function is intended for memory consumption optimizations, because
  77. // NewWriter*() functions make allocations for inner buffer.
  78. //
  79. // Note the it ceils n to the power of two.
  80. //
  81. // If you have your own bytes buffer pool you could use NewWriterBuffer to use
  82. // pooled bytes in writer.
  83. func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
  84. x, m := writers.Get(n)
  85. if x != nil {
  86. w := x.(*Writer)
  87. w.Reset(dest, state, op)
  88. return w
  89. }
  90. // NOTE: we use m instead of n, because m is an attempt to reuse w of such
  91. // size in the future.
  92. return NewWriterBufferSize(dest, state, op, m)
  93. }
  94. // PutWriter puts w for future reuse by GetWriter().
  95. func PutWriter(w *Writer) {
  96. w.Reset(nil, 0, 0)
  97. writers.Put(w, w.Size())
  98. }
  99. // Writer contains logic of buffering output data into a WebSocket fragments.
  100. // It is much the same as bufio.Writer, except the thing that it works with
  101. // WebSocket frames, not the raw data.
  102. //
  103. // Writer writes frames with specified OpCode.
  104. // It uses ws.State to decide whether the output frames must be masked.
  105. //
  106. // Note that it does not check control frame size or other RFC rules.
  107. // That is, it must be used with special care to write control frames without
  108. // violation of RFC. You could use ControlWriter that wraps Writer and contains
  109. // some guards for writing control frames.
  110. //
  111. // If an error occurs writing to a Writer, no more data will be accepted and
  112. // all subsequent writes will return the error.
  113. //
  114. // After all data has been written, the client should call the Flush() method
  115. // to guarantee all data has been forwarded to the underlying io.Writer.
  116. type Writer struct {
  117. // dest specifies a destination of buffer flushes.
  118. dest io.Writer
  119. // op specifies the WebSocket operation code used in flushed frames.
  120. op ws.OpCode
  121. // state specifies the state of the Writer.
  122. state ws.State
  123. // extensions is a list of negotiated extensions for writer Dest.
  124. // It is used to meet the specs and set appropriate bits in fragment
  125. // header RSV segment.
  126. extensions []SendExtension
  127. // noFlush reports whether buffer must grow instead of being flushed.
  128. noFlush bool
  129. // Raw representation of the buffer, including reserved header bytes.
  130. raw []byte
  131. // Writeable part of buffer, without reserved header bytes.
  132. // Resetting this to nil will not result in reallocation if raw is not nil.
  133. // And vice versa: if buf is not nil, then Writer is assumed as ready and
  134. // initialized.
  135. buf []byte
  136. // Buffered bytes counter.
  137. n int
  138. dirty bool
  139. fseq int
  140. err error
  141. }
  142. // NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size.
  143. func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer {
  144. return NewWriterBufferSize(dest, state, op, 0)
  145. }
  146. // NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize.
  147. // That is, output frames payload length could be up to n, except the case when
  148. // Write() is called on empty Writer with len(p) > n.
  149. //
  150. // If n <= 0 then the default buffer size is used as Writer's buffer size.
  151. func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
  152. if n > 0 {
  153. n += headerSize(state, n)
  154. }
  155. return NewWriterBufferSize(dest, state, op, n)
  156. }
  157. // NewWriterBufferSize returns a new Writer whose buffer size is equal to n.
  158. // If n <= ws.MinHeaderSize then the default buffer size is used.
  159. //
  160. // Note that Writer will reserve x bytes for header data, where x is in range
  161. // [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer
  162. // will not have payload length equal to n, except the case when Write() is
  163. // called on empty Writer with len(p) > n.
  164. func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
  165. if n <= ws.MinHeaderSize {
  166. n = DefaultWriteBuffer
  167. }
  168. return NewWriterBuffer(dest, state, op, make([]byte, n))
  169. }
  170. // NewWriterBuffer returns a new Writer with buf as a buffer.
  171. //
  172. // Note that it reserves x bytes of buf for header data, where x is in range
  173. // [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size).
  174. //
  175. // You could use ws.HeaderSize() to calculate number of bytes needed to store
  176. // header data.
  177. //
  178. // It panics if len(buf) is too small to fit header and payload data.
  179. func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer {
  180. w := &Writer{
  181. dest: dest,
  182. state: state,
  183. op: op,
  184. raw: buf,
  185. }
  186. w.initBuf()
  187. return w
  188. }
  189. func (w *Writer) initBuf() {
  190. offset := reserve(w.state, len(w.raw))
  191. if len(w.raw) <= offset {
  192. panic("wsutil: writer buffer is too small")
  193. }
  194. w.buf = w.raw[offset:]
  195. }
  196. // Reset resets Writer as it was created by New() methods.
  197. // Note that Reset does reset extensions and other options was set after
  198. // Writer initialization.
  199. func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) {
  200. w.dest = dest
  201. w.state = state
  202. w.op = op
  203. w.initBuf()
  204. w.n = 0
  205. w.dirty = false
  206. w.fseq = 0
  207. w.extensions = w.extensions[:0]
  208. w.noFlush = false
  209. }
  210. // ResetOp is an quick version of Reset().
  211. // ResetOp does reset unwritten fragments and does not reset results of
  212. // SetExtensions() or DisableFlush() methods.
  213. func (w *Writer) ResetOp(op ws.OpCode) {
  214. w.op = op
  215. w.n = 0
  216. w.dirty = false
  217. w.fseq = 0
  218. }
  219. // SetExtensions adds xs as extensions to be used during writes.
  220. func (w *Writer) SetExtensions(xs ...SendExtension) {
  221. w.extensions = xs
  222. }
  223. // DisableFlush denies Writer to write fragments.
  224. func (w *Writer) DisableFlush() {
  225. w.noFlush = true
  226. }
  227. // Size returns the size of the underlying buffer in bytes (not including
  228. // WebSocket header bytes).
  229. func (w *Writer) Size() int {
  230. return len(w.buf)
  231. }
  232. // Available returns how many bytes are unused in the buffer.
  233. func (w *Writer) Available() int {
  234. return len(w.buf) - w.n
  235. }
  236. // Buffered returns the number of bytes that have been written into the current
  237. // buffer.
  238. func (w *Writer) Buffered() int {
  239. return w.n
  240. }
  241. // Write implements io.Writer.
  242. //
  243. // Note that even if the Writer was created to have N-sized buffer, Write()
  244. // with payload of N bytes will not fit into that buffer. Writer reserves some
  245. // space to fit WebSocket header data.
  246. func (w *Writer) Write(p []byte) (n int, err error) {
  247. // Even empty p may make a sense.
  248. w.dirty = true
  249. var nn int
  250. for len(p) > w.Available() && w.err == nil {
  251. if w.noFlush {
  252. w.Grow(len(p))
  253. continue
  254. }
  255. if w.Buffered() == 0 {
  256. // Large write, empty buffer. Write directly from p to avoid copy.
  257. // Trade off here is that we make additional Write() to underlying
  258. // io.Writer when writing frame header.
  259. //
  260. // On large buffers additional write is better than copying.
  261. nn, _ = w.WriteThrough(p)
  262. } else {
  263. nn = copy(w.buf[w.n:], p)
  264. w.n += nn
  265. w.FlushFragment()
  266. }
  267. n += nn
  268. p = p[nn:]
  269. }
  270. if w.err != nil {
  271. return n, w.err
  272. }
  273. nn = copy(w.buf[w.n:], p)
  274. w.n += nn
  275. n += nn
  276. // Even if w.Available() == 0 we will not flush buffer preventively because
  277. // this could bring unwanted fragmentation. That is, user could create
  278. // buffer with size that fits exactly all further Write() call, and then
  279. // call Flush(), excepting that single and not fragmented frame will be
  280. // sent. With preemptive flush this case will produce two frames – last one
  281. // will be empty and just to set fin = true.
  282. return n, w.err
  283. }
  284. func ceilPowerOfTwo(n int) int {
  285. n |= n >> 1
  286. n |= n >> 2
  287. n |= n >> 4
  288. n |= n >> 8
  289. n |= n >> 16
  290. n |= n >> 32
  291. n++
  292. return n
  293. }
  294. // Grow grows Writer's internal buffer capacity to guarantee space for another
  295. // n bytes of _payload_ -- that is, frame header is not included in n.
  296. func (w *Writer) Grow(n int) {
  297. // NOTE: we must respect the possibility of header reserved bytes grow.
  298. var (
  299. size = len(w.raw)
  300. prevOffset = len(w.raw) - len(w.buf)
  301. nextOffset = len(w.raw) - len(w.buf)
  302. buffered = w.Buffered()
  303. )
  304. for cap := size - nextOffset - buffered; cap < n; {
  305. // This loop runs twice only at split cases, when reservation of raw
  306. // buffer space for the header shrinks capacity of new buffer such that
  307. // it still less than n.
  308. //
  309. // Loop is safe here because:
  310. // - (offset + buffered + n) is greater than size, otherwise (cap < n)
  311. // would be false:
  312. // size = offset + buffered + freeSpace (cap)
  313. // size' = offset + buffered + wantSpace (n)
  314. // Since (cap < n) is true in the loop condition, size' is guaranteed
  315. // to be greater => no infinite loop.
  316. size = ceilPowerOfTwo(nextOffset + buffered + n)
  317. nextOffset = reserve(w.state, size)
  318. cap = size - nextOffset - buffered
  319. }
  320. if size < len(w.raw) {
  321. panic("wsutil: buffer grow leads to its reduce")
  322. }
  323. if size == len(w.raw) {
  324. return
  325. }
  326. p := make([]byte, size)
  327. copy(p[nextOffset-prevOffset:], w.raw[:prevOffset+buffered])
  328. w.raw = p
  329. w.buf = w.raw[nextOffset:]
  330. }
  331. // WriteThrough writes data bypassing the buffer.
  332. // Note that Writer's buffer must be empty before calling WriteThrough().
  333. func (w *Writer) WriteThrough(p []byte) (n int, err error) {
  334. if w.err != nil {
  335. return 0, w.err
  336. }
  337. if w.Buffered() != 0 {
  338. return 0, ErrNotEmpty
  339. }
  340. var frame ws.Frame
  341. frame.Header = ws.Header{
  342. OpCode: w.opCode(),
  343. Fin: false,
  344. Length: int64(len(p)),
  345. }
  346. for _, x := range w.extensions {
  347. frame.Header, err = x.SetBits(frame.Header)
  348. if err != nil {
  349. return 0, err
  350. }
  351. }
  352. if w.state.ClientSide() {
  353. // Should copy bytes to prevent corruption of caller data.
  354. payload := pbytes.GetLen(len(p))
  355. defer pbytes.Put(payload)
  356. copy(payload, p)
  357. frame.Payload = payload
  358. frame = ws.MaskFrameInPlace(frame)
  359. } else {
  360. frame.Payload = p
  361. }
  362. w.err = ws.WriteFrame(w.dest, frame)
  363. if w.err == nil {
  364. n = len(p)
  365. }
  366. w.dirty = true
  367. w.fseq++
  368. return n, w.err
  369. }
  370. // ReadFrom implements io.ReaderFrom.
  371. func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
  372. var nn int
  373. for err == nil {
  374. if w.Available() == 0 {
  375. if w.noFlush {
  376. w.Grow(w.Buffered()) // Twice bigger.
  377. } else {
  378. err = w.FlushFragment()
  379. }
  380. continue
  381. }
  382. // We copy the behavior of bufio.Writer here.
  383. // Also, from the docs on io.ReaderFrom:
  384. // ReadFrom reads data from r until EOF or error.
  385. //
  386. // See https://codereview.appspot.com/76400048/#ps1
  387. const maxEmptyReads = 100
  388. var nr int
  389. for nr < maxEmptyReads {
  390. nn, err = src.Read(w.buf[w.n:])
  391. if nn != 0 || err != nil {
  392. break
  393. }
  394. nr++
  395. }
  396. if nr == maxEmptyReads {
  397. return n, io.ErrNoProgress
  398. }
  399. w.n += nn
  400. n += int64(nn)
  401. }
  402. if err == io.EOF {
  403. // NOTE: Do not flush preemptively.
  404. // See the Write() sources for more info.
  405. err = nil
  406. w.dirty = true
  407. }
  408. return n, err
  409. }
  410. // Flush writes any buffered data to the underlying io.Writer.
  411. // It sends the frame with "fin" flag set to true.
  412. //
  413. // If no Write() or ReadFrom() was made, then Flush() does nothing.
  414. func (w *Writer) Flush() error {
  415. if (!w.dirty && w.Buffered() == 0) || w.err != nil {
  416. return w.err
  417. }
  418. w.err = w.flushFragment(true)
  419. w.n = 0
  420. w.dirty = false
  421. w.fseq = 0
  422. return w.err
  423. }
  424. // FlushFragment writes any buffered data to the underlying io.Writer.
  425. // It sends the frame with "fin" flag set to false.
  426. func (w *Writer) FlushFragment() error {
  427. if w.Buffered() == 0 || w.err != nil {
  428. return w.err
  429. }
  430. w.err = w.flushFragment(false)
  431. w.n = 0
  432. w.fseq++
  433. return w.err
  434. }
  435. func (w *Writer) flushFragment(fin bool) (err error) {
  436. var (
  437. payload = w.buf[:w.n]
  438. header = ws.Header{
  439. OpCode: w.opCode(),
  440. Fin: fin,
  441. Length: int64(len(payload)),
  442. }
  443. )
  444. for _, ext := range w.extensions {
  445. header, err = ext.SetBits(header)
  446. if err != nil {
  447. return err
  448. }
  449. }
  450. if w.state.ClientSide() {
  451. header.Masked = true
  452. header.Mask = ws.NewMask()
  453. ws.Cipher(payload, header.Mask, 0)
  454. }
  455. // Write header to the header segment of the raw buffer.
  456. var (
  457. offset = len(w.raw) - len(w.buf)
  458. skip = offset - ws.HeaderSize(header)
  459. )
  460. buf := bytesWriter{
  461. buf: w.raw[skip:offset],
  462. }
  463. if err := ws.WriteHeader(&buf, header); err != nil {
  464. // Must never be reached.
  465. panic("dump header error: " + err.Error())
  466. }
  467. _, err = w.dest.Write(w.raw[skip : offset+w.n])
  468. return err
  469. }
  470. func (w *Writer) opCode() ws.OpCode {
  471. if w.fseq > 0 {
  472. return ws.OpContinuation
  473. }
  474. return w.op
  475. }
  476. var errNoSpace = fmt.Errorf("not enough buffer space")
  477. type bytesWriter struct {
  478. buf []byte
  479. pos int
  480. }
  481. func (w *bytesWriter) Write(p []byte) (int, error) {
  482. n := copy(w.buf[w.pos:], p)
  483. w.pos += n
  484. if n != len(p) {
  485. return n, errNoSpace
  486. }
  487. return n, nil
  488. }
  489. func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
  490. var frame ws.Frame
  491. if s.ClientSide() {
  492. // Should copy bytes to prevent corruption of caller data.
  493. payload := pbytes.GetLen(len(p))
  494. defer pbytes.Put(payload)
  495. copy(payload, p)
  496. frame = ws.NewFrame(op, fin, payload)
  497. frame = ws.MaskFrameInPlace(frame)
  498. } else {
  499. frame = ws.NewFrame(op, fin, p)
  500. }
  501. return ws.WriteFrame(w, frame)
  502. }
  503. // reserve calculates number of bytes need to be reserved for frame header.
  504. //
  505. // Note that instead of ws.HeaderSize() it does calculation based on the buffer
  506. // size, not the payload size.
  507. func reserve(state ws.State, n int) (offset int) {
  508. var mask int
  509. if state.ClientSide() {
  510. mask = 4
  511. }
  512. switch {
  513. case n <= int(len7)+mask+2:
  514. return mask + 2
  515. case n <= int(len16)+mask+4:
  516. return mask + 4
  517. default:
  518. return mask + 10
  519. }
  520. }
  521. // headerSize returns number of bytes needed to encode header of a frame with
  522. // given state and length.
  523. func headerSize(s ws.State, n int) int {
  524. return ws.HeaderSize(ws.Header{
  525. Length: int64(n),
  526. Masked: s.ClientSide(),
  527. })
  528. }