{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.Scan
where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR ( IR )
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Generate
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target ( Native )
import Control.Applicative
import Control.Monad
import Data.String ( fromString )
import Data.Coerce as Safe
import Prelude as P
data Direction = L | R
mkScanl
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanl uid aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanS L uid aenv combine (Just seed) arr
, mkScanP L uid aenv combine (Just seed) arr
, mkScanFill uid aenv seed
]
| otherwise
= (+++) <$> mkScanS L uid aenv combine (Just seed) arr
<*> mkScanFill uid aenv seed
mkScanl1
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanl1 uid aenv combine arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= (+++) <$> mkScanS L uid aenv combine Nothing arr
<*> mkScanP L uid aenv combine Nothing arr
| otherwise
= mkScanS L uid aenv combine Nothing arr
mkScanl'
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScanl' uid aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScan'S L uid aenv combine seed arr
, mkScan'P L uid aenv combine seed arr
, mkScan'Fill uid aenv seed
]
| otherwise
= (+++) <$> mkScan'S L uid aenv combine seed arr
<*> mkScan'Fill uid aenv seed
mkScanr
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanr uid aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScanS R uid aenv combine (Just seed) arr
, mkScanP R uid aenv combine (Just seed) arr
, mkScanFill uid aenv seed
]
| otherwise
= (+++) <$> mkScanS R uid aenv combine (Just seed) arr
<*> mkScanFill uid aenv seed
mkScanr1
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanr1 uid aenv combine arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= (+++) <$> mkScanS R uid aenv combine Nothing arr
<*> mkScanP R uid aenv combine Nothing arr
| otherwise
= mkScanS R uid aenv combine Nothing arr
mkScanr'
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScanr' uid aenv combine seed arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= foldr1 (+++) <$> sequence [ mkScan'S R uid aenv combine seed arr
, mkScan'P R uid aenv combine seed arr
, mkScan'Fill uid aenv seed
]
| otherwise
= (+++) <$> mkScan'S R uid aenv combine seed arr
<*> mkScan'Fill uid aenv seed
mkScanFill
:: (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRExp Native aenv e
-> CodeGen (IROpenAcc Native aenv (Array sh e))
mkScanFill uid aenv seed =
mkGenerate uid aenv (IRFun1 (const seed))
mkScan'Fill
:: forall aenv sh e. (Shape sh, Elt e)
=> UID
-> Gamma aenv
-> IRExp Native aenv e
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScan'Fill uid aenv seed =
Safe.coerce <$> (mkScanFill uid aenv seed :: CodeGen (IROpenAcc Native aenv (Array sh e)))
mkScanS
:: forall aenv sh e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanS dir uid aenv combine mseed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e))
paramEnv = envParam aenv
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
in
makeOpenAcc uid "scanS" (paramGang ++ paramOut ++ paramEnv) $ do
sz <- indexHead <$> delayedExtent
szp1 <- A.add numType sz (lift 1)
szm1 <- A.sub numType sz (lift 1)
imapFromTo start end $ \seg -> do
i0 <- case dir of
L -> A.mul numType sz seg
R -> do x <- A.mul numType sz seg
y <- A.add numType szm1 x
return y
j0 <- case mseed of
Nothing -> return i0
Just{} -> case dir of
L -> A.mul numType szp1 seg
R -> do x <- A.mul numType szp1 seg
y <- A.add numType x sz
return y
(v0,i1) <- case mseed of
Just seed -> (,) <$> seed <*> pure i0
Nothing -> (,) <$> app1 delayedLinearIndex i0 <*> next i0
writeArray arrOut j0 v0
j1 <- next j0
iz <- case dir of
L -> A.add numType i0 sz
R -> A.sub numType i0 sz
let cont i = case dir of
L -> A.lt singleType i iz
R -> A.gt singleType i iz
void $ while (cont . A.fst3)
(\(A.untrip -> (i,j,v)) -> do
u <- app1 delayedLinearIndex i
v' <- case dir of
L -> app2 combine v u
R -> app2 combine u v
writeArray arrOut j v'
A.trip <$> next i <*> next j <*> pure v')
(A.trip i1 j1 v0)
return_
mkScan'S
:: forall aenv sh e. (Shape sh, Elt e)
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Array (sh:.Int) e)
-> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScan'S dir uid aenv combine seed IRDelayed{..} =
let
(start, end, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Array (sh:.Int) e))
(arrSum, paramSum) = mutableArray ("sum" :: Name (Array sh e))
paramEnv = envParam aenv
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
in
makeOpenAcc uid "scanS" (paramGang ++ paramOut ++ paramSum ++ paramEnv) $ do
sz <- indexHead <$> delayedExtent
szm1 <- A.sub numType sz (lift 1)
imapFromTo start end $ \seg -> do
i0 <- case dir of
L -> A.mul numType seg sz
R -> do x <- A.mul numType sz seg
y <- A.add numType x szm1
return y
v0 <- seed
iz <- case dir of
L -> A.add numType i0 sz
R -> A.sub numType i0 sz
let cont i = case dir of
L -> A.lt singleType i iz
R -> A.gt singleType i iz
r <- while (cont . A.fst)
(\(A.unpair -> (i,v)) -> do
writeArray arrOut i v
u <- app1 delayedLinearIndex i
v' <- case dir of
L -> app2 combine v u
R -> app2 combine u v
i' <- next i
return $ A.pair i' v')
(A.pair i0 v0)
writeArray arrSum seg (A.snd r)
return_
mkScanP
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Vector e)
-> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP dir uid aenv combine mseed arr =
foldr1 (+++) <$> sequence [ mkScanP1 dir uid aenv combine mseed arr
, mkScanP2 dir uid aenv combine
, mkScanP3 dir uid aenv combine mseed
]
mkScanP1
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> IRDelayed Native aenv (Vector e)
-> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP1 dir uid aenv combine mseed IRDelayed{..} =
let
(chunk, _, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
steps = local scalarType ("ix.steps" :: Name Int)
paramSteps = scalarParameter scalarType ("ix.steps" :: Name Int)
stride = local scalarType ("ix.stride" :: Name Int)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
firstChunk = case dir of
L -> lift 0
R -> steps
in
makeOpenAcc uid "scanP1" (paramGang ++ paramStride : paramSteps : paramOut ++ paramTmp ++ paramEnv) $ do
len <- indexHead <$> delayedExtent
inf <- A.mul numType chunk stride
a <- A.add numType inf stride
sup <- A.min singleType a len
i0 <- case dir of
L -> return inf
R -> next sup
j0 <- case mseed of
Nothing -> return i0
Just _ -> case dir of
L -> if A.eq singleType chunk firstChunk
then return i0
else next i0
R -> if A.eq singleType chunk firstChunk
then return sup
else return i0
(v0,i1) <- A.unpair <$> case mseed of
Just seed -> if A.eq singleType chunk firstChunk
then A.pair <$> seed <*> pure i0
else A.pair <$> app1 delayedLinearIndex i0 <*> next i0
Nothing -> A.pair <$> app1 delayedLinearIndex i0 <*> next i0
writeArray arrOut j0 v0
j1 <- next j0
let cont i =
case dir of
L -> A.lt singleType i sup
R -> A.gte singleType i inf
r <- while (cont . A.fst3)
(\(A.untrip -> (i,j,v)) -> do
u <- app1 delayedLinearIndex i
v' <- case dir of
L -> app2 combine v u
R -> app2 combine u v
writeArray arrOut j v'
A.trip <$> next i <*> next j <*> pure v')
(A.trip i1 j1 v0)
writeArray arrTmp chunk (A.thd3 r)
return_
mkScanP2
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP2 dir uid aenv combine =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
cont i = case dir of
L -> A.lt singleType i end
R -> A.gte singleType i start
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
in
makeOpenAcc uid "scanP2" (paramGang ++ paramTmp ++ paramEnv) $ do
i0 <- case dir of
L -> return start
R -> next end
v0 <- readArray arrTmp i0
i1 <- next i0
void $ while (cont . A.fst)
(\(A.unpair -> (i,v)) -> do
u <- readArray arrTmp i
i' <- next i
v' <- case dir of
L -> app2 combine v u
R -> app2 combine u v
writeArray arrTmp i v'
return $ A.pair i' v')
(A.pair i1 v0)
return_
mkScanP3
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> Maybe (IRExp Native aenv e)
-> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP3 dir uid aenv combine mseed =
let
(chunk, _, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
stride = local scalarType ("ix.stride" :: Name Int)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
prev i = case dir of
L -> A.sub numType i (lift 1)
R -> A.add numType i (lift 1)
in
makeOpenAcc uid "scanP3" (paramGang ++ paramStride : paramOut ++ paramTmp ++ paramEnv) $ do
a <- case dir of
L -> next chunk
R -> pure chunk
b <- A.mul numType a stride
c <- A.add numType b stride
d <- A.min singleType c (indexHead (irArrayShape arrOut))
(inf,sup) <- case (dir,mseed) of
(L,Just _) -> (,) <$> next b <*> next d
_ -> (,) <$> pure b <*> pure d
e <- case dir of
L -> pure chunk
R -> prev chunk
carry <- readArray arrTmp e
imapFromTo inf sup $ \i -> do
x <- readArray arrOut i
y <- case dir of
L -> app2 combine carry x
R -> app2 combine x carry
writeArray arrOut i y
return_
mkScan'P
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Vector e)
-> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P dir uid aenv combine seed arr =
foldr1 (+++) <$> sequence [ mkScan'P1 dir uid aenv combine seed arr
, mkScan'P2 dir uid aenv combine
, mkScan'P3 dir uid aenv combine
]
mkScan'P1
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> IRExp Native aenv e
-> IRDelayed Native aenv (Vector e)
-> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P1 dir uid aenv combine seed IRDelayed{..} =
let
(chunk, _, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
steps = local scalarType ("ix.steps" :: Name Int)
paramSteps = scalarParameter scalarType ("ix.steps" :: Name Int)
stride = local scalarType ("ix.stride" :: Name Int)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
firstChunk = case dir of
L -> lift 0
R -> steps
in
makeOpenAcc uid "scanP1" (paramGang ++ paramStride : paramSteps : paramOut ++ paramTmp ++ paramEnv) $ do
len <- indexHead <$> delayedExtent
inf <- A.mul numType chunk stride
a <- A.add numType inf stride
sup <- A.min singleType a len
i0 <- case dir of
L -> return inf
R -> next sup
j0 <- if A.eq singleType chunk firstChunk
then pure i0
else next i0
(v0,i1) <- A.unpair <$> if A.eq singleType chunk firstChunk
then A.pair <$> seed <*> pure i0
else A.pair <$> app1 delayedLinearIndex i0 <*> pure j0
writeArray arrOut j0 v0
j1 <- next j0
let cont i =
case dir of
L -> A.lt singleType i sup
R -> A.gte singleType i inf
r <- while (cont . A.fst3)
(\(A.untrip-> (i,j,v)) -> do
u <- app1 delayedLinearIndex i
v' <- case dir of
L -> app2 combine v u
R -> app2 combine u v
writeArray arrOut j v'
A.trip <$> next i <*> next j <*> pure v')
(A.trip i1 j1 v0)
writeArray arrTmp chunk (A.thd3 r)
return_
mkScan'P2
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P2 dir uid aenv combine =
let
(start, end, paramGang) = gangParam
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
(arrSum, paramSum) = mutableArray ("sum" :: Name (Scalar e))
paramEnv = envParam aenv
cont i = case dir of
L -> A.lt singleType i end
R -> A.gte singleType i start
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
in
makeOpenAcc uid "scanP2" (paramGang ++ paramSum ++ paramTmp ++ paramEnv) $ do
i0 <- case dir of
L -> return start
R -> next end
v0 <- readArray arrTmp i0
i1 <- next i0
r <- while (cont . A.fst)
(\(A.unpair -> (i,v)) -> do
u <- readArray arrTmp i
i' <- next i
v' <- case dir of
L -> app2 combine v u
R -> app2 combine u v
writeArray arrTmp i v'
return $ A.pair i' v')
(A.pair i1 v0)
writeArray arrSum (lift 0 :: IR Int) (A.snd r)
return_
mkScan'P3
:: forall aenv e. Elt e
=> Direction
-> UID
-> Gamma aenv
-> IRFun2 Native aenv (e -> e -> e)
-> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P3 dir uid aenv combine =
let
(chunk, _, paramGang) = gangParam
(arrOut, paramOut) = mutableArray ("out" :: Name (Vector e))
(arrTmp, paramTmp) = mutableArray ("tmp" :: Name (Vector e))
paramEnv = envParam aenv
stride = local scalarType ("ix.stride" :: Name Int)
paramStride = scalarParameter scalarType ("ix.stride" :: Name Int)
next i = case dir of
L -> A.add numType i (lift 1)
R -> A.sub numType i (lift 1)
prev i = case dir of
L -> A.sub numType i (lift 1)
R -> A.add numType i (lift 1)
in
makeOpenAcc uid "scanP3" (paramGang ++ paramStride : paramOut ++ paramTmp ++ paramEnv) $ do
a <- case dir of
L -> next chunk
R -> pure chunk
b <- A.mul numType a stride
c <- A.add numType b stride
d <- A.min singleType c (indexHead (irArrayShape arrOut))
inf <- next b
sup <- next d
e <- case dir of
L -> pure chunk
R -> prev chunk
carry <- readArray arrTmp e
imapFromTo inf sup $ \i -> do
x <- readArray arrOut i
y <- case dir of
L -> app2 combine carry x
R -> app2 combine x carry
writeArray arrOut i y
return_