0
Fork 0
mirror of https://github.com/ninenines/cowboy.git synced 2025-07-14 12:20:24 +00:00

Don't discard data following a Websocket upgrade request

While the protocol does not allow sending data before
receiving a successful Websocket upgrade response, we
do not want to discard that data if it does come in.
This commit is contained in:
Loïc Hoguin 2019-10-05 13:04:21 +02:00
parent 618c001291
commit c50d6aa09c
No known key found for this signature in database
GPG key ID: 8A9DF795F6FED764
4 changed files with 81 additions and 60 deletions

View file

@ -111,6 +111,7 @@
transport :: module(), transport :: module(),
proxy_header :: undefined | ranch_proxy_header:proxy_info(), proxy_header :: undefined | ranch_proxy_header:proxy_info(),
opts = #{} :: cowboy:opts(), opts = #{} :: cowboy:opts(),
buffer = <<>> :: binary(),
%% Some options may be overriden for the current stream. %% Some options may be overriden for the current stream.
overriden_opts = #{} :: cowboy:opts(), overriden_opts = #{} :: cowboy:opts(),
@ -175,7 +176,7 @@ init(Parent, Ref, Socket, Transport, ProxyHeader, Opts) ->
parent=Parent, ref=Ref, socket=Socket, parent=Parent, ref=Ref, socket=Socket,
transport=Transport, proxy_header=ProxyHeader, opts=Opts, transport=Transport, proxy_header=ProxyHeader, opts=Opts,
peer=Peer, sock=Sock, cert=Cert, peer=Peer, sock=Sock, cert=Cert,
last_streamid=LastStreamID}), <<>>); last_streamid=LastStreamID}));
{{error, Reason}, _, _} -> {{error, Reason}, _, _} ->
terminate(undefined, {socket_error, Reason, terminate(undefined, {socket_error, Reason,
'A socket error occurred when retrieving the peer name.'}); 'A socket error occurred when retrieving the peer name.'});
@ -187,22 +188,22 @@ init(Parent, Ref, Socket, Transport, ProxyHeader, Opts) ->
'A socket error occurred when retrieving the client TLS certificate.'}) 'A socket error occurred when retrieving the client TLS certificate.'})
end. end.
before_loop(State=#state{socket=Socket, transport=Transport}, Buffer) -> before_loop(State=#state{socket=Socket, transport=Transport}) ->
%% @todo disable this when we get to the body, until the stream asks for it? %% @todo disable this when we get to the body, until the stream asks for it?
%% Perhaps have a threshold for how much we're willing to read before waiting. %% Perhaps have a threshold for how much we're willing to read before waiting.
Transport:setopts(Socket, [{active, once}]), Transport:setopts(Socket, [{active, once}]),
loop(State, Buffer). loop(State).
loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts, loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts,
timer=TimerRef, children=Children, in_streamid=InStreamID, buffer=Buffer, timer=TimerRef, children=Children, in_streamid=InStreamID,
last_streamid=LastStreamID, streams=Streams}, Buffer) -> last_streamid=LastStreamID, streams=Streams}) ->
Messages = Transport:messages(), Messages = Transport:messages(),
InactivityTimeout = maps:get(inactivity_timeout, Opts, 300000), InactivityTimeout = maps:get(inactivity_timeout, Opts, 300000),
receive receive
%% Discard data coming in after the last request %% Discard data coming in after the last request
%% we want to process was received fully. %% we want to process was received fully.
{OK, Socket, _} when OK =:= element(1, Messages), InStreamID > LastStreamID -> {OK, Socket, _} when OK =:= element(1, Messages), InStreamID > LastStreamID ->
before_loop(State, Buffer); before_loop(State);
%% Socket messages. %% Socket messages.
{OK, Socket, Data} when OK =:= element(1, Messages) -> {OK, Socket, Data} when OK =:= element(1, Messages) ->
%% Only reset the timeout if it is idle_timeout (active streams). %% Only reset the timeout if it is idle_timeout (active streams).
@ -218,30 +219,30 @@ loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts,
%% Timeouts. %% Timeouts.
{timeout, Ref, {shutdown, Pid}} -> {timeout, Ref, {shutdown, Pid}} ->
cowboy_children:shutdown_timeout(Children, Ref, Pid), cowboy_children:shutdown_timeout(Children, Ref, Pid),
loop(State, Buffer); loop(State);
{timeout, TimerRef, Reason} -> {timeout, TimerRef, Reason} ->
timeout(State, Reason); timeout(State, Reason);
{timeout, _, _} -> {timeout, _, _} ->
loop(State, Buffer); loop(State);
%% System messages. %% System messages.
{'EXIT', Parent, Reason} -> {'EXIT', Parent, Reason} ->
terminate(State, {stop, {exit, Reason}, 'Parent process terminated.'}); terminate(State, {stop, {exit, Reason}, 'Parent process terminated.'});
{system, From, Request} -> {system, From, Request} ->
sys:handle_system_msg(Request, From, Parent, ?MODULE, [], {State, Buffer}); sys:handle_system_msg(Request, From, Parent, ?MODULE, [], State);
%% Messages pertaining to a stream. %% Messages pertaining to a stream.
{{Pid, StreamID}, Msg} when Pid =:= self() -> {{Pid, StreamID}, Msg} when Pid =:= self() ->
loop(info(State, StreamID, Msg), Buffer); loop(info(State, StreamID, Msg));
%% Exit signal from children. %% Exit signal from children.
Msg = {'EXIT', Pid, _} -> Msg = {'EXIT', Pid, _} ->
loop(down(State, Pid, Msg), Buffer); loop(down(State, Pid, Msg));
%% Calls from supervisor module. %% Calls from supervisor module.
{'$gen_call', From, Call} -> {'$gen_call', From, Call} ->
cowboy_children:handle_supervisor_call(Call, From, Children, ?MODULE), cowboy_children:handle_supervisor_call(Call, From, Children, ?MODULE),
loop(State, Buffer); loop(State);
%% Unknown messages. %% Unknown messages.
Msg -> Msg ->
cowboy:log(warning, "Received stray message ~p.~n", [Msg], Opts), cowboy:log(warning, "Received stray message ~p.~n", [Msg], Opts),
loop(State, Buffer) loop(State)
after InactivityTimeout -> after InactivityTimeout ->
terminate(State, {internal_error, timeout, 'No message or data received before timeout.'}) terminate(State, {internal_error, timeout, 'No message or data received before timeout.'})
end. end.
@ -293,12 +294,12 @@ timeout(State, idle_timeout) ->
'Connection idle longer than configuration allows.'}). 'Connection idle longer than configuration allows.'}).
parse(<<>>, State) -> parse(<<>>, State) ->
before_loop(State, <<>>); before_loop(State#state{buffer= <<>>});
%% Do not process requests that come in after the last request %% Do not process requests that come in after the last request
%% and discard the buffer if any to save memory. %% and discard the buffer if any to save memory.
parse(_, State=#state{in_streamid=InStreamID, in_state=#ps_request_line{}, parse(_, State=#state{in_streamid=InStreamID, in_state=#ps_request_line{},
last_streamid=LastStreamID}) when InStreamID > LastStreamID -> last_streamid=LastStreamID}) when InStreamID > LastStreamID ->
before_loop(State, <<>>); before_loop(State#state{buffer= <<>>});
parse(Buffer, State=#state{in_state=#ps_request_line{empty_lines=EmptyLines}}) -> parse(Buffer, State=#state{in_state=#ps_request_line{empty_lines=EmptyLines}}) ->
after_parse(parse_request(Buffer, State, EmptyLines)); after_parse(parse_request(Buffer, State, EmptyLines));
parse(Buffer, State=#state{in_state=PS=#ps_header{headers=Headers, name=undefined}}) -> parse(Buffer, State=#state{in_state=PS=#ps_header{headers=Headers, name=undefined}}) ->
@ -317,7 +318,7 @@ parse(Buffer, State=#state{in_state=#ps_body{}}) ->
after_parse({request, Req=#{streamid := StreamID, method := Method, after_parse({request, Req=#{streamid := StreamID, method := Method,
headers := Headers, version := Version}, headers := Headers, version := Version},
State0=#state{opts=Opts, streams=Streams0}, Buffer}) -> State0=#state{opts=Opts, buffer=Buffer, streams=Streams0}}) ->
try cowboy_stream:init(StreamID, Req, Opts) of try cowboy_stream:init(StreamID, Req, Opts) of
{Commands, StreamState} -> {Commands, StreamState} ->
TE = maps:get(<<"te">>, Headers, undefined), TE = maps:get(<<"te">>, Headers, undefined),
@ -339,8 +340,8 @@ after_parse({request, Req=#{streamid := StreamID, method := Method,
end; end;
%% Streams are sequential so the body is always about the last stream created %% Streams are sequential so the body is always about the last stream created
%% unless that stream has terminated. %% unless that stream has terminated.
after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts, after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts, buffer=Buffer,
streams=Streams0=[Stream=#stream{id=StreamID, state=StreamState0}|_]}, Buffer}) -> streams=Streams0=[Stream=#stream{id=StreamID, state=StreamState0}|_]}}) ->
try cowboy_stream:data(StreamID, IsFin, Data, StreamState0) of try cowboy_stream:data(StreamID, IsFin, Data, StreamState0) of
{Commands, StreamState} -> {Commands, StreamState} ->
Streams = lists:keyreplace(StreamID, #stream.id, Streams0, Streams = lists:keyreplace(StreamID, #stream.id, Streams0,
@ -355,17 +356,17 @@ after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts,
end; end;
%% No corresponding stream. We must skip the body of the previous request %% No corresponding stream. We must skip the body of the previous request
%% in order to process the next one. %% in order to process the next one.
after_parse({data, _, _, _, State, Buffer}) -> after_parse({data, _, _, _, State}) ->
before_loop(State, Buffer); before_loop(State);
after_parse({more, State, Buffer}) -> after_parse({more, State}) ->
before_loop(State, Buffer). before_loop(State).
%% Request-line. %% Request-line.
-spec parse_request(Buffer, State, non_neg_integer()) -spec parse_request(Buffer, State, non_neg_integer())
-> {request, cowboy_req:req(), State, Buffer} -> {request, cowboy_req:req(), State}
| {data, cowboy_stream:streamid(), cowboy_stream:fin(), binary(), State, Buffer} | {data, cowboy_stream:streamid(), cowboy_stream:fin(), binary(), State}
| {more, State, Buffer} | {more, State}
when Buffer::binary(), State::#state{}. when Buffer::binary(), State::#state{}.
%% Empty lines must be using \r\n. %% Empty lines must be using \r\n.
parse_request(<< $\n, _/bits >>, State, _) -> parse_request(<< $\n, _/bits >>, State, _) ->
@ -384,7 +385,7 @@ parse_request(Buffer, State=#state{opts=Opts, in_streamid=InStreamID}, EmptyLine
error_terminate(414, State, {connection_error, limit_reached, error_terminate(414, State, {connection_error, limit_reached,
'The request-line length is larger than configuration allows. (RFC7230 3.1.1)'}); 'The request-line length is larger than configuration allows. (RFC7230 3.1.1)'});
nomatch -> nomatch ->
{more, State#state{in_state=#ps_request_line{empty_lines=EmptyLines}}, Buffer}; {more, State#state{buffer=Buffer, in_state=#ps_request_line{empty_lines=EmptyLines}}};
1 when EmptyLines =:= MaxEmptyLines -> 1 when EmptyLines =:= MaxEmptyLines ->
error_terminate(400, State, {connection_error, limit_reached, error_terminate(400, State, {connection_error, limit_reached,
'More empty lines were received than configuration allows. (RFC7230 3.5)'}); 'More empty lines were received than configuration allows. (RFC7230 3.5)'});
@ -527,7 +528,7 @@ before_parse_headers(Rest, State, M, A, P, Q, V) ->
%% We need two or more bytes in the buffer to continue. %% We need two or more bytes in the buffer to continue.
parse_header(Rest, State=#state{in_state=PS}, Headers) when byte_size(Rest) < 2 -> parse_header(Rest, State=#state{in_state=PS}, Headers) when byte_size(Rest) < 2 ->
{more, State#state{in_state=PS#ps_header{headers=Headers}}, Rest}; {more, State#state{buffer=Rest, in_state=PS#ps_header{headers=Headers}}};
parse_header(<< $\r, $\n, Rest/bits >>, S, Headers) -> parse_header(<< $\r, $\n, Rest/bits >>, S, Headers) ->
request(Rest, S, Headers); request(Rest, S, Headers);
parse_header(Buffer, State=#state{opts=Opts, in_state=PS}, Headers) -> parse_header(Buffer, State=#state{opts=Opts, in_state=PS}, Headers) ->
@ -554,7 +555,7 @@ parse_header_colon(Buffer, State=#state{opts=Opts, in_state=PS}, Headers) ->
%% so check if we have an LF and abort with an error if we do. %% so check if we have an LF and abort with an error if we do.
case match_eol(Buffer, 0) of case match_eol(Buffer, 0) of
nomatch -> nomatch ->
{more, State#state{in_state=PS#ps_header{headers=Headers}}, Buffer}; {more, State#state{buffer=Buffer, in_state=PS#ps_header{headers=Headers}}};
_ -> _ ->
error_terminate(400, State#state{in_state=PS#ps_header{headers=Headers}}, error_terminate(400, State#state{in_state=PS#ps_header{headers=Headers}},
{connection_error, protocol_error, {connection_error, protocol_error,
@ -596,7 +597,7 @@ parse_hd_before_value(Buffer, State=#state{opts=Opts, in_state=PS}, H, N) ->
{connection_error, limit_reached, {connection_error, limit_reached,
'A header value is larger than configuration allows. (RFC7230 3.2.5, RFC6585 5)'}); 'A header value is larger than configuration allows. (RFC7230 3.2.5, RFC6585 5)'});
nomatch -> nomatch ->
{more, State#state{in_state=PS#ps_header{headers=H, name=N}}, Buffer}; {more, State#state{buffer=Buffer, in_state=PS#ps_header{headers=H, name=N}}};
_ -> _ ->
parse_hd_value(Buffer, State, H, N, <<>>) parse_hd_value(Buffer, State, H, N, <<>>)
end. end.
@ -766,7 +767,7 @@ request(Buffer, State0=#state{ref=Ref, transport=Transport, peer=Peer, sock=Sock
false -> false ->
State0#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}} State0#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}
end, end,
{request, Req, State, Buffer}; {request, Req, State#state{buffer=Buffer}};
{true, HTTP2Settings} -> {true, HTTP2Settings} ->
%% We save the headers in case the upgrade will fail %% We save the headers in case the upgrade will fail
%% and we need to pass them to cowboy_stream:early_error. %% and we need to pass them to cowboy_stream:early_error.
@ -835,28 +836,28 @@ parse_body(Buffer, State=#state{in_streamid=StreamID, in_state=
try TDecode(Buffer, TState0) of try TDecode(Buffer, TState0) of
more -> more ->
%% @todo Asks for 0 or more bytes. %% @todo Asks for 0 or more bytes.
{more, State, Buffer}; {more, State#state{buffer=Buffer}};
{more, Data, TState} -> {more, Data, TState} ->
%% @todo Asks for 0 or more bytes. %% @todo Asks for 0 or more bytes.
{data, StreamID, nofin, Data, State#state{in_state= {data, StreamID, nofin, Data, State#state{buffer= <<>>,
PS#ps_body{received=Received + byte_size(Data), in_state=PS#ps_body{received=Received + byte_size(Data),
transfer_decode_state=TState}}, <<>>}; transfer_decode_state=TState}}};
{more, Data, _Length, TState} when is_integer(_Length) -> {more, Data, _Length, TState} when is_integer(_Length) ->
%% @todo Asks for Length more bytes. %% @todo Asks for Length more bytes.
{data, StreamID, nofin, Data, State#state{in_state= {data, StreamID, nofin, Data, State#state{buffer= <<>>,
PS#ps_body{received=Received + byte_size(Data), in_state=PS#ps_body{received=Received + byte_size(Data),
transfer_decode_state=TState}}, <<>>}; transfer_decode_state=TState}}};
{more, Data, Rest, TState} -> {more, Data, Rest, TState} ->
%% @todo Asks for 0 or more bytes. %% @todo Asks for 0 or more bytes.
{data, StreamID, nofin, Data, State#state{in_state= {data, StreamID, nofin, Data, State#state{buffer=Rest,
PS#ps_body{received=Received + byte_size(Data), in_state=PS#ps_body{received=Received + byte_size(Data),
transfer_decode_state=TState}}, Rest}; transfer_decode_state=TState}}};
{done, _HasTrailers, Rest} -> {done, _HasTrailers, Rest} ->
{data, StreamID, fin, <<>>, set_timeout( {data, StreamID, fin, <<>>, set_timeout(
State#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}), Rest}; State#state{buffer=Rest, in_streamid=StreamID + 1, in_state=#ps_request_line{}})};
{done, Data, _HasTrailers, Rest} -> {done, Data, _HasTrailers, Rest} ->
{data, StreamID, fin, Data, set_timeout( {data, StreamID, fin, Data, set_timeout(
State#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}), Rest} State#state{buffer=Rest, in_streamid=StreamID + 1, in_state=#ps_request_line{}})}
catch _:_ -> catch _:_ ->
Reason = {connection_error, protocol_error, Reason = {connection_error, protocol_error,
'Failure to decode the content. (RFC7230 4)'}, 'Failure to decode the content. (RFC7230 4)'},
@ -1094,7 +1095,7 @@ commands(State=#state{socket=Socket, transport=Transport, streams=Streams, out_s
commands(State#state{out_state=done}, StreamID, Tail); commands(State#state{out_state=done}, StreamID, Tail);
%% Protocol takeover. %% Protocol takeover.
commands(State0=#state{ref=Ref, parent=Parent, socket=Socket, transport=Transport, commands(State0=#state{ref=Ref, parent=Parent, socket=Socket, transport=Transport,
out_state=OutState, opts=Opts, children=Children}, StreamID, out_state=OutState, opts=Opts, buffer=Buffer, children=Children}, StreamID,
[{switch_protocol, Headers, Protocol, InitialState}|_Tail]) -> [{switch_protocol, Headers, Protocol, InitialState}|_Tail]) ->
%% @todo This should be the last stream running otherwise we need to wait before switching. %% @todo This should be the last stream running otherwise we need to wait before switching.
%% @todo If there's streams opened after this one, fail instead of 101. %% @todo If there's streams opened after this one, fail instead of 101.
@ -1117,10 +1118,7 @@ commands(State0=#state{ref=Ref, parent=Parent, socket=Socket, transport=Transpor
%% Terminate children processes and flush any remaining messages from the mailbox. %% Terminate children processes and flush any remaining messages from the mailbox.
cowboy_children:terminate(Children), cowboy_children:terminate(Children),
flush(Parent), flush(Parent),
%% @todo This is no good because commands return a state normally and here it doesn't Protocol:takeover(Parent, Ref, Socket, Transport, Opts, Buffer, InitialState);
%% we need to let this module go entirely. Perhaps it should be handled directly in
%% cowboy_clear/cowboy_tls?
Protocol:takeover(Parent, Ref, Socket, Transport, Opts, <<>>, InitialState);
%% Set options dynamically. %% Set options dynamically.
commands(State0=#state{overriden_opts=Opts}, commands(State0=#state{overriden_opts=Opts},
StreamID, [{set_options, SetOpts}|Tail]) -> StreamID, [{set_options, SetOpts}|Tail]) ->
@ -1446,12 +1444,12 @@ terminate_linger_loop(State=#state{socket=Socket, transport=Transport}, TimerRef
%% System callbacks. %% System callbacks.
-spec system_continue(_, _, {#state{}, binary()}) -> ok. -spec system_continue(_, _, #state{}) -> ok.
system_continue(_, _, {State, Buffer}) -> system_continue(_, _, State) ->
loop(State, Buffer). loop(State).
-spec system_terminate(any(), _, _, {#state{}, binary()}) -> no_return(). -spec system_terminate(any(), _, _, {#state{}, binary()}) -> no_return().
system_terminate(Reason, _, _, {State, _}) -> system_terminate(Reason, _, _, State) ->
terminate(State, {stop, {exit, Reason}, 'sys:terminate/2,3 was called.'}). terminate(State, {stop, {exit, Reason}, 'sys:terminate/2,3 was called.'}).
-spec system_code_change(Misc, _, _, _) -> {ok, Misc} when Misc::{#state{}, binary()}. -spec system_code_change(Misc, _, _, _) -> {ok, Misc} when Misc::{#state{}, binary()}.

View file

@ -291,10 +291,14 @@ takeover(Parent, Ref, Socket, Transport, _Opts, Buffer,
State = loop_timeout(State0#state{parent=Parent, State = loop_timeout(State0#state{parent=Parent,
ref=Ref, socket=Socket, transport=Transport, ref=Ref, socket=Socket, transport=Transport,
key=undefined, messages=Messages}), key=undefined, messages=Messages}),
%% We call parse_header/3 immediately because there might be
%% some data in the buffer that was sent along with the handshake.
%% While it is not allowed by the protocol to send frames immediately,
%% we still want to process that data if any.
case erlang:function_exported(Handler, websocket_init, 1) of case erlang:function_exported(Handler, websocket_init, 1) of
true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer}, true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer},
websocket_init, undefined, fun before_loop/3); websocket_init, undefined, fun parse_header/3);
false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer}) false -> parse_header(State, HandlerState, #ps_header{buffer=Buffer})
end. end.
before_loop(State=#state{active=false}, HandlerState, ParseState) -> before_loop(State=#state{active=false}, HandlerState, ParseState) ->

View file

@ -602,9 +602,8 @@ sys_get_state_h1(Config) ->
{ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []), {ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []),
timer:sleep(100), timer:sleep(100),
Pid = get_remote_pid_tcp(Socket), Pid = get_remote_pid_tcp(Socket),
{State, Buffer} = sys:get_state(Pid), State = sys:get_state(Pid),
state = element(1, State), state = element(1, State),
true = is_binary(Buffer),
ok. ok.
sys_get_state_h2(Config) -> sys_get_state_h2(Config) ->
@ -726,9 +725,8 @@ sys_replace_state_h1(Config) ->
{ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []), {ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []),
timer:sleep(100), timer:sleep(100),
Pid = get_remote_pid_tcp(Socket), Pid = get_remote_pid_tcp(Socket),
{State, Buffer} = sys:replace_state(Pid, fun(S) -> S end), State = sys:replace_state(Pid, fun(S) -> S end),
state = element(1, State), state = element(1, State),
true = is_binary(Buffer),
ok. ok.
sys_replace_state_h2(Config) -> sys_replace_state_h2(Config) ->

View file

@ -304,6 +304,18 @@ do_ws_deflate_opts_z(Path, Config) ->
{error, closed} = gen_tcp:recv(Socket, 0, 6000), {error, closed} = gen_tcp:recv(Socket, 0, 6000),
ok. ok.
ws_first_frame_with_handshake(Config) ->
doc("Client sends the first frame immediately with the handshake. "
"This is invalid according to the protocol but we still want "
"to accept it if the handshake is successful."),
Mask = 16#37fa213d,
MaskedHello = do_mask(<<"Hello">>, Mask, <<>>),
{ok, Socket, _} = do_handshake("/ws_echo", "",
<<1:1, 0:3, 1:4, 1:1, 5:7, Mask:32, MaskedHello/binary>>,
Config),
{ok, <<1:1, 0:3, 1:4, 0:1, 5:7, "Hello">>} = gen_tcp:recv(Socket, 0, 6000),
ok.
ws_init_return_ok(Config) -> ws_init_return_ok(Config) ->
doc("Handler does nothing."), doc("Handler does nothing."),
{ok, Socket, _} = do_handshake("/ws_init?ok", Config), {ok, Socket, _} = do_handshake("/ws_init?ok", Config),
@ -636,9 +648,12 @@ ws_webkit_deflate_single_bytes(Config) ->
%% Internal. %% Internal.
do_handshake(Path, Config) -> do_handshake(Path, Config) ->
do_handshake(Path, "", Config). do_handshake(Path, "", "", Config).
do_handshake(Path, ExtraHeaders, Config) -> do_handshake(Path, ExtraHeaders, Config) ->
do_handshake(Path, ExtraHeaders, "", Config).
do_handshake(Path, ExtraHeaders, ExtraData, Config) ->
{ok, Socket} = gen_tcp:connect("localhost", config(port, Config), {ok, Socket} = gen_tcp:connect("localhost", config(port, Config),
[binary, {active, false}]), [binary, {active, false}]),
ok = gen_tcp:send(Socket, [ ok = gen_tcp:send(Socket, [
@ -650,10 +665,16 @@ do_handshake(Path, ExtraHeaders, Config) ->
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Upgrade: websocket\r\n", "Upgrade: websocket\r\n",
ExtraHeaders, ExtraHeaders,
"\r\n"]), "\r\n",
ExtraData]),
{ok, Handshake} = gen_tcp:recv(Socket, 0, 6000), {ok, Handshake} = gen_tcp:recv(Socket, 0, 6000),
{ok, {http_response, {1, 1}, 101, _}, Rest} = erlang:decode_packet(http, Handshake, []), {ok, {http_response, {1, 1}, 101, _}, Rest} = erlang:decode_packet(http, Handshake, []),
[Headers, <<>>] = do_decode_headers(erlang:decode_packet(httph, Rest, []), []), [Headers, Data] = do_decode_headers(erlang:decode_packet(httph, Rest, []), []),
%% Queue extra data back, if any. We don't want to receive it yet.
case Data of
<<>> -> ok;
_ -> gen_tcp:unrecv(Socket, Data)
end,
{_, "Upgrade"} = lists:keyfind('Connection', 1, Headers), {_, "Upgrade"} = lists:keyfind('Connection', 1, Headers),
{_, "websocket"} = lists:keyfind('Upgrade', 1, Headers), {_, "websocket"} = lists:keyfind('Upgrade', 1, Headers),
{_, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="} = lists:keyfind("sec-websocket-accept", 1, Headers), {_, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="} = lists:keyfind("sec-websocket-accept", 1, Headers),