diff --git a/lib/postgrex/replication_connection.ex b/lib/postgrex/replication_connection.ex index e9edf768..fb812157 100644 --- a/lib/postgrex/replication_connection.ex +++ b/lib/postgrex/replication_connection.ex @@ -141,12 +141,13 @@ defmodule Postgrex.ReplicationConnection do `GenServer`. Read more about them in the `GenServer` docs. """ - use Connection require Logger import Bitwise alias Postgrex.Protocol + @behaviour :gen_statem + @doc false defstruct protocol: nil, state: nil, @@ -156,7 +157,7 @@ defmodule Postgrex.ReplicationConnection do ## PUBLIC API ## - @type server :: GenServer.server() + @type server :: :gen_statem.server() @type state :: term @type ack :: iodata @type query :: iodata @@ -240,7 +241,7 @@ defmodule Postgrex.ReplicationConnection do been replied to should eventually do so. One simple approach is to reply to any pending commands on `c:handle_disconnect/1`. """ - @callback handle_call(term, GenServer.from(), state) :: + @callback handle_call(term, :gen_statem.from(), state) :: {:noreply, state} | {:noreply, ack, state} | {:query, query, state} @@ -250,7 +251,7 @@ defmodule Postgrex.ReplicationConnection do Callback for `:query` outputs. If any callback returns `{:query, iodata, state}`, - then this callback will be immediatelly called with + then this callback will be immediately called with the result of the query. Please note that even though replication connections use the simple query protocol, Postgres currently limits them to single command queries. @@ -274,13 +275,13 @@ defmodule Postgrex.ReplicationConnection do @doc """ Replies to the given `call/3`. """ - defdelegate reply(client, reply), to: GenServer + defdelegate reply(client, reply), to: :gen_statem @doc """ Calls the given replication server. """ def call(server, message, timeout \\ 5000) do - with {__MODULE__, reason} <- GenServer.call(server, message, timeout) do + with {__MODULE__, reason} <- :gen_statem.call(server, message, timeout) do exit({reason, {__MODULE__, :call, [server, message, timeout]}}) end end @@ -339,10 +340,34 @@ defmodule Postgrex.ReplicationConnection do @spec start_link(module(), term(), Keyword.t()) :: {:ok, pid} | {:error, Postgrex.Error.t() | term} def start_link(module, arg, opts) do - {server_opts, opts} = Keyword.split(opts, [:name]) + {name, opts} = Keyword.pop(opts, :name) opts = Keyword.put_new(opts, :sync_connect, true) connection_opts = Postgrex.Utils.default_opts(opts) - Connection.start_link(__MODULE__, {module, arg, connection_opts}, server_opts) + start_args = {module, arg, connection_opts} + + case name do + nil -> + :gen_statem.start_link(__MODULE__, start_args, []) + + atom when is_atom(atom) -> + :gen_statem.start_link({:local, atom}, __MODULE__, start_args, []) + + {:global, _term} = tuple -> + :gen_statem.start_link(tuple, __MODULE__, start_args, []) + + {:via, via_module, _term} = tuple when is_atom(via_module) -> + :gen_statem.start_link(tuple, __MODULE__, start_args, []) + + other -> + raise ArgumentError, """ + expected :name option to be one of the following: + * nil + * atom + * {:global, term} + * {:via, module, term} + Got: #{inspect(other)} + """ + end end @doc """ @@ -409,7 +434,14 @@ defmodule Postgrex.ReplicationConnection do ## CALLBACKS ## + @state :no_state + + @doc false + @impl :gen_statem + def callback_mode, do: :handle_event_function + @doc false + @impl :gen_statem def init({mod, arg, opts}) do case mod.init(arg) do {:ok, mod_state} -> @@ -433,42 +465,44 @@ defmodule Postgrex.ReplicationConnection do put_opts(opts) if opts[:sync_connect] do - case connect(:init, state) do - {:ok, _} = ok -> ok - {:backoff, _, _} = backoff -> backoff - {:stop, reason, _} -> {:stop, reason} + case handle_event(:internal, {:connect, :init}, @state, state) do + {:keep_state, state} -> {:ok, @state, state} + {:keep_state, state, actions} -> {:ok, @state, state, actions} + {:stop, _reason, _state} = stop -> stop end else - {:connect, :init, state} + {:ok, @state, state, {:next_event, :internal, {:connect, :init}}} end end end @doc false - def connect(_, %{state: {mod, mod_state}} = s) do + @impl :gen_statem + def handle_event(type, content, state, s) + + def handle_event({:timeout, :backoff}, nil, @state, s) do + {:keep_state, s, {:next_event, :internal, {:connect, :backoff}}} + end + + def handle_event(:internal, {:connect, _info}, @state, %{state: {mod, mod_state}} = s) do case Protocol.connect(opts()) do {:ok, protocol} -> - s = %{s | protocol: protocol} - - with {:noreply, s} <- maybe_handle(mod, :handle_connect, [mod_state], s) do - {:ok, s} - end + maybe_handle(mod, :handle_connect, [mod_state], %{s | protocol: protocol}) {:error, reason} -> if s.auto_reconnect do - {:backoff, s.reconnect_backoff, s} + {:keep_state, s, {{:timeout, :backoff}, s.reconnect_backoff, nil}} else {:stop, reason, s} end end end - def handle_call(msg, from, %{state: {mod, mod_state}} = s) do + def handle_event({:call, from}, msg, @state, %{state: {mod, mod_state}} = s) do handle(mod, :handle_call, [msg, from, mod_state], from, s) end - @doc false - def handle_info(msg, %{protocol: protocol, streaming: streaming} = s) do + def handle_event(:info, msg, @state, %{protocol: protocol, streaming: streaming} = s) do case Protocol.handle_copy_recv(msg, streaming, protocol) do {:ok, copies, protocol} -> handle_data(copies, %{s | protocol: protocol}) @@ -482,16 +516,18 @@ defmodule Postgrex.ReplicationConnection do end end - defp handle_data([], s), do: {:noreply, s} + ## Helpers + + defp handle_data([], s), do: {:keep_state, s} defp handle_data([:copy_done | copies], %{state: {mod, mod_state}} = s) do - with {:noreply, s} <- handle(mod, :handle_data, [:done, mod_state], nil, s) do + with {:keep_state, s} <- handle(mod, :handle_data, [:done, mod_state], nil, s) do handle_data(copies, %{s | streaming: nil}) end end defp handle_data([copy | copies], %{state: {mod, mod_state}} = s) do - with {:noreply, s} <- handle(mod, :handle_data, [copy, mod_state], nil, s) do + with {:keep_state, s} <- handle(mod, :handle_data, [copy, mod_state], nil, s) do handle_data(copies, s) end end @@ -500,20 +536,20 @@ defmodule Postgrex.ReplicationConnection do if function_exported?(mod, fun, length(args)) do handle(mod, fun, args, nil, s) else - {:noreply, s} + {:keep_state, s} end end defp handle(mod, fun, args, from, %{streaming: streaming} = s) do case apply(mod, fun, args) do {:noreply, mod_state} -> - {:noreply, %{s | state: {mod, mod_state}}} + {:keep_state, %{s | state: {mod, mod_state}}} {:noreply, replies, mod_state} -> s = %{s | state: {mod, mod_state}} case Protocol.handle_copy_send(replies, s.protocol) do - :ok -> {:noreply, s} + :ok -> {:keep_state, s} {error, reason, protocol} -> reconnect_or_stop(error, reason, protocol, s) end @@ -523,7 +559,7 @@ defmodule Postgrex.ReplicationConnection do with {:ok, protocol} <- Protocol.handle_streaming(query, s.protocol), {:ok, protocol} <- Protocol.checkin(protocol) do - {:noreply, %{s | protocol: protocol, streaming: max_messages}} + {:keep_state, %{s | protocol: protocol, streaming: max_messages}} else {error_or_disconnect, reason, protocol} -> reconnect_or_stop(error_or_disconnect, reason, protocol, s) @@ -552,21 +588,24 @@ defmodule Postgrex.ReplicationConnection do defp stream_in_progress(command, mod, mod_state, from, s) do Logger.warning("received #{command} while stream is already in progress") from && reply(from, {__MODULE__, :stream_in_progress}) - {:noreply, %{s | state: {mod, mod_state}}} + {:keep_state, %{s | state: {mod, mod_state}}} end defp reconnect_or_stop(error, reason, protocol, %{auto_reconnect: false} = s) when error in [:error, :disconnect] do %{state: {mod, mod_state}} = s - {:noreply, s} = maybe_handle(mod, :handle_disconnect, [mod_state], %{s | protocol: protocol}) + + {:keep_state, s} = + maybe_handle(mod, :handle_disconnect, [mod_state], %{s | protocol: protocol}) + {:stop, reason, s} end defp reconnect_or_stop(error, _reason, _protocol, %{auto_reconnect: true} = s) when error in [:error, :disconnect] do %{state: {mod, mod_state}} = s - {:noreply, s} = maybe_handle(mod, :handle_disconnect, [mod_state], s) - {:connect, :reconnect, %{s | streaming: nil}} + {:keep_state, s} = maybe_handle(mod, :handle_disconnect, [mod_state], s) + {:keep_state, %{s | streaming: nil}, {:next_event, :internal, {:connect, :reconnect}}} end defp opts(), do: Process.get(__MODULE__) diff --git a/test/replication_connection_test.exs b/test/replication_connection_test.exs index af9271e4..7d0441f3 100644 --- a/test/replication_connection_test.exs +++ b/test/replication_connection_test.exs @@ -288,7 +288,8 @@ defmodule ReplicationTest do end defp disconnect(repl) do - {:gen_tcp, sock} = :sys.get_state(repl).mod_state.protocol.sock + {_, state} = :sys.get_state(repl) + {:gen_tcp, sock} = state.protocol.sock :gen_tcp.shutdown(sock, :read_write) end end