{-# LANGUAGE DuplicateRecordFields #-}

module Binja.BasicBlock
  ( Binja.BasicBlock.fromMlilFunction,
    Binja.BasicBlock.fromMlilSSAFunction,
    Binja.BasicBlock.outgoingEdges,
    Binja.BasicBlock.incomingEdges,
    Binja.BasicBlock.fromBlockPtr,
    Binja.BasicBlock.fromBlockEdge,
  )
where

import Binja.FFI
import Binja.Types (BNBasicBlockEdge (..), BNBasicBlockPtr, BNMlilFunctionPtr, BNMlilSSAFunctionPtr, BasicBlockEdge (..), BasicBlockMlilSSA (..), Ptr, alloca, castPtr, nullPtr, peek, peekArray)
import Binja.Utils (toBool)

fromMlilFunction :: BNMlilFunctionPtr -> IO [BNBasicBlockPtr]
fromMlilFunction :: BNMlilFunctionPtr -> IO [BNBasicBlockPtr]
fromMlilFunction BNMlilFunctionPtr
func = do
  (Ptr CSize -> IO [BNBasicBlockPtr]) -> IO [BNBasicBlockPtr]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [BNBasicBlockPtr]) -> IO [BNBasicBlockPtr])
-> (Ptr CSize -> IO [BNBasicBlockPtr]) -> IO [BNBasicBlockPtr]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
countPtr -> do
    arrPtr <- BNMlilFunctionPtr -> Ptr CSize -> IO (Ptr BNBasicBlockPtr)
c_BNGetMediumLevelILBasicBlockList BNMlilFunctionPtr
func Ptr CSize
countPtr
    count' <- peek countPtr
    if arrPtr == nullPtr || count' == 0
      then error "basicBlocks: arrPtr null or count is 0"
      else do
        refs <- peekArray (fromIntegral count') (castPtr arrPtr :: Ptr BNBasicBlockPtr)
        c_BNFreeBasicBlockList arrPtr count'
        pure refs

fromMlilSSAFunction :: BNMlilSSAFunctionPtr -> IO [BNBasicBlockPtr]
fromMlilSSAFunction :: BNMlilSSAFunctionPtr -> IO [BNBasicBlockPtr]
fromMlilSSAFunction BNMlilSSAFunctionPtr
func = do
  (Ptr CSize -> IO [BNBasicBlockPtr]) -> IO [BNBasicBlockPtr]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [BNBasicBlockPtr]) -> IO [BNBasicBlockPtr])
-> (Ptr CSize -> IO [BNBasicBlockPtr]) -> IO [BNBasicBlockPtr]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
countPtr -> do
    arrPtr <- BNMlilSSAFunctionPtr -> Ptr CSize -> IO (Ptr BNBasicBlockPtr)
c_BNGetMediumLevelILSSABasicBlockList BNMlilSSAFunctionPtr
func Ptr CSize
countPtr
    count' <- peek countPtr
    if arrPtr == nullPtr || count' == 0
      then error "basicBlocks: arrPtr null or count is 0"
      else do
        refs <- peekArray (fromIntegral count') (castPtr arrPtr :: Ptr BNBasicBlockPtr)
        c_BNFreeBasicBlockList arrPtr count'
        pure refs

outgoingEdges :: BNBasicBlockPtr -> IO [BNBasicBlockEdge]
outgoingEdges :: BNBasicBlockPtr -> IO [BNBasicBlockEdge]
outgoingEdges BNBasicBlockPtr
blockPtr = do
  (Ptr CSize -> IO [BNBasicBlockEdge]) -> IO [BNBasicBlockEdge]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [BNBasicBlockEdge]) -> IO [BNBasicBlockEdge])
-> (Ptr CSize -> IO [BNBasicBlockEdge]) -> IO [BNBasicBlockEdge]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
countPtr -> do
    arrPtr <- BNBasicBlockPtr -> Ptr CSize -> IO BNBasicBlockEdgePtr
c_BNGetBasicBlockOutgoingEdges BNBasicBlockPtr
blockPtr Ptr CSize
countPtr
    count' <- peek countPtr
    edges <- peekArray (fromIntegral count') (castPtr arrPtr :: Ptr BNBasicBlockEdge)
    c_BNFreeBasicBlockEdgeList arrPtr count'
    pure edges

incomingEdges :: BNBasicBlockPtr -> IO [BNBasicBlockEdge]
incomingEdges :: BNBasicBlockPtr -> IO [BNBasicBlockEdge]
incomingEdges BNBasicBlockPtr
blockPtr = do
  (Ptr CSize -> IO [BNBasicBlockEdge]) -> IO [BNBasicBlockEdge]
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSize -> IO [BNBasicBlockEdge]) -> IO [BNBasicBlockEdge])
-> (Ptr CSize -> IO [BNBasicBlockEdge]) -> IO [BNBasicBlockEdge]
forall a b. (a -> b) -> a -> b
$ \Ptr CSize
countPtr -> do
    arrPtr <- BNBasicBlockPtr -> Ptr CSize -> IO BNBasicBlockEdgePtr
c_BNGetBasicBlockIncomingEdges BNBasicBlockPtr
blockPtr Ptr CSize
countPtr
    count' <- peek countPtr
    edges <- peekArray (fromIntegral count') (castPtr arrPtr :: Ptr BNBasicBlockEdge)
    c_BNFreeBasicBlockEdgeList arrPtr count'
    pure edges

fromBlockPtr :: BNBasicBlockPtr -> IO BasicBlockMlilSSA
fromBlockPtr :: BNBasicBlockPtr -> IO BasicBlockMlilSSA
fromBlockPtr BNBasicBlockPtr
blockPtr = do
  startInstructionIndex <- BNBasicBlockPtr -> IO CULLong
c_BNGetBasicBlockStart BNBasicBlockPtr
blockPtr
  endInstructionIndex <- c_BNGetBasicBlockEnd blockPtr
  canExit' <- c_BNBasicBlockCanExit blockPtr -- CBool to Bool
  hasInvalidInstructions' <- c_BNBasicBlockHasInvalidInstructions blockPtr -- CBool to Bool
  pure $
    BasicBlockMlilSSA
      { handle = blockPtr,
        start = fromIntegral startInstructionIndex,
        end = fromIntegral endInstructionIndex - 1,
        canExit = toBool canExit',
        hasInvalidInstructions = toBool hasInvalidInstructions'
      }

fromBlockEdge :: BNBasicBlockEdge -> IO BasicBlockEdge
fromBlockEdge :: BNBasicBlockEdge -> IO BasicBlockEdge
fromBlockEdge
  BNBasicBlockEdge
    { ty :: BNBasicBlockEdge -> BNBranchType
ty = BNBranchType
edgeTy,
      target :: BNBasicBlockEdge -> BNBasicBlockPtr
target = BNBasicBlockPtr
target',
      backEdge :: BNBasicBlockEdge -> CBool
backEdge = CBool
backEdge',
      fallThrough :: BNBasicBlockEdge -> CBool
fallThrough = CBool
fallThrough'
    } = do
    liftedBlock <- BNBasicBlockPtr -> IO BasicBlockMlilSSA
fromBlockPtr BNBasicBlockPtr
target'
    pure
      BasicBlockEdge
        { ty = edgeTy,
          target = liftedBlock,
          backEdge = Binja.Utils.toBool backEdge',
          fallThrough = Binja.Utils.toBool fallThrough'
        }