{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.FoldSeg
where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Exp ( indexHead )
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Fold
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target ( Native )
import Control.Applicative
import Control.Monad
import Prelude as P
mkFoldSeg
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Array (sh :. Int) e)
-> IRDelayed Native aenv (Segments i)
-> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFoldSeg uid aenv combine seed arr seg =
(+++) <$> mkFoldSegS uid aenv combine (Just seed) arr seg
<*> mkFoldSegP uid aenv combine (Just seed) arr seg
mkFold1Seg
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRDelayed Native aenv (Array (sh :. Int) e)
-> IRDelayed Native aenv (Segments i)
-> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFold1Seg uid aenv combine arr seg =
(+++) <$> mkFoldSegS uid aenv combine Nothing arr seg
<*> mkFoldSegP uid aenv combine Nothing arr seg
mkFoldSegS
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Array (sh :. Int) e)
-> IRDelayed Native aenv (Segments i)
-> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFoldSegS uid aenv combine mseed arr seg =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh :. Int) e))
paramEnv = envParam aenv
in
makeOpenAcc uid "foldSegS" (paramGang ++ paramOut ++ paramEnv) $ do
ss <- indexHead <$> delayedExtent seg
let test si = A.lt singleType (A.fst si) end
initial = A.pair start (lift 0)
body :: IR (Int,Int) -> CodeGen (IR (Int,Int))
body (A.unpair -> (s,inf)) = do
s' <- case rank (undefined::sh) of
0 -> return s
_ -> A.rem integralType s ss
len <- A.fromIntegral integralType numType =<< app1 (delayedLinearIndex seg) s'
sup <- A.add numType inf len
r <- case mseed of
Just seed -> do z <- seed
reduceFromTo inf sup (app2 combine) z (app1 (delayedLinearIndex arr))
Nothing -> reduce1FromTo inf sup (app2 combine) (app1 (delayedLinearIndex arr))
writeArray arrOut s r
t <- A.add numType s (lift 1)
return $ A.pair t sup
void $ while test body initial
return_
mkFoldSegP
:: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Array (sh :. Int) e)
-> IRDelayed Native aenv (Segments i)
-> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFoldSegP uid aenv combine mseed arr seg =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh :. Int) e))
paramEnv = envParam aenv
in
makeOpenAcc uid "foldSegP" (paramGang ++ paramOut ++ paramEnv) $ do
sz <- indexHead <$> delayedExtent arr
ss <- do n <- indexHead <$> delayedExtent seg
A.sub numType n (lift 1)
imapFromTo start end $ \s -> do
i <- case rank (undefined::sh) of
0 -> return s
_ -> A.rem integralType s ss
j <- A.add numType i (lift 1)
u <- A.fromIntegral integralType numType =<< app1 (delayedLinearIndex seg) i
v <- A.fromIntegral integralType numType =<< app1 (delayedLinearIndex seg) j
(inf,sup) <- A.unpair <$> case rank (undefined::sh) of
0 -> return (A.pair u v)
_ -> do q <- A.quot integralType s ss
a <- A.mul numType q sz
A.pair <$> A.add numType u a <*> A.add numType v a
r <- case mseed of
Just seed -> do z <- seed
reduceFromTo inf sup (app2 combine) z (app1 (delayedLinearIndex arr))
Nothing -> reduce1FromTo inf sup (app2 combine) (app1 (delayedLinearIndex arr))
writeArray arrOut s r
return_