{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}

module Main where

import GHC.Prim
import GHC.Types

data PromptTag a = PromptTag (PromptTag# a)

newPromptTag :: IO (PromptTag a)
newPromptTag = IO (\s -> case newPromptTag# s of
  (# s', tag #) -> (# s, PromptTag tag #))

prompt :: PromptTag a -> IO a -> IO a
prompt (PromptTag tag) (IO m) = IO (prompt# tag m)

control0 :: PromptTag a -> ((IO b -> IO a) -> IO a) -> IO b
control0 (PromptTag tag) f =
  IO (control0# tag (\k -> case f (\(IO a) -> IO (k a)) of IO b -> b))

reset :: PromptTag a -> IO a -> IO a
reset = prompt

shift :: PromptTag a -> ((IO b -> IO a) -> IO a) -> IO b
shift tag f = control0 tag (\k -> reset tag (f (\m -> reset tag (k m))))

data HandlerTag f where
  HandlerTag :: PromptTag a
             -> (forall b. f b -> (b -> IO a) -> IO a)
             -> HandlerTag f

send :: HandlerTag f -> f b -> IO b
send (HandlerTag tag f) v = control0 tag $ \k -> f v (prompt tag . k . pure)

handle :: (HandlerTag f -> IO a)
       -> (forall b. f b -> (b -> IO a) -> IO a)
       -> IO a
handle f g = do
  tag <- newPromptTag
  prompt tag $ f (HandlerTag tag g)

data NonDet a where
  Choice :: NonDet Bool

handleNonDet :: (HandlerTag NonDet -> IO a) -> IO [a]
handleNonDet f = handle (fmap (:[]) . f) $ \Choice k ->
  liftA2 (++) (k True) (k False)

amb :: HandlerTag NonDet -> a -> a -> IO a
amb tag a b = send tag Choice >>= \case
  True  -> pure a
  False -> pure b

example :: IO [[(Integer, Char)]]
example =
  handleNonDet $ \tag1 ->
  handleNonDet $ \tag2 -> do
    x <- amb tag2  1   2
    y <- amb tag1 'a' 'b'
    pure (x, y)

main :: IO ()
main = print =<< example