{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RebindableSyntax    #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.CodeGen.Scan
-- Copyright   : [2014..2017] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.CodeGen.Scan
  where

-- accelerate
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR                        ( IR )
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache

import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Generate
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )

import Control.Applicative
import Control.Monad
import Data.String                                                  ( fromString )
import Data.Coerce                                                  as Safe
import Prelude                                                      as P


data Direction = L | R

-- 'Data.List.scanl' style left-to-right exclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation.
--
-- > scanl (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 11) [10,10,11,13,16,20,25,31,38,46,55]
--
mkScanl
    :: forall aenv sh e. (Shape sh, Elt e)
    => UID
    -> Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanl uid aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanS L uid aenv combine (Just seed) arr
                              , mkScanP L uid aenv combine (Just seed) arr
                              , mkScanFill uid aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScanS L uid aenv combine (Just seed) arr
          <*> mkScanFill uid aenv seed


-- 'Data.List.scanl1' style left-to-right inclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation. The array must not be empty.
--
-- > scanl1 (+) (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 10) [0,1,3,6,10,15,21,28,36,45]
--
mkScanl1
    :: forall aenv sh e. (Shape sh, Elt e)
    => UID
    -> Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanl1 uid aenv combine arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = (+++) <$> mkScanS L uid aenv combine Nothing arr
          <*> mkScanP L uid aenv combine Nothing arr
  --
  | otherwise
  = mkScanS L uid aenv combine Nothing arr


-- Variant of 'scanl' where the final result is returned in a separate array.
--
-- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> ( Array (Z :. 10) [10,10,11,13,16,20,25,31,38,46]
--       , Array Z [55]
--       )
--
mkScanl'
    :: forall aenv sh e. (Shape sh, Elt e)
    => UID
    -> Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScanl' uid aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScan'S L uid aenv combine seed arr
                              , mkScan'P L uid aenv combine seed arr
                              , mkScan'Fill uid aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScan'S L uid aenv combine seed arr
          <*> mkScan'Fill uid aenv seed


-- 'Data.List.scanr' style right-to-left exclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation.
--
-- > scanr (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 11) [55,55,54,52,49,45,40,34,27,19,10]
--
mkScanr
    :: forall aenv sh e. (Shape sh, Elt e)
    => UID
    -> Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanr uid aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanS R uid aenv combine (Just seed) arr
                              , mkScanP R uid aenv combine (Just seed) arr
                              , mkScanFill uid aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScanS R uid aenv combine (Just seed) arr
          <*> mkScanFill uid aenv seed


-- 'Data.List.scanr1' style right-to-left inclusive scan, but with the
-- restriction that the combination function must be associative to enable
-- efficient parallel implementation. The array must not be empty.
--
-- > scanr (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> Array (Z :. 10) [45,45,44,42,39,35,30,24,17,9]
--
mkScanr1
    :: forall aenv sh e. (Shape sh, Elt e)
    => UID
    -> Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanr1 uid aenv combine arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = (+++) <$> mkScanS R uid aenv combine Nothing arr
          <*> mkScanP R uid aenv combine Nothing arr
  --
  | otherwise
  = mkScanS R uid aenv combine Nothing arr


-- Variant of 'scanr' where the final result is returned in a separate array.
--
-- > scanr' (+) 10 (use $ fromList (Z :. 10) [0..])
-- >
-- > ==> ( Array (Z :. 10) [55,54,52,49,45,40,34,27,19,10]
--       , Array Z [55]
--       )
--
mkScanr'
    :: forall aenv sh e. (Shape sh, Elt e)
    => UID
    -> Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScanr' uid aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScan'S R uid aenv combine seed arr
                              , mkScan'P R uid aenv combine seed arr
                              , mkScan'Fill uid aenv seed
                              ]
  --
  | otherwise
  = (+++) <$> mkScan'S R uid aenv combine seed arr
          <*> mkScan'Fill uid aenv seed


-- If the innermost dimension of an exclusive scan is empty, then we just fill
-- the result with the seed element.
--
mkScanFill
    :: (Shape sh, Elt e)
    => UID
    -> Gamma aenv
    -> IRExp Native aenv e
    -> CodeGen (IROpenAcc Native aenv (Array sh e))
mkScanFill uid aenv seed =
  mkGenerate uid aenv (IRFun1 (const seed))

mkScan'Fill
    :: forall aenv sh e. (Shape sh, Elt e)
    => UID
    -> Gamma aenv
    -> IRExp Native aenv e
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScan'Fill uid aenv seed =
  Safe.coerce <$> (mkScanFill uid aenv seed :: CodeGen (IROpenAcc Native aenv (Array sh e)))


-- A single thread sequentially scans along an entire innermost dimension. For
-- inclusive scans we can assume that the innermost-dimension is at least one
-- element.
--
-- Note that we can use this both when there is a single thread, or in parallel
-- where threads are scheduled over the outer dimensions (segments).
--
mkScanS
    :: forall aenv sh e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanS dir uid aenv combine mseed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh:.Int) e))
      paramEnv                  = envParam aenv
      --
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc uid "scanS" (paramGang ++ paramOut ++ paramEnv) $ do

    sz    <- indexHead <$> delayedExtent
    szp1  <- A.add numType sz (lift 1)
    szm1  <- A.sub numType sz (lift 1)

    -- loop over each lower-dimensional index (segment)
    imapFromTo start end $ \seg -> do

      -- index i* is the index that we will read data from. Recall that the
      -- supremum index is exclusive
      i0 <- case dir of
              L -> A.mul numType sz seg
              R -> do x <- A.mul numType sz seg
                      y <- A.add numType szm1 x
                      return y

      -- index j* is the index that we write to. Recall that for exclusive scans
      -- the output array inner dimension is one larger than the input.
      j0 <- case mseed of
              Nothing -> return i0        -- merge 'i' and 'j' indices whenever we can
              Just{}  -> case dir of
                           L -> A.mul numType szp1 seg
                           R -> do x <- A.mul numType szp1 seg
                                   y <- A.add numType x sz
                                   return y

      -- Evaluate or read the initial element. Update the read-from index
      -- appropriately.
      (v0,i1) <- case mseed of
                   Just seed -> (,) <$> seed                       <*> pure i0
                   Nothing   -> (,) <$> app1 delayedLinearIndex i0 <*> next i0

      -- Write first element, then continue looping through the rest
      writeArray arrOut j0 v0
      j1 <- next j0

      iz <- case dir of
              L -> A.add numType i0 sz
              R -> A.sub numType i0 sz

      let cont i = case dir of
                     L -> A.lt singleType i iz
                     R -> A.gt singleType i iz

      void $ while (cont . A.fst3)
                   (\(A.untrip -> (i,j,v)) -> do
                       u  <- app1 delayedLinearIndex i
                       v' <- case dir of
                               L -> app2 combine v u
                               R -> app2 combine u v
                       writeArray arrOut j v'
                       A.trip <$> next i <*> next j <*> pure v')
                   (A.trip i1 j1 v0)

    return_


mkScan'S
    :: forall aenv sh e. (Shape sh, Elt e)
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> IRExp Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScan'S dir uid aenv combine seed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh:.Int) e))
      (arrSum, paramSum)        = mutableArray ("sum" :: Name (Array sh e))
      paramEnv                  = envParam aenv
      --
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc uid "scanS" (paramGang ++ paramOut ++ paramSum ++ paramEnv) $ do

    sz    <- indexHead <$> delayedExtent
    szm1  <- A.sub numType sz (lift 1)

    -- iterate over each lower-dimensional index (segment)
    imapFromTo start end $ \seg -> do

      -- index to read data from
      i0 <- case dir of
              L -> A.mul numType seg sz
              R -> do x <- A.mul numType sz seg
                      y <- A.add numType x szm1
                      return y

      -- initial element
      v0 <- seed

      iz <- case dir of
              L -> A.add numType i0 sz
              R -> A.sub numType i0 sz

      let cont i  = case dir of
                      L -> A.lt singleType i iz
                      R -> A.gt singleType i iz

      -- Loop through the input. Only at the top of the loop to we write the
      -- carry-in value (i.e. value from the last loop iteration) to the output
      -- array. This ensures correct behaviour if the input array was empty.
      r  <- while (cont . A.fst)
                  (\(A.unpair -> (i,v)) -> do
                      writeArray arrOut i v

                      u  <- app1 delayedLinearIndex i
                      v' <- case dir of
                              L -> app2 combine v u
                              R -> app2 combine u v
                      i' <- next i
                      return $ A.pair i' v')
                  (A.pair i0 v0)

      -- write final reduction result
      writeArray arrSum seg (A.snd r)

    return_


mkScanP
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP dir uid aenv combine mseed arr =
  foldr1 (+++) <$> sequence [ mkScanP1 dir uid aenv combine mseed arr
                            , mkScanP2 dir uid aenv combine
                            , mkScanP3 dir uid aenv combine mseed
                            ]

-- Parallel scan, step 1.
--
-- Threads scan a stripe of the input into a temporary array, incorporating the
-- initial element and any fused functions on the way. The final reduction
-- result of this chunk is written to a separate array.
--
mkScanP1
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP1 dir uid aenv combine mseed IRDelayed{..} =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      steps                     = local           scalarType ("ix.steps"  :: Name Int)
      paramSteps                = scalarParameter scalarType ("ix.steps"  :: Name Int)
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      --
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
      firstChunk                = case dir of
                                    L -> lift 0
                                    R -> steps
  in
  makeOpenAcc uid "scanP1" (paramGang ++ paramStride : paramSteps : paramOut ++ paramTmp ++ paramEnv) $ do

    len <- indexHead <$> delayedExtent

    -- A thread scans a non-empty stripe of the input, storing the final
    -- reduction result into a separate array.
    --
    -- For exclusive scans the first chunk must incorporate the initial element
    -- into the input and output, while all other chunks increment their output
    -- index by one.
    inf <- A.mul numType    chunk stride
    a   <- A.add numType    inf   stride
    sup <- A.min singleType a     len

    -- index i* is the index that we read data from. Recall that the supremum
    -- index is exclusive
    i0  <- case dir of
             L -> return inf
             R -> next sup

    -- index j* is the index that we write to. Recall that for exclusive scan
    -- the output array is one larger than the input; the first chunk uses
    -- this spot to write the initial element, all other chunks shift by one.
    j0  <- case mseed of
             Nothing -> return i0
             Just _  -> case dir of
                          L -> if A.eq singleType chunk firstChunk
                                 then return i0
                                 else next i0
                          R -> if A.eq singleType chunk firstChunk
                                 then return sup
                                 else return i0

    -- Evaluate/read the initial element for this chunk. Update the read-from
    -- index appropriately
    (v0,i1) <- A.unpair <$> case mseed of
                 Just seed -> if A.eq singleType chunk firstChunk
                                then A.pair <$> seed                       <*> pure i0
                                else A.pair <$> app1 delayedLinearIndex i0 <*> next i0
                 Nothing   ->        A.pair <$> app1 delayedLinearIndex i0 <*> next i0

    -- Write first element
    writeArray arrOut j0 v0
    j1  <- next j0

    -- Continue looping through the rest of the input
    let cont i =
           case dir of
             L -> A.lt  singleType i sup
             R -> A.gte singleType i inf

    r   <- while (cont . A.fst3)
                 (\(A.untrip -> (i,j,v)) -> do
                     u  <- app1 delayedLinearIndex i
                     v' <- case dir of
                             L -> app2 combine v u
                             R -> app2 combine u v
                     writeArray arrOut j v'
                     A.trip <$> next i <*> next j <*> pure v')
                 (A.trip i1 j1 v0)

    -- Final reduction result of this chunk
    writeArray arrTmp chunk (A.thd3 r)

    return_


-- Parallel scan, step 2.
--
-- A single thread performs an in-place inclusive scan of the partial block
-- sums. This forms the carry-in value which are added to the stripe partial
-- results in the final step.
--
mkScanP2
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP2 dir uid aenv combine =
  let
      (start, end, paramGang)   = gangParam
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      cont i                    = case dir of
                                    L -> A.lt  singleType i end
                                    R -> A.gte singleType i start

      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc uid "scanP2" (paramGang ++ paramTmp ++ paramEnv) $ do

    i0 <- case dir of
            L -> return start
            R -> next end

    v0 <- readArray arrTmp i0
    i1 <- next i0

    void $ while (cont . A.fst)
                 (\(A.unpair -> (i,v)) -> do
                    u  <- readArray arrTmp i
                    i' <- next i
                    v' <- case dir of
                            L -> app2 combine v u
                            R -> app2 combine u v
                    writeArray arrTmp i v'
                    return $ A.pair i' v')
                 (A.pair i1 v0)

    return_


-- Parallel scan, step 3.
--
-- Threads combine every element of the partial block results with the carry-in
-- value computed from step 2.
--
-- Note that we launch (chunks-1) threads, because the first chunk does not need
-- extra processing (has no carry-in value).
--
mkScanP3
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP3 dir uid aenv combine mseed =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      --
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
      prev i                    = case dir of
                                    L -> A.sub numType i (lift 1)
                                    R -> A.add numType i (lift 1)
  in
  makeOpenAcc uid "scanP3" (paramGang ++ paramStride : paramOut ++ paramTmp ++ paramEnv) $ do

    -- Determine which chunk will be carrying in values for. Compute appropriate
    -- start and end indices.
    a     <- case dir of
               L -> next chunk
               R -> pure chunk

    b     <- A.mul numType    a stride
    c     <- A.add numType    b stride
    d     <- A.min singleType c (indexHead (irArrayShape arrOut))

    (inf,sup) <- case (dir,mseed) of
                   (L,Just _) -> (,) <$> next b <*> next d
                   _          -> (,) <$> pure b <*> pure d

    -- Carry in value from the previous chunk
    e     <- case dir of
               L -> pure chunk
               R -> prev chunk
    carry <- readArray arrTmp e

    imapFromTo inf sup $ \i -> do
      x <- readArray arrOut i
      y <- case dir of
             L -> app2 combine carry x
             R -> app2 combine x carry
      writeArray arrOut i y

    return_


mkScan'P
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> IRExp Native aenv e
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P dir uid aenv combine seed arr =
  foldr1 (+++) <$> sequence [ mkScan'P1 dir uid aenv combine seed arr
                            , mkScan'P2 dir uid aenv combine
                            , mkScan'P3 dir uid aenv combine
                            ]

-- Parallel scan', step 1
--
-- Threads scan a stripe of the input into a temporary array. Similar to
-- exclusive scan, but since the size of the output array is the same as the
-- input, input and output indices are shifted by one.
--
mkScan'P1
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> IRExp Native aenv e
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P1 dir uid aenv combine seed IRDelayed{..} =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      steps                     = local           scalarType ("ix.steps"  :: Name Int)
      paramSteps                = scalarParameter scalarType ("ix.steps"  :: Name Int)
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      --
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)

      firstChunk                = case dir of
                                    L -> lift 0
                                    R -> steps
  in
  makeOpenAcc uid "scanP1" (paramGang ++ paramStride : paramSteps : paramOut ++ paramTmp ++ paramEnv) $ do

    -- Compute the start and end indices for this non-empty chunk of the input.
    --
    len <- indexHead <$> delayedExtent
    inf <- A.mul numType    chunk stride
    a   <- A.add numType    inf   stride
    sup <- A.min singleType a     len

    -- index i* is the index that we pull data from.
    i0 <- case dir of
            L -> return inf
            R -> next sup

    -- index j* is the index that we write results to. The first chunk needs to
    -- include the initial element, and all other chunks shift their results
    -- across by one to make space.
    j0      <- if A.eq singleType chunk firstChunk
                 then pure i0
                 else next i0

    -- Evaluate/read the initial element. Update the read-from index
    -- appropriately.
    (v0,i1) <- A.unpair <$> if A.eq singleType chunk firstChunk
                              then A.pair <$> seed                       <*> pure i0
                              else A.pair <$> app1 delayedLinearIndex i0 <*> pure j0

    -- Write the first element
    writeArray arrOut j0 v0
    j1 <- next j0

    -- Continue looping through the rest of the input
    let cont i =
           case dir of
             L -> A.lt  singleType i sup
             R -> A.gte singleType i inf

    r  <- while (cont . A.fst3)
                (\(A.untrip-> (i,j,v)) -> do
                    u  <- app1 delayedLinearIndex i
                    v' <- case dir of
                            L -> app2 combine v u
                            R -> app2 combine u v
                    writeArray arrOut j v'
                    A.trip <$> next i <*> next j <*> pure v')
                (A.trip i1 j1 v0)

    -- Write the final reduction result of this chunk
    writeArray arrTmp chunk (A.thd3 r)

    return_


-- Parallel scan', step 2
--
-- Identical to mkScanP2, except we store the total scan result into a separate
-- array (rather than discard it).
--
mkScan'P2
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P2 dir uid aenv combine =
  let
      (start, end, paramGang)   = gangParam
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      (arrSum, paramSum)        = mutableArray ("sum" :: Name (Scalar e))
      paramEnv                  = envParam aenv
      --
      cont i                    = case dir of
                                    L -> A.lt  singleType i end
                                    R -> A.gte singleType i start

      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc uid "scanP2" (paramGang ++ paramSum ++ paramTmp ++ paramEnv) $ do

    i0 <- case dir of
            L -> return start
            R -> next end

    v0 <- readArray arrTmp i0
    i1 <- next i0

    r  <- while (cont . A.fst)
                (\(A.unpair -> (i,v)) -> do
                   u  <- readArray arrTmp i
                   i' <- next i
                   v' <- case dir of
                           L -> app2 combine v u
                           R -> app2 combine u v
                   writeArray arrTmp i v'
                   return $ A.pair i' v')
                (A.pair i1 v0)

    writeArray arrSum (lift 0 :: IR Int) (A.snd r)

    return_


-- Parallel scan', step 3
--
-- Similar to mkScanP3, except that indices are shifted by one since the output
-- array is the same size as the input (despite being an exclusive scan).
--
-- Launch (chunks-1) threads, because the first chunk does not need extra
-- processing.
--
mkScan'P3
    :: forall aenv e. Elt e
    => Direction
    -> UID
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P3 dir uid aenv combine =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      --
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      --
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
      prev i                    = case dir of
                                    L -> A.sub numType i (lift 1)
                                    R -> A.add numType i (lift 1)
  in
  makeOpenAcc uid "scanP3" (paramGang ++ paramStride : paramOut ++ paramTmp ++ paramEnv) $ do

    -- Determine which chunk we will be carrying in the values of, and compute
    -- the appropriate start and end indices
    a     <- case dir of
               L -> next chunk
               R -> pure chunk

    b     <- A.mul numType a stride
    c     <- A.add numType b stride
    d     <- A.min singleType c (indexHead (irArrayShape arrOut))

    inf   <- next b
    sup   <- next d

    -- Carry-value from the previous chunk
    e     <- case dir of
               L -> pure chunk
               R -> prev chunk

    carry <- readArray arrTmp e

    imapFromTo inf sup $ \i -> do
      x <- readArray arrOut i
      y <- case dir of
             L -> app2 combine carry x
             R -> app2 combine x carry
      writeArray arrOut i y

    return_