{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.Fold
where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Generate
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target ( Native )
import Control.Applicative
import Prelude as P hiding ( length )
mkFold
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Array (sh :. Int) e)
-> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFold uid aenv f z acc
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= (+++) <$> mkFoldAll uid aenv f (Just z) acc
<*> mkFoldFill uid aenv z
| otherwise
= (+++) <$> mkFoldDim uid aenv f (Just z) acc
<*> mkFoldFill uid aenv z
mkFold1
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRDelayed Native aenv (Array (sh :. Int) e)
-> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFold1 uid aenv f acc
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= mkFoldAll uid aenv f Nothing acc
| otherwise
= mkFoldDim uid aenv f Nothing acc
mkFoldDim
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Array (sh :. Int) e)
-> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFoldDim uid aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array sh e))
paramEnv = envParam aenv
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
stride = local scalarType ("ix.stride" :: Name Int)
in
makeOpenAcc uid "fold" (paramGang ++ paramStride : paramOut ++ paramEnv) $ do
imapFromTo start end $ \seg -> do
from <- mul numType seg stride
to <- add numType from stride
r <- case mseed of
Just seed -> do z <- seed
reduceFromTo from to (app2 combine) z (app1 delayedLinearIndex)
Nothing -> reduce1FromTo from to (app2 combine) (app1 delayedLinearIndex)
writeArray arrOut seg r
return_
mkFoldAll
:: forall aenv e. Elt e
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Vector e)
-> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAll uid aenv combine mseed arr =
foldr1 (+++) <$> sequence [ mkFoldAllS uid aenv combine mseed arr
, mkFoldAllP1 uid aenv combine arr
, mkFoldAllP2 uid aenv combine mseed
]
mkFoldAllS
:: forall aenv e. Elt e
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Vector e)
-> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAllS uid aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
paramEnv = envParam aenv
(arrOut, paramOut) = mutableArray ("out" :: Name (Scalar e))
zero = lift 0 :: IR Int
in
makeOpenAcc uid "foldAllS" (paramGang ++ paramOut ++ paramEnv) $ do
r <- case mseed of
Just seed -> do z <- seed
reduceFromTo start end (app2 combine) z (app1 delayedLinearIndex)
Nothing -> reduce1FromTo start end (app2 combine) (app1 delayedLinearIndex)
writeArray arrOut zero r
return_
mkFoldAllP1
:: forall aenv e. Elt e
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRDelayed Native aenv (Vector e)
-> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAllP1 uid aenv combine IRDelayed{..} =
let
(start, end, paramGang) = gangParam
paramEnv = envParam aenv
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
length = local scalarType ("ix.length" :: Name Int)
stride = local scalarType ("ix.stride" :: Name Int)
paramLength = scalarParameter scalarType ("ix.length" :: Name Int)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
in
makeOpenAcc uid "foldAllP1" (paramGang ++ paramLength : paramStride : paramTmp ++ paramEnv) $ do
imapFromTo start end $ \i -> do
inf <- A.mul numType i stride
a <- A.add numType inf stride
sup <- A.min singleType a length
r <- reduce1FromTo inf sup (app2 combine) (app1 delayedLinearIndex)
writeArray arrTmp i r
return_
mkFoldAllP2
:: forall aenv e. Elt e
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAllP2 uid aenv combine mseed =
let
(start, end, paramGang) = gangParam
paramEnv = envParam aenv
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
(arrOut, paramOut) = mutableArray ("out" :: Name (Scalar e))
zero = lift 0 :: IR Int
in
makeOpenAcc uid "foldAllP2" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do
r <- case mseed of
Just seed -> do z <- seed
reduceFromTo start end (app2 combine) z (readArray arrTmp)
Nothing -> reduce1FromTo start end (app2 combine) (readArray arrTmp)
writeArray arrOut zero r
return_
mkFoldFill
:: (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRExp Native aenv e
-> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFoldFill uid aenv seed =
mkGenerate uid aenv (IRFun1 (const seed))
reduceFromTo
:: Elt a
=> IR Int
-> IR Int
-> (IR a -> IR a -> CodeGen (IR a))
-> IR a
-> (IR Int -> CodeGen (IR a))
-> CodeGen (IR a)
reduceFromTo m n f z get =
iterFromTo m n z $ \i acc -> do
x <- get i
y <- f acc x
return y
reduce1FromTo
:: Elt a
=> IR Int
-> IR Int
-> (IR a -> IR a -> CodeGen (IR a))
-> (IR Int -> CodeGen (IR a))
-> CodeGen (IR a)
reduce1FromTo m n f get = do
z <- get m
m1 <- add numType m (ir numType (num numType 1))
reduceFromTo m1 n f z get