0
Fork 0
mirror of https://github.com/ninenines/cowboy.git synced 2025-07-14 12:20:24 +00:00

Merge the two separate receive loops in cowboy_websocket

Also rename a bunch of functions to make the code easier to read.
This commit is contained in:
Loïc Hoguin 2018-03-23 16:32:53 +01:00
parent 31092b546c
commit 21c9c66971
No known key found for this signature in database
GPG key ID: 8A9DF795F6FED764
2 changed files with 109 additions and 103 deletions

View file

@ -20,7 +20,7 @@
-export([upgrade/4]). -export([upgrade/4]).
-export([upgrade/5]). -export([upgrade/5]).
-export([takeover/7]). -export([takeover/7]).
-export([handler_loop/3]). -export([loop/3]).
-export([system_continue/3]). -export([system_continue/3]).
-export([system_terminate/4]). -export([system_terminate/4]).
@ -202,53 +202,64 @@ websocket_handshake(State=#state{key=Key},
%% Connection process. %% Connection process.
%% @todo Keep parent and handle system messages. -record(ps_header, {
buffer = <<>> :: binary()
}).
-record(ps_payload, {
type :: cow_ws:frame_type(),
len :: non_neg_integer(),
mask_key :: cow_ws:mask_key(),
rsv :: cow_ws:rsv(),
close_code = undefined :: undefined | cow_ws:close_code(),
unmasked = <<>> :: binary(),
unmasked_len = 0 :: non_neg_integer(),
buffer = <<>> :: binary()
}).
-type parse_state() :: #ps_header{} | #ps_payload{}.
-spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(), -spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(),
{#state{}, any()}) -> ok. {#state{}, any()}) -> no_return().
takeover(Parent, Ref, Socket, Transport, _Opts, Buffer, takeover(Parent, Ref, Socket, Transport, _Opts, Buffer,
{State0=#state{handler=Handler}, HandlerState}) -> {State0=#state{handler=Handler}, HandlerState}) ->
%% @todo We should have an option to disable this behavior. %% @todo We should have an option to disable this behavior.
ranch:remove_connection(Ref), ranch:remove_connection(Ref),
State1 = handler_loop_timeout(State0#state{parent=Parent, State = loop_timeout(State0#state{parent=Parent,
ref=Ref, socket=Socket, transport=Transport}), ref=Ref, socket=Socket, transport=Transport,
State = State1#state{key=undefined, messages=Transport:messages()}, key=undefined, messages=Transport:messages()}),
case erlang:function_exported(Handler, websocket_init, 1) of case erlang:function_exported(Handler, websocket_init, 1) of
true -> handler_call(State, HandlerState, Buffer, websocket_init, undefined, fun handler_before_loop/3); true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer},
false -> handler_before_loop(State, HandlerState, Buffer) websocket_init, undefined, fun before_loop/3);
false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer})
end. end.
-spec handler_before_loop(#state{}, any(), binary()) before_loop(State=#state{socket=Socket, transport=Transport, hibernate=true},
%% @todo Yeah not env. HandlerState, ParseState) ->
-> {ok, cowboy_middleware:env()}.
handler_before_loop(State=#state{
socket=Socket, transport=Transport, hibernate=true},
HandlerState, SoFar) ->
Transport:setopts(Socket, [{active, once}]), Transport:setopts(Socket, [{active, once}]),
proc_lib:hibernate(?MODULE, handler_loop, proc_lib:hibernate(?MODULE, loop,
[State#state{hibernate=false}, HandlerState, SoFar]); [State#state{hibernate=false}, HandlerState, ParseState]);
handler_before_loop(State=#state{socket=Socket, transport=Transport}, before_loop(State=#state{socket=Socket, transport=Transport},
HandlerState, SoFar) -> HandlerState, ParseState) ->
Transport:setopts(Socket, [{active, once}]), Transport:setopts(Socket, [{active, once}]),
handler_loop(State, HandlerState, SoFar). loop(State, HandlerState, ParseState).
-spec handler_loop_timeout(#state{}) -> #state{}. -spec loop_timeout(#state{}) -> #state{}.
handler_loop_timeout(State=#state{timeout=infinity}) -> loop_timeout(State=#state{timeout=infinity}) ->
State#state{timeout_ref=undefined}; State#state{timeout_ref=undefined};
handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) -> loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
_ = case PrevRef of undefined -> ignore; PrevRef -> _ = case PrevRef of undefined -> ignore; PrevRef ->
erlang:cancel_timer(PrevRef) end, erlang:cancel_timer(PrevRef) end,
TRef = erlang:start_timer(Timeout, self(), ?MODULE), TRef = erlang:start_timer(Timeout, self(), ?MODULE),
State#state{timeout_ref=TRef}. State#state{timeout_ref=TRef}.
-spec handler_loop(#state{}, any(), binary()) -spec loop(#state{}, any(), parse_state()) -> no_return().
-> {ok, cowboy_middleware:env()}. loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error},
handler_loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error}, timeout_ref=TRef}, HandlerState, ParseState) ->
timeout_ref=TRef}, HandlerState, SoFar) ->
receive receive
{OK, Socket, Data} -> {OK, Socket, Data} ->
State2 = handler_loop_timeout(State), State2 = loop_timeout(State),
websocket_data(State2, HandlerState, parse(State2, HandlerState, ParseState, Data);
<< SoFar/binary, Data/binary >>);
{Closed, Socket} -> {Closed, Socket} ->
terminate(State, HandlerState, {error, closed}); terminate(State, HandlerState, {error, closed});
{Error, Socket, Reason} -> {Error, Socket, Reason} ->
@ -256,124 +267,121 @@ handler_loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Er
{timeout, TRef, ?MODULE} -> {timeout, TRef, ?MODULE} ->
websocket_close(State, HandlerState, timeout); websocket_close(State, HandlerState, timeout);
{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
handler_loop(State, HandlerState, SoFar); loop(State, HandlerState, ParseState);
%% System messages. %% System messages.
{'EXIT', Parent, Reason} -> {'EXIT', Parent, Reason} ->
%% @todo We should exit gracefully. %% @todo We should exit gracefully.
exit(Reason); exit(Reason);
{system, From, Request} -> {system, From, Request} ->
sys:handle_system_msg(Request, From, Parent, ?MODULE, [], sys:handle_system_msg(Request, From, Parent, ?MODULE, [],
{State, HandlerState, SoFar}); {State, HandlerState, ParseState});
%% Calls from supervisor module. %% Calls from supervisor module.
{'$gen_call', From, Call} -> {'$gen_call', From, Call} ->
cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE), cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE),
handler_loop(State, HandlerState, SoFar); loop(State, HandlerState, ParseState);
Message -> Message ->
handler_call(State, HandlerState, handler_call(State, HandlerState, ParseState,
SoFar, websocket_info, Message, fun handler_before_loop/3) websocket_info, Message, fun before_loop/3)
end. end.
-spec websocket_data(#state{}, any(), binary()) parse(State, HandlerState, PS=#ps_header{buffer=Buffer}, Data) ->
-> {ok, cowboy_middleware:env()}. parse_header(State, HandlerState, PS#ps_header{
websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState, Data) -> buffer= <<Buffer/binary, Data/binary>>});
parse(State, HandlerState, PS=#ps_payload{buffer=Buffer}, Data) ->
parse_payload(State, HandlerState, PS#ps_payload{buffer= <<>>},
<<Buffer/binary, Data/binary>>).
parse_header(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState,
ParseState=#ps_header{buffer=Data}) ->
case cow_ws:parse_header(Data, Extensions, FragState) of case cow_ws:parse_header(Data, Extensions, FragState) of
%% All frames sent from the client to the server are masked. %% All frames sent from the client to the server are masked.
{_, _, _, _, undefined, _} -> {_, _, _, _, undefined, _} ->
websocket_close(State, HandlerState, {error, badframe}); websocket_close(State, HandlerState, {error, badframe});
{Type, FragState2, Rsv, Len, MaskKey, Rest} -> {Type, FragState2, Rsv, Len, MaskKey, Rest} ->
websocket_payload(State#state{frag_state=FragState2}, HandlerState, Type, Len, MaskKey, Rsv, undefined, <<>>, 0, Rest); parse_payload(State#state{frag_state=FragState2}, HandlerState,
#ps_payload{type=Type, len=Len, mask_key=MaskKey, rsv=Rsv}, Rest);
more -> more ->
handler_before_loop(State, HandlerState, Data); before_loop(State, HandlerState, ParseState);
error -> error ->
websocket_close(State, HandlerState, {error, badframe}) websocket_close(State, HandlerState, {error, badframe})
end. end.
websocket_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions}, parse_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions},
HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data) -> HandlerState, ParseState=#ps_payload{
case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, Type, Len, FragState, Extensions, Rsv) of type=Type, len=Len, mask_key=MaskKey, rsv=Rsv,
{ok, CloseCode2, Payload, Utf8State, Rest} -> unmasked=Unmasked, unmasked_len=UnmaskedLen}, Data) ->
websocket_dispatch(State#state{utf8_state=Utf8State}, case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen,
HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode2, Rest); Type, Len, FragState, Extensions, Rsv) of
{ok, CloseCode, Payload, Utf8State, Rest} ->
dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState,
ParseState#ps_payload{unmasked= <<Unmasked/binary, Payload/binary>>,
close_code=CloseCode}, Rest);
{ok, Payload, Utf8State, Rest} -> {ok, Payload, Utf8State, Rest} ->
websocket_dispatch(State#state{utf8_state=Utf8State}, dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState,
HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode, Rest); ParseState#ps_payload{unmasked= <<Unmasked/binary, Payload/binary>>},
{more, CloseCode2, Payload, Utf8State} -> Rest);
websocket_payload_loop(State#state{utf8_state=Utf8State}, {more, CloseCode, Payload, Utf8State} ->
HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode2, before_loop(State#state{utf8_state=Utf8State}, HandlerState,
<< Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data)); ParseState#ps_payload{len=Len - byte_size(Data), close_code=CloseCode,
unmasked= <<Unmasked/binary, Payload/binary>>,
unmasked_len=UnmaskedLen + byte_size(Data)});
{more, Payload, Utf8State} -> {more, Payload, Utf8State} ->
websocket_payload_loop(State#state{utf8_state=Utf8State}, before_loop(State#state{utf8_state=Utf8State}, HandlerState,
HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode, ParseState#ps_payload{len=Len - byte_size(Data),
<< Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data)); unmasked= <<Unmasked/binary, Payload/binary>>,
unmasked_len=UnmaskedLen + byte_size(Data)});
Error = {error, _Reason} -> Error = {error, _Reason} ->
websocket_close(State, HandlerState, Error) websocket_close(State, HandlerState, Error)
end. end.
websocket_payload_loop(State=#state{socket=Socket, transport=Transport, dispatch_frame(State=#state{socket=Socket, transport=Transport,
messages={OK, Closed, Error}, timeout_ref=TRef}, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions},
HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) -> HandlerState, #ps_payload{type=Type0, unmasked=Payload0, close_code=CloseCode0},
Transport:setopts(Socket, [{active, once}]), RemainingData) ->
receive
{OK, Socket, Data} ->
State2 = handler_loop_timeout(State),
websocket_payload(State2, HandlerState,
Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data);
{Closed, Socket} ->
terminate(State, HandlerState, {error, closed});
{Error, Socket, Reason} ->
terminate(State, HandlerState, {error, Reason});
{timeout, TRef, ?MODULE} ->
websocket_close(State, HandlerState, timeout);
{timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) ->
websocket_payload_loop(State, HandlerState,
Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen);
Message ->
handler_call(State, HandlerState,
<<>>, websocket_info, Message,
fun (State2, HandlerState2, _) ->
websocket_payload_loop(State2, HandlerState2,
Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen)
end)
end.
websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions},
HandlerState, Type0, Payload0, CloseCode0, RemainingData) ->
case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of
%% @todo Allow receiving fragments. %% @todo Allow receiving fragments.
{fragment, nofin, _, Payload} -> {fragment, nofin, _, Payload} ->
websocket_data(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, HandlerState, RemainingData); parse_header(State#state{frag_buffer= << SoFar/binary, Payload/binary >>},
HandlerState, #ps_header{buffer=RemainingData});
{fragment, fin, Type, Payload} -> {fragment, fin, Type, Payload} ->
handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState, RemainingData, handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState,
websocket_handle, {Type, << SoFar/binary, Payload/binary >>}, fun websocket_data/3); #ps_header{buffer=RemainingData},
websocket_handle, {Type, << SoFar/binary, Payload/binary >>},
fun parse_header/3);
close -> close ->
websocket_close(State, HandlerState, remote); websocket_close(State, HandlerState, remote);
{close, CloseCode, Payload} -> {close, CloseCode, Payload} ->
websocket_close(State, HandlerState, {remote, CloseCode, Payload}); websocket_close(State, HandlerState, {remote, CloseCode, Payload});
Frame = ping -> Frame = ping ->
Transport:send(Socket, cow_ws:frame(pong, Extensions)), Transport:send(Socket, cow_ws:frame(pong, Extensions)),
handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3); handler_call(State, HandlerState,
#ps_header{buffer=RemainingData},
websocket_handle, Frame, fun parse_header/3);
Frame = {ping, Payload} -> Frame = {ping, Payload} ->
Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)), Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)),
handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3); handler_call(State, HandlerState,
#ps_header{buffer=RemainingData},
websocket_handle, Frame, fun parse_header/3);
Frame -> Frame ->
handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3) handler_call(State, HandlerState,
#ps_header{buffer=RemainingData},
websocket_handle, Frame, fun parse_header/3)
end. end.
-spec handler_call(#state{}, any(), binary(), atom(), any(), fun()) -> no_return().
handler_call(State=#state{handler=Handler}, HandlerState, handler_call(State=#state{handler=Handler}, HandlerState,
RemainingData, Callback, Message, NextState) -> ParseState, Callback, Message, NextState) ->
try case Callback of try case Callback of
websocket_init -> Handler:websocket_init(HandlerState); websocket_init -> Handler:websocket_init(HandlerState);
_ -> Handler:Callback(Message, HandlerState) _ -> Handler:Callback(Message, HandlerState)
end of end of
{ok, HandlerState2} -> {ok, HandlerState2} ->
NextState(State, HandlerState2, RemainingData); NextState(State, HandlerState2, ParseState);
{ok, HandlerState2, hibernate} -> {ok, HandlerState2, hibernate} ->
NextState(State#state{hibernate=true}, HandlerState2, RemainingData); NextState(State#state{hibernate=true}, HandlerState2, ParseState);
{reply, Payload, HandlerState2} -> {reply, Payload, HandlerState2} ->
case websocket_send(Payload, State) of case websocket_send(Payload, State) of
ok -> ok ->
NextState(State, HandlerState2, RemainingData); NextState(State, HandlerState2, ParseState);
stop -> stop ->
terminate(State, HandlerState2, stop); terminate(State, HandlerState2, stop);
Error = {error, _} -> Error = {error, _} ->
@ -383,7 +391,7 @@ handler_call(State=#state{handler=Handler}, HandlerState,
case websocket_send(Payload, State) of case websocket_send(Payload, State) of
ok -> ok ->
NextState(State#state{hibernate=true}, NextState(State#state{hibernate=true},
HandlerState2, RemainingData); HandlerState2, ParseState);
stop -> stop ->
terminate(State, HandlerState2, stop); terminate(State, HandlerState2, stop);
Error = {error, _} -> Error = {error, _} ->
@ -458,15 +466,16 @@ handler_terminate(#state{handler=Handler, req=Req}, HandlerState, Reason) ->
%% System callbacks. %% System callbacks.
-spec system_continue(_, _, {#state{}, any(), binary()}) -> ok. -spec system_continue(_, _, {#state{}, any(), parse_state()}) -> no_return().
system_continue(_, _, {State, HandlerState, SoFar}) -> system_continue(_, _, {State, HandlerState, ParseState}) ->
handler_loop(State, HandlerState, SoFar). loop(State, HandlerState, ParseState).
-spec system_terminate(any(), _, _, {#state{}, any(), binary()}) -> no_return(). -spec system_terminate(any(), _, _, {#state{}, any(), parse_state()}) -> no_return().
system_terminate(Reason, _, _, {State, HandlerState, _}) -> system_terminate(Reason, _, _, {State, HandlerState, _}) ->
%% @todo We should exit gracefully, if possible. %% @todo We should exit gracefully, if possible.
terminate(State, HandlerState, Reason). terminate(State, HandlerState, Reason).
-spec system_code_change(Misc, _, _, _) -> {ok, Misc} when Misc::{#state{}, any(), binary()}. -spec system_code_change(Misc, _, _, _)
-> {ok, Misc} when Misc::{#state{}, any(), parse_state()}.
system_code_change(Misc, _, _, _) -> system_code_change(Misc, _, _, _) ->
{ok, Misc}. {ok, Misc}.

View file

@ -112,9 +112,6 @@ proc_lib_initial_call_tls(Config) ->
%% so that it doesn't eat up system messages. It should only %% so that it doesn't eat up system messages. It should only
%% flush messages that are specific to cowboy_http. %% flush messages that are specific to cowboy_http.
%% @todo The cowboy_websocket module needs to have the functions
%% handler_loop and websocket_payload_loop merged into one.
bad_system_from_h1(Config) -> bad_system_from_h1(Config) ->
doc("h1: Sending a system message with a bad From value results in a process crash."), doc("h1: Sending a system message with a bad From value results in a process crash."),
{ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), [{active, false}]), {ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), [{active, false}]),