{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.Permute
where
import Data.Array.Accelerate.Array.Sugar ( Array, Vector, Shape, Elt, eltType )
import Data.Array.Accelerate.Error
import qualified Data.Array.Accelerate.Array.Sugar as S
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Permute
import Data.Array.Accelerate.LLVM.CodeGen.Ptr
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache
import Data.Array.Accelerate.LLVM.Native.Target ( Native )
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.RMW as RMW
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Representation
import Control.Applicative
import Control.Monad ( void )
import Data.Typeable
import Prelude
mkPermute
:: (Shape sh, Shape sh', Elt e)
=> UID
-> Gamma aenv
-> IRPermuteFun Native aenv (e -> e -> e)
-> IRFun1 Native aenv (sh -> sh')
-> IRDelayed Native aenv (Array sh e)
-> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermute uid aenv combine project arr =
(+++) <$> mkPermuteS uid aenv combine project arr
<*> mkPermuteP uid aenv combine project arr
mkPermuteS
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> UID
-> Gamma aenv
-> IRPermuteFun Native aenv (e -> e -> e)
-> IRFun1 Native aenv (sh -> sh')
-> IRDelayed Native aenv (Array sh e)
-> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteS uid aenv IRPermuteFun{..} project IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array sh' e))
paramEnv = envParam aenv
in
makeOpenAcc uid "permuteS" (paramGang ++ paramOut ++ paramEnv) $ do
sh <- delayedExtent
imapFromTo start end $ \i -> do
ix <- indexOfInt sh i
ix' <- app1 project ix
unless (ignore ix') $ do
j <- intOfIndex (irArrayShape arrOut) ix'
x <- app1 delayedLinearIndex i
y <- readArray arrOut j
r <- app2 combine x y
writeArray arrOut j r
return_
mkPermuteP
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> UID
-> Gamma aenv
-> IRPermuteFun Native aenv (e -> e -> e)
-> IRFun1 Native aenv (sh -> sh')
-> IRDelayed Native aenv (Array sh e)
-> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteP uid aenv IRPermuteFun{..} project arr =
case atomicRMW of
Nothing -> mkPermuteP_mutex uid aenv combine project arr
Just (rmw, f) -> mkPermuteP_rmw uid aenv rmw f project arr
mkPermuteP_rmw
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> UID
-> Gamma aenv
-> RMWOperation
-> IRFun1 Native aenv (e -> e)
-> IRFun1 Native aenv (sh -> sh')
-> IRDelayed Native aenv (Array sh e)
-> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteP_rmw uid aenv rmw update project IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array sh' e))
paramEnv = envParam aenv
in
makeOpenAcc uid "permuteP_rmw" (paramGang ++ paramOut ++ paramEnv) $ do
sh <- delayedExtent
imapFromTo start end $ \i -> do
ix <- indexOfInt sh i
ix' <- app1 project ix
unless (ignore ix') $ do
j <- intOfIndex (irArrayShape arrOut) ix'
x <- app1 delayedLinearIndex i
r <- app1 update x
case rmw of
Exchange
-> writeArray arrOut j r
_ | TypeRscalar (SingleScalarType s) <- eltType (undefined::e)
, Just adata <- gcast (irArrayData arrOut)
, Just r' <- gcast r
-> do
addr <- instr' $ GetElementPtr (asPtr defaultAddrSpace (op s adata)) [op integralType j]
case s of
NumSingleType (IntegralNumType t) -> void . instr' $ AtomicRMW t NonVolatile rmw addr (op t r') (CrossThread, AcquireRelease)
NumSingleType t | RMW.Add <- rmw -> atomicCAS_rmw s (A.add t r') addr
NumSingleType t | RMW.Sub <- rmw -> atomicCAS_rmw s (A.sub t r') addr
_ -> case rmw of
RMW.Min -> atomicCAS_cmp s A.lt addr (op s r')
RMW.Max -> atomicCAS_cmp s A.gt addr (op s r')
_ -> $internalError "mkPermute_rmw" "unexpected transition"
_ -> $internalError "mkPermute_rmw" "unexpected transition"
return_
mkPermuteP_mutex
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRFun1 Native aenv (sh -> sh')
-> IRDelayed Native aenv (Array sh e)
-> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteP_mutex uid aenv combine project IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array sh' e))
(arrLock, paramLock) = mutableArray ("lock" :: Name (Vector Word8))
paramEnv = envParam aenv
in
makeOpenAcc uid "permuteP_mutex" (paramGang ++ paramOut ++ paramLock ++ paramEnv) $ do
sh <- delayedExtent
imapFromTo start end $ \i -> do
ix <- indexOfInt sh i
ix' <- app1 project ix
unless (ignore ix') $ do
j <- intOfIndex (irArrayShape arrOut) ix'
x <- app1 delayedLinearIndex i
atomically arrLock j $ do
y <- readArray arrOut j
r <- app2 combine x y
writeArray arrOut j r
return_
atomically
:: IRArray (Vector Word8)
-> IR Int
-> CodeGen a
-> CodeGen a
atomically barriers i action = do
let
lock = integral integralType 1
unlock = integral integralType 0
unlocked = lift 0
spin <- newBlock "spinlock.entry"
crit <- newBlock "spinlock.critical-section"
exit <- newBlock "spinlock.exit"
addr <- instr' $ GetElementPtr (asPtr defaultAddrSpace (op integralType (irArrayData barriers))) [op integralType i]
_ <- br spin
setBlock spin
old <- instr $ AtomicRMW integralType NonVolatile Exchange addr lock (CrossThread, Acquire)
ok <- A.eq singleType old unlocked
_ <- cbr ok crit spin
setBlock crit
r <- action
_ <- instr $ AtomicRMW integralType NonVolatile Exchange addr unlock (CrossThread, Release)
_ <- br exit
setBlock exit
return r
ignore :: forall ix. Shape ix => IR ix -> CodeGen (IR Bool)
ignore (IR ix) = go (S.eltType (undefined::ix)) (S.fromElt (S.ignore::ix)) ix
where
go :: TupleType t -> t -> Operands t -> CodeGen (IR Bool)
go TypeRunit () OP_Unit = return (lift True)
go (TypeRpair tsh tsz) (ish, isz) (OP_Pair sh sz) = do x <- go tsh ish sh
y <- go tsz isz sz
land' x y
go (TypeRscalar s) ig sz = case s of
SingleScalarType t -> A.eq t (ir t (single t ig)) (ir t (op' t sz))
VectorScalarType{} -> $internalError "ignore" "unexpected shape type"