-------------------------------------------------------------------------------
-- network-wait
-- Copyright 2022 Michael B. Gale (github@michael-gale.co.uk)
-------------------------------------------------------------------------------

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- | This module exports variants of the functions from "Network.Wait"
-- specialised for PostgreSQL servers. In addition to checking whether a
-- connection can be established, the functions in this module also check
-- whether the PostgreSQL server is ready to accept commands.
module Network.Wait.PostgreSQL (
    waitPostgreSql,
    waitPostgreSqlVerbose,
    waitPostgreSqlVerboseFormat,
    waitPostgreSqlWith
) where

-------------------------------------------------------------------------------

import Control.Monad
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Retry

import Database.PostgreSQL.Simple
import Database.PostgreSQL.Simple.Internal

import Network.Wait

-------------------------------------------------------------------------------

-- | `waitPostgreSql` @retryPolicy connectInfo@ is a variant of
-- `waitPostgresWith` which does not install any additional handlers.
waitPostgreSql
    :: (MonadIO m, MonadMask m)
    => RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSql :: RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSql = [RetryStatus -> Handler m Bool]
-> RetryPolicyM m -> ConnectInfo -> m Connection
forall (m :: * -> *).
(MonadIO m, MonadMask m) =>
[RetryStatus -> Handler m Bool]
-> RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSqlWith []

-- | `waitPostgreSqlVerbose` @outputHandler retryPolicy connectInfo@ is a variant
-- of `waitPostgreSqlVerboseFormat` which catches all exceptions derived from
-- `SomeException` and formats retry attempt information using `defaultLogMsg`
-- before passing the resulting `String` to @out@.
waitPostgreSqlVerbose
    :: (MonadIO m, MonadMask m)
    => (String -> m ()) -> RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSqlVerbose :: (String -> m ()) -> RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSqlVerbose String -> m ()
out =
    forall e (m :: * -> *).
(MonadIO m, MonadMask m, Exception e) =>
(Bool -> e -> RetryStatus -> m ())
-> RetryPolicyM m -> ConnectInfo -> m Connection
forall (m :: * -> *).
(MonadIO m, MonadMask m, Exception SomeException) =>
(Bool -> SomeException -> RetryStatus -> m ())
-> RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSqlVerboseFormat @SomeException ((Bool -> SomeException -> RetryStatus -> m ())
 -> RetryPolicyM m -> ConnectInfo -> m Connection)
-> (Bool -> SomeException -> RetryStatus -> m ())
-> RetryPolicyM m
-> ConnectInfo
-> m Connection
forall a b. (a -> b) -> a -> b
$
    \Bool
b SomeException
ex RetryStatus
st -> String -> m ()
out (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ Bool -> SomeException -> RetryStatus -> String
forall e. Exception e => Bool -> e -> RetryStatus -> String
defaultLogMsg Bool
b SomeException
ex RetryStatus
st

-- | `waitPostgreSqlVerboseFormat` @outputHandler retryPolicy connectInfo@ is a
-- variant of `waitPostgreSqlWith` which installs an extra handler based on
-- `logRetries` which passes status information for each retry attempt
-- to @outputHandler@.
waitPostgreSqlVerboseFormat
    :: forall e m . (MonadIO m, MonadMask m, Exception e)
    => (Bool -> e -> RetryStatus -> m ())
    -> RetryPolicyM m
    -> ConnectInfo
    -> m Connection
waitPostgreSqlVerboseFormat :: (Bool -> e -> RetryStatus -> m ())
-> RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSqlVerboseFormat Bool -> e -> RetryStatus -> m ()
out = [RetryStatus -> Handler m Bool]
-> RetryPolicyM m -> ConnectInfo -> m Connection
forall (m :: * -> *).
(MonadIO m, MonadMask m) =>
[RetryStatus -> Handler m Bool]
-> RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSqlWith [RetryStatus -> Handler m Bool
h]
    where h :: RetryStatus -> Handler m Bool
h = (e -> m Bool)
-> (Bool -> e -> RetryStatus -> m ())
-> RetryStatus
-> Handler m Bool
forall (m :: * -> *) e.
(Monad m, Exception e) =>
(e -> m Bool)
-> (Bool -> e -> RetryStatus -> m ())
-> RetryStatus
-> Handler m Bool
logRetries (m Bool -> e -> m Bool
forall a b. a -> b -> a
const (m Bool -> e -> m Bool) -> m Bool -> e -> m Bool
forall a b. (a -> b) -> a -> b
$ Bool -> m Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True) Bool -> e -> RetryStatus -> m ()
out

-- | `waitPostgreSqlWith` @extraHandlers retryPolicy connectInfo@ will attempt
-- to connect to the PostgreSQL server using @connectInfo@ and check that the
-- server is ready to accept commands. If this check fails, @retryPolicy@ is
-- used to determine whether (and how often) this function should attempt to
-- retry establishing the connection. By default, this function will retry
-- after all exceptions (except for those given by `skipAsyncExceptions`).
-- This behaviour may be customised with @extraHandlers@ which are installed
-- after `skipAsyncExceptions`, but before the default exception handler. The
--  @extraHandlers@ may also be used to report retry attempts to e.g. the
-- standard output or a logger.
waitPostgreSqlWith
    :: (MonadIO m, MonadMask m)
    => [RetryStatus -> Handler m Bool] -> RetryPolicyM m -> ConnectInfo
    -> m Connection
waitPostgreSqlWith :: [RetryStatus -> Handler m Bool]
-> RetryPolicyM m -> ConnectInfo -> m Connection
waitPostgreSqlWith [RetryStatus -> Handler m Bool]
hs RetryPolicyM m
policy ConnectInfo
info =
    [RetryStatus -> Handler m Bool]
-> RetryPolicyM m -> m Connection -> m Connection
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[RetryStatus -> Handler m Bool] -> RetryPolicyM m -> m a -> m a
recoveringWith [RetryStatus -> Handler m Bool]
hs RetryPolicyM m
policy (m Connection -> m Connection) -> m Connection -> m Connection
forall a b. (a -> b) -> a -> b
$
    IO Connection -> m Connection
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Connection -> m Connection) -> IO Connection -> m Connection
forall a b. (a -> b) -> a -> b
$
    IO Connection
-> (Connection -> IO ())
-> (Connection -> IO Connection)
-> IO Connection
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (ConnectInfo -> IO Connection
connect ConnectInfo
info) Connection -> IO ()
close ((Connection -> IO Connection) -> IO Connection)
-> (Connection -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \Connection
con -> do
        [[Int]]
rs <- Connection -> Query -> IO [[Int]]
forall r. FromRow r => Connection -> Query -> IO [r]
query_ @[Int] Connection
con Query
"SELECT 1;"
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([[Int]]
rs [[Int]] -> [[Int]] -> Bool
forall a. Eq a => a -> a -> Bool
== [[Int
1]]) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ SqlError -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SqlError -> IO ()) -> SqlError -> IO ()
forall a b. (a -> b) -> a -> b
$
            ByteString -> SqlError
fatalError ByteString
"Unexpected result for SELECT 1."
        Connection -> IO Connection
forall (f :: * -> *) a. Applicative f => a -> f a
pure Connection
con

-------------------------------------------------------------------------------