server.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. package ws
  2. import (
  3. "bufio"
  4. "bytes"
  5. "fmt"
  6. "io"
  7. "net"
  8. "net/http"
  9. "strings"
  10. "time"
  11. "github.com/gobwas/httphead"
  12. "github.com/gobwas/pool/pbufio"
  13. )
  14. // Constants used by ConnUpgrader.
  15. const (
  16. DefaultServerReadBufferSize = 4096
  17. DefaultServerWriteBufferSize = 512
  18. )
  19. // Errors used by both client and server when preparing WebSocket handshake.
  20. var (
  21. ErrHandshakeBadProtocol = RejectConnectionError(
  22. RejectionStatus(http.StatusHTTPVersionNotSupported),
  23. RejectionReason("handshake error: bad HTTP protocol version"),
  24. )
  25. ErrHandshakeBadMethod = RejectConnectionError(
  26. RejectionStatus(http.StatusMethodNotAllowed),
  27. RejectionReason("handshake error: bad HTTP request method"),
  28. )
  29. ErrHandshakeBadHost = RejectConnectionError(
  30. RejectionStatus(http.StatusBadRequest),
  31. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)),
  32. )
  33. ErrHandshakeBadUpgrade = RejectConnectionError(
  34. RejectionStatus(http.StatusBadRequest),
  35. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)),
  36. )
  37. ErrHandshakeBadConnection = RejectConnectionError(
  38. RejectionStatus(http.StatusBadRequest),
  39. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)),
  40. )
  41. ErrHandshakeBadSecAccept = RejectConnectionError(
  42. RejectionStatus(http.StatusBadRequest),
  43. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)),
  44. )
  45. ErrHandshakeBadSecKey = RejectConnectionError(
  46. RejectionStatus(http.StatusBadRequest),
  47. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)),
  48. )
  49. ErrHandshakeBadSecVersion = RejectConnectionError(
  50. RejectionStatus(http.StatusBadRequest),
  51. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
  52. )
  53. )
  54. // ErrMalformedResponse is returned by Dialer to indicate that server response
  55. // can not be parsed.
  56. var ErrMalformedResponse = fmt.Errorf("malformed HTTP response")
  57. // ErrMalformedRequest is returned when HTTP request can not be parsed.
  58. var ErrMalformedRequest = RejectConnectionError(
  59. RejectionStatus(http.StatusBadRequest),
  60. RejectionReason("malformed HTTP request"),
  61. )
  62. // ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that
  63. // connection is rejected because given WebSocket version is malformed.
  64. //
  65. // According to RFC6455:
  66. // If this version does not match a version understood by the server, the
  67. // server MUST abort the WebSocket handshake described in this section and
  68. // instead send an appropriate HTTP error code (such as 426 Upgrade Required)
  69. // and a |Sec-WebSocket-Version| header field indicating the version(s) the
  70. // server is capable of understanding.
  71. var ErrHandshakeUpgradeRequired = RejectConnectionError(
  72. RejectionStatus(http.StatusUpgradeRequired),
  73. RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")),
  74. RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
  75. )
  76. // ErrNotHijacker is an error returned when http.ResponseWriter does not
  77. // implement http.Hijacker interface.
  78. var ErrNotHijacker = RejectConnectionError(
  79. RejectionStatus(http.StatusInternalServerError),
  80. RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
  81. )
  82. // DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by
  83. // UpgradeHTTP function.
  84. var DefaultHTTPUpgrader HTTPUpgrader
  85. // UpgradeHTTP is like HTTPUpgrader{}.Upgrade().
  86. func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) {
  87. return DefaultHTTPUpgrader.Upgrade(r, w)
  88. }
  89. // DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade
  90. // function.
  91. var DefaultUpgrader Upgrader
  92. // Upgrade is like Upgrader{}.Upgrade().
  93. func Upgrade(conn io.ReadWriter) (Handshake, error) {
  94. return DefaultUpgrader.Upgrade(conn)
  95. }
  96. // HTTPUpgrader contains options for upgrading connection to websocket from
  97. // net/http Handler arguments.
  98. type HTTPUpgrader struct {
  99. // Timeout is the maximum amount of time an Upgrade() will spent while
  100. // writing handshake response.
  101. //
  102. // The default is no timeout.
  103. Timeout time.Duration
  104. // Header is an optional http.Header mapping that could be used to
  105. // write additional headers to the handshake response.
  106. //
  107. // Note that if present, it will be written in any result of handshake.
  108. Header http.Header
  109. // Protocol is the select function that is used to select subprotocol from
  110. // list requested by client. If this field is set, then the first matched
  111. // protocol is sent to a client as negotiated.
  112. Protocol func(string) bool
  113. // Extension is the select function that is used to select extensions from
  114. // list requested by client. If this field is set, then the all matched
  115. // extensions are sent to a client as negotiated.
  116. //
  117. // Deprecated: use Negotiate instead.
  118. Extension func(httphead.Option) bool
  119. // Negotiate is the callback that is used to negotiate extensions from
  120. // the client's offer. If this field is set, then the returned non-zero
  121. // extensions are sent to the client as accepted extensions in the
  122. // response.
  123. //
  124. // The argument is only valid until the Negotiate callback returns.
  125. //
  126. // If returned error is non-nil then connection is rejected and response is
  127. // sent with appropriate HTTP error code and body set to error message.
  128. //
  129. // RejectConnectionError could be used to get more control on response.
  130. Negotiate func(httphead.Option) (httphead.Option, error)
  131. }
  132. // Upgrade upgrades http connection to the websocket connection.
  133. //
  134. // It hijacks net.Conn from w and returns received net.Conn and
  135. // bufio.ReadWriter. On successful handshake it returns Handshake struct
  136. // describing handshake info.
  137. func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
  138. // Hijack connection first to get the ability to write rejection errors the
  139. // same way as in Upgrader.
  140. conn, rw, err = hijack(w)
  141. if err != nil {
  142. httpError(w, err.Error(), http.StatusInternalServerError)
  143. return conn, rw, hs, err
  144. }
  145. // See https://tools.ietf.org/html/rfc6455#section-4.1
  146. // The method of the request MUST be GET, and the HTTP version MUST be at least 1.1.
  147. var nonce string
  148. if r.Method != http.MethodGet {
  149. err = ErrHandshakeBadMethod
  150. } else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) {
  151. err = ErrHandshakeBadProtocol
  152. } else if r.Host == "" {
  153. err = ErrHandshakeBadHost
  154. } else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
  155. err = ErrHandshakeBadUpgrade
  156. } else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
  157. err = ErrHandshakeBadConnection
  158. } else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
  159. err = ErrHandshakeBadSecKey
  160. } else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
  161. // According to RFC6455:
  162. //
  163. // If this version does not match a version understood by the server,
  164. // the server MUST abort the WebSocket handshake described in this
  165. // section and instead send an appropriate HTTP error code (such as 426
  166. // Upgrade Required) and a |Sec-WebSocket-Version| header field
  167. // indicating the version(s) the server is capable of understanding.
  168. //
  169. // So we branching here cause empty or not present version does not
  170. // meet the ABNF rules of RFC6455:
  171. //
  172. // version = DIGIT | (NZDIGIT DIGIT) |
  173. // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
  174. // ; Limited to 0-255 range, with no leading zeros
  175. //
  176. // That is, if version is really invalid – we sent 426 status, if it
  177. // not present or empty – it is 400.
  178. if v != "" {
  179. err = ErrHandshakeUpgradeRequired
  180. } else {
  181. err = ErrHandshakeBadSecVersion
  182. }
  183. }
  184. if check := u.Protocol; err == nil && check != nil {
  185. ps := r.Header[headerSecProtocolCanonical]
  186. for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
  187. var ok bool
  188. hs.Protocol, ok = strSelectProtocol(ps[i], check)
  189. if !ok {
  190. err = ErrMalformedRequest
  191. }
  192. }
  193. }
  194. if f := u.Negotiate; err == nil && f != nil {
  195. for _, h := range r.Header[headerSecExtensionsCanonical] {
  196. hs.Extensions, err = negotiateExtensions(strToBytes(h), hs.Extensions, f)
  197. if err != nil {
  198. break
  199. }
  200. }
  201. }
  202. // DEPRECATED path.
  203. if check := u.Extension; err == nil && check != nil && u.Negotiate == nil {
  204. xs := r.Header[headerSecExtensionsCanonical]
  205. for i := 0; i < len(xs) && err == nil; i++ {
  206. var ok bool
  207. hs.Extensions, ok = btsSelectExtensions(strToBytes(xs[i]), hs.Extensions, check)
  208. if !ok {
  209. err = ErrMalformedRequest
  210. }
  211. }
  212. }
  213. // Clear deadlines set by server.
  214. conn.SetDeadline(noDeadline)
  215. if t := u.Timeout; t != 0 {
  216. conn.SetWriteDeadline(time.Now().Add(t))
  217. defer conn.SetWriteDeadline(noDeadline)
  218. }
  219. var header handshakeHeader
  220. if h := u.Header; h != nil {
  221. header[0] = HandshakeHeaderHTTP(h)
  222. }
  223. if err == nil {
  224. httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo)
  225. err = rw.Writer.Flush()
  226. } else {
  227. var code int
  228. if rej, ok := err.(*ConnectionRejectedError); ok {
  229. code = rej.code
  230. header[1] = rej.header
  231. }
  232. if code == 0 {
  233. code = http.StatusInternalServerError
  234. }
  235. httpWriteResponseError(rw.Writer, err, code, header.WriteTo)
  236. // Do not store Flush() error to not override already existing one.
  237. _ = rw.Writer.Flush()
  238. }
  239. return conn, rw, hs, err
  240. }
  241. // Upgrader contains options for upgrading connection to websocket.
  242. type Upgrader struct {
  243. // ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
  244. // They used to read and write http data while upgrading to WebSocket.
  245. // Allocated buffers are pooled with sync.Pool to avoid extra allocations.
  246. //
  247. // If a size is zero then default value is used.
  248. //
  249. // Usually it is useful to set read buffer size bigger than write buffer
  250. // size because incoming request could contain long header values, such as
  251. // Cookie. Response, in other way, could be big only if user write multiple
  252. // custom headers. Usually response takes less than 256 bytes.
  253. ReadBufferSize, WriteBufferSize int
  254. // Protocol is a select function that is used to select subprotocol
  255. // from list requested by client. If this field is set, then the first matched
  256. // protocol is sent to a client as negotiated.
  257. //
  258. // The argument is only valid until the callback returns.
  259. Protocol func([]byte) bool
  260. // ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually.
  261. // Note that returned bytes must be valid until Upgrade returns.
  262. // If ProtocolCustom is set, it used instead of Protocol function.
  263. ProtocolCustom func([]byte) (string, bool)
  264. // Extension is a select function that is used to select extensions
  265. // from list requested by client. If this field is set, then the all matched
  266. // extensions are sent to a client as negotiated.
  267. //
  268. // Note that Extension may be called multiple times and implementations
  269. // must track uniqueness of accepted extensions manually.
  270. //
  271. // The argument is only valid until the callback returns.
  272. //
  273. // According to the RFC6455 order of extensions passed by a client is
  274. // significant. That is, returning true from this function means that no
  275. // other extension with the same name should be checked because server
  276. // accepted the most preferable extension right now:
  277. // "Note that the order of extensions is significant. Any interactions between
  278. // multiple extensions MAY be defined in the documents defining the extensions.
  279. // In the absence of such definitions, the interpretation is that the header
  280. // fields listed by the client in its request represent a preference of the
  281. // header fields it wishes to use, with the first options listed being most
  282. // preferable."
  283. //
  284. // Deprecated: use Negotiate instead.
  285. Extension func(httphead.Option) bool
  286. // ExtensionCustom allow user to parse Sec-WebSocket-Extensions header
  287. // manually.
  288. //
  289. // If ExtensionCustom() decides to accept received extension, it must
  290. // append appropriate option to the given slice of httphead.Option.
  291. // It returns results of append() to the given slice and a flag that
  292. // reports whether given header value is wellformed or not.
  293. //
  294. // Note that ExtensionCustom may be called multiple times and
  295. // implementations must track uniqueness of accepted extensions manually.
  296. //
  297. // Note that returned options should be valid until Upgrade returns.
  298. // If ExtensionCustom is set, it used instead of Extension function.
  299. ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool)
  300. // Negotiate is the callback that is used to negotiate extensions from
  301. // the client's offer. If this field is set, then the returned non-zero
  302. // extensions are sent to the client as accepted extensions in the
  303. // response.
  304. //
  305. // The argument is only valid until the Negotiate callback returns.
  306. //
  307. // If returned error is non-nil then connection is rejected and response is
  308. // sent with appropriate HTTP error code and body set to error message.
  309. //
  310. // RejectConnectionError could be used to get more control on response.
  311. Negotiate func(httphead.Option) (httphead.Option, error)
  312. // Header is an optional HandshakeHeader instance that could be used to
  313. // write additional headers to the handshake response.
  314. //
  315. // It used instead of any key-value mappings to avoid allocations in user
  316. // land.
  317. //
  318. // Note that if present, it will be written in any result of handshake.
  319. Header HandshakeHeader
  320. // OnRequest is a callback that will be called after request line
  321. // successful parsing.
  322. //
  323. // The arguments are only valid until the callback returns.
  324. //
  325. // If returned error is non-nil then connection is rejected and response is
  326. // sent with appropriate HTTP error code and body set to error message.
  327. //
  328. // RejectConnectionError could be used to get more control on response.
  329. OnRequest func(uri []byte) error
  330. // OnHost is a callback that will be called after "Host" header successful
  331. // parsing.
  332. //
  333. // It is separated from OnHeader callback because the Host header must be
  334. // present in each request since HTTP/1.1. Thus Host header is non-optional
  335. // and required for every WebSocket handshake.
  336. //
  337. // The arguments are only valid until the callback returns.
  338. //
  339. // If returned error is non-nil then connection is rejected and response is
  340. // sent with appropriate HTTP error code and body set to error message.
  341. //
  342. // RejectConnectionError could be used to get more control on response.
  343. OnHost func(host []byte) error
  344. // OnHeader is a callback that will be called after successful parsing of
  345. // header, that is not used during WebSocket handshake procedure. That is,
  346. // it will be called with non-websocket headers, which could be relevant
  347. // for application-level logic.
  348. //
  349. // The arguments are only valid until the callback returns.
  350. //
  351. // If returned error is non-nil then connection is rejected and response is
  352. // sent with appropriate HTTP error code and body set to error message.
  353. //
  354. // RejectConnectionError could be used to get more control on response.
  355. OnHeader func(key, value []byte) error
  356. // OnBeforeUpgrade is a callback that will be called before sending
  357. // successful upgrade response.
  358. //
  359. // Setting OnBeforeUpgrade allows user to make final application-level
  360. // checks and decide whether this connection is allowed to successfully
  361. // upgrade to WebSocket.
  362. //
  363. // It must return non-nil either HandshakeHeader or error and never both.
  364. //
  365. // If returned error is non-nil then connection is rejected and response is
  366. // sent with appropriate HTTP error code and body set to error message.
  367. //
  368. // RejectConnectionError could be used to get more control on response.
  369. OnBeforeUpgrade func() (header HandshakeHeader, err error)
  370. }
  371. // Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn
  372. // as connection with incoming HTTP Upgrade request.
  373. //
  374. // It is a caller responsibility to manage i/o timeouts on conn.
  375. //
  376. // Non-nil error means that request for the WebSocket upgrade is invalid or
  377. // malformed and usually connection should be closed.
  378. // Even when error is non-nil Upgrade will write appropriate response into
  379. // connection in compliance with RFC.
  380. func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
  381. // headerSeen constants helps to report whether or not some header was seen
  382. // during reading request bytes.
  383. const (
  384. headerSeenHost = 1 << iota
  385. headerSeenUpgrade
  386. headerSeenConnection
  387. headerSeenSecVersion
  388. headerSeenSecKey
  389. // headerSeenAll is the value that we expect to receive at the end of
  390. // headers read/parse loop.
  391. headerSeenAll = 0 |
  392. headerSeenHost |
  393. headerSeenUpgrade |
  394. headerSeenConnection |
  395. headerSeenSecVersion |
  396. headerSeenSecKey
  397. )
  398. // Prepare I/O buffers.
  399. // TODO(gobwas): make it configurable.
  400. br := pbufio.GetReader(conn,
  401. nonZero(u.ReadBufferSize, DefaultServerReadBufferSize),
  402. )
  403. bw := pbufio.GetWriter(conn,
  404. nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize),
  405. )
  406. defer func() {
  407. pbufio.PutReader(br)
  408. pbufio.PutWriter(bw)
  409. }()
  410. // Read HTTP request line like "GET /ws HTTP/1.1".
  411. rl, err := readLine(br)
  412. if err != nil {
  413. return hs, err
  414. }
  415. // Parse request line data like HTTP version, uri and method.
  416. req, err := httpParseRequestLine(rl)
  417. if err != nil {
  418. return hs, err
  419. }
  420. // Prepare stack-based handshake header list.
  421. header := handshakeHeader{
  422. 0: u.Header,
  423. }
  424. // Parse and check HTTP request.
  425. // As RFC6455 says:
  426. // The client's opening handshake consists of the following parts. If the
  427. // server, while reading the handshake, finds that the client did not
  428. // send a handshake that matches the description below (note that as per
  429. // [RFC2616], the order of the header fields is not important), including
  430. // but not limited to any violations of the ABNF grammar specified for
  431. // the components of the handshake, the server MUST stop processing the
  432. // client's handshake and return an HTTP response with an appropriate
  433. // error code (such as 400 Bad Request).
  434. //
  435. // See https://tools.ietf.org/html/rfc6455#section-4.2.1
  436. // An HTTP/1.1 or higher GET request, including a "Request-URI".
  437. //
  438. // Even if RFC says "1.1 or higher" without mentioning the part of the
  439. // version, we apply it only to minor part.
  440. switch {
  441. case req.major != 1 || req.minor < 1:
  442. // Abort processing the whole request because we do not even know how
  443. // to actually parse it.
  444. err = ErrHandshakeBadProtocol
  445. case btsToString(req.method) != http.MethodGet:
  446. err = ErrHandshakeBadMethod
  447. default:
  448. if onRequest := u.OnRequest; onRequest != nil {
  449. err = onRequest(req.uri)
  450. }
  451. }
  452. // Start headers read/parse loop.
  453. var (
  454. // headerSeen reports which header was seen by setting corresponding
  455. // bit on.
  456. headerSeen byte
  457. nonce = make([]byte, nonceSize)
  458. )
  459. for err == nil {
  460. line, e := readLine(br)
  461. if e != nil {
  462. return hs, e
  463. }
  464. if len(line) == 0 {
  465. // Blank line, no more lines to read.
  466. break
  467. }
  468. k, v, ok := httpParseHeaderLine(line)
  469. if !ok {
  470. err = ErrMalformedRequest
  471. break
  472. }
  473. switch btsToString(k) {
  474. case headerHostCanonical:
  475. headerSeen |= headerSeenHost
  476. if onHost := u.OnHost; onHost != nil {
  477. err = onHost(v)
  478. }
  479. case headerUpgradeCanonical:
  480. headerSeen |= headerSeenUpgrade
  481. if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
  482. err = ErrHandshakeBadUpgrade
  483. }
  484. case headerConnectionCanonical:
  485. headerSeen |= headerSeenConnection
  486. if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
  487. err = ErrHandshakeBadConnection
  488. }
  489. case headerSecVersionCanonical:
  490. headerSeen |= headerSeenSecVersion
  491. if !bytes.Equal(v, specHeaderValueSecVersion) {
  492. err = ErrHandshakeUpgradeRequired
  493. }
  494. case headerSecKeyCanonical:
  495. headerSeen |= headerSeenSecKey
  496. if len(v) != nonceSize {
  497. err = ErrHandshakeBadSecKey
  498. } else {
  499. copy(nonce, v)
  500. }
  501. case headerSecProtocolCanonical:
  502. if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
  503. var ok bool
  504. if custom != nil {
  505. hs.Protocol, ok = custom(v)
  506. } else {
  507. hs.Protocol, ok = btsSelectProtocol(v, check)
  508. }
  509. if !ok {
  510. err = ErrMalformedRequest
  511. }
  512. }
  513. case headerSecExtensionsCanonical:
  514. if f := u.Negotiate; err == nil && f != nil {
  515. hs.Extensions, err = negotiateExtensions(v, hs.Extensions, f)
  516. }
  517. // DEPRECATED path.
  518. if custom, check := u.ExtensionCustom, u.Extension; u.Negotiate == nil && (custom != nil || check != nil) {
  519. var ok bool
  520. if custom != nil {
  521. hs.Extensions, ok = custom(v, hs.Extensions)
  522. } else {
  523. hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check)
  524. }
  525. if !ok {
  526. err = ErrMalformedRequest
  527. }
  528. }
  529. default:
  530. if onHeader := u.OnHeader; onHeader != nil {
  531. err = onHeader(k, v)
  532. }
  533. }
  534. }
  535. switch {
  536. case err == nil && headerSeen != headerSeenAll:
  537. switch {
  538. case headerSeen&headerSeenHost == 0:
  539. // As RFC2616 says:
  540. // A client MUST include a Host header field in all HTTP/1.1
  541. // request messages. If the requested URI does not include an
  542. // Internet host name for the service being requested, then the
  543. // Host header field MUST be given with an empty value. An
  544. // HTTP/1.1 proxy MUST ensure that any request message it
  545. // forwards does contain an appropriate Host header field that
  546. // identifies the service being requested by the proxy. All
  547. // Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad
  548. // Request) status code to any HTTP/1.1 request message which
  549. // lacks a Host header field.
  550. err = ErrHandshakeBadHost
  551. case headerSeen&headerSeenUpgrade == 0:
  552. err = ErrHandshakeBadUpgrade
  553. case headerSeen&headerSeenConnection == 0:
  554. err = ErrHandshakeBadConnection
  555. case headerSeen&headerSeenSecVersion == 0:
  556. // In case of empty or not present version we do not send 426 status,
  557. // because it does not meet the ABNF rules of RFC6455:
  558. //
  559. // version = DIGIT | (NZDIGIT DIGIT) |
  560. // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
  561. // ; Limited to 0-255 range, with no leading zeros
  562. //
  563. // That is, if version is really invalid – we sent 426 status as above, if it
  564. // not present – it is 400.
  565. err = ErrHandshakeBadSecVersion
  566. case headerSeen&headerSeenSecKey == 0:
  567. err = ErrHandshakeBadSecKey
  568. default:
  569. panic("unknown headers state")
  570. }
  571. case err == nil && u.OnBeforeUpgrade != nil:
  572. header[1], err = u.OnBeforeUpgrade()
  573. }
  574. if err != nil {
  575. var code int
  576. if rej, ok := err.(*ConnectionRejectedError); ok {
  577. code = rej.code
  578. header[1] = rej.header
  579. }
  580. if code == 0 {
  581. code = http.StatusInternalServerError
  582. }
  583. httpWriteResponseError(bw, err, code, header.WriteTo)
  584. // Do not store Flush() error to not override already existing one.
  585. _ = bw.Flush()
  586. return hs, err
  587. }
  588. httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
  589. err = bw.Flush()
  590. return hs, err
  591. }
  592. type handshakeHeader [2]HandshakeHeader
  593. func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) {
  594. for i := 0; i < len(hs) && err == nil; i++ {
  595. if h := hs[i]; h != nil {
  596. var m int64
  597. m, err = h.WriteTo(w)
  598. n += m
  599. }
  600. }
  601. return n, err
  602. }