{-# LANGUAGE DuplicateRecordFields #-}

module Callgraph
  ( Callgraph.create,
    Callgraph.Graph,
    Callgraph.Vertex,
    Callgraph.vertices,
    Callgraph.neighbors,
    Callgraph.filter,
    Callgraph.filterWithKey,
    Callgraph.recursive,
    Callgraph.leaf,
    Callgraph.callers,
    Callgraph.callees,
    Callgraph.mostCalled,
    Callgraph.mostConnected,
    Callgraph.reachable,
    Callgraph.order,
    Callgraph.size,
  )
where

import Binja.AnalysisContext
import Binja.Types
import qualified Data.Map as Map
import qualified Data.Set as Set

type Vertex = Binja.Types.Symbol

type Graph = Map.Map Vertex (Set.Set Vertex)

-- | Derive a callgraph from an AnalysisContext.
-- __Note__: Not all runtime evaluated call destinations will be recovered
-- via Binja.AnalysisContext.callers.
create :: AnalysisContext -> Graph
create :: AnalysisContext -> Graph
create AnalysisContext
context =
  Graph -> Graph -> Graph
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union Graph
initialGraph (Graph -> Graph) -> Graph -> Graph
forall a b. (a -> b) -> a -> b
$ [(Vertex, Set Vertex)] -> Graph
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Vertex, Set Vertex)] -> Graph)
-> [(Vertex, Set Vertex)] -> Graph
forall a b. (a -> b) -> a -> b
$ (Vertex -> (Vertex, Set Vertex))
-> [Vertex] -> [(Vertex, Set Vertex)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (\Vertex
v -> (Vertex
v, Set Vertex
forall a. Set a
Set.empty)) [Vertex]
allChildren
  where
    initialGraph :: Graph
    initialGraph :: Graph
initialGraph =
      [(Vertex, Set Vertex)] -> Graph
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Vertex, Set Vertex)] -> Graph)
-> [(Vertex, Set Vertex)] -> Graph
forall a b. (a -> b) -> a -> b
$
        (FunctionContext -> (Vertex, Set Vertex))
-> [FunctionContext] -> [(Vertex, Set Vertex)]
forall a b. (a -> b) -> [a] -> [b]
map (\FunctionContext
f -> (FunctionContext -> Vertex
symbol FunctionContext
f, AnalysisContext -> FunctionContext -> Set Vertex
Binja.AnalysisContext.callers AnalysisContext
context FunctionContext
f)) ([FunctionContext] -> [(Vertex, Set Vertex)])
-> [FunctionContext] -> [(Vertex, Set Vertex)]
forall a b. (a -> b) -> a -> b
$
          AnalysisContext -> [FunctionContext]
functions AnalysisContext
context
    allChildren :: [Vertex]
    allChildren :: [Vertex]
allChildren = Set Vertex -> [Vertex]
forall a. Set a -> [a]
Set.toList (Set Vertex -> [Vertex]) -> Set Vertex -> [Vertex]
forall a b. (a -> b) -> a -> b
$ [Set Vertex] -> Set Vertex
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions ([Set Vertex] -> Set Vertex) -> [Set Vertex] -> Set Vertex
forall a b. (a -> b) -> a -> b
$ Graph -> [Set Vertex]
forall k a. Map k a -> [a]
Map.elems Graph
initialGraph

vertices :: Graph -> [Vertex]
vertices :: Graph -> [Vertex]
vertices = Graph -> [Vertex]
forall k a. Map k a -> [k]
Map.keys

neighbors :: Graph -> Vertex -> Maybe (Set.Set Vertex)
neighbors :: Graph -> Vertex -> Maybe (Set Vertex)
neighbors Graph
graph' Vertex
source = Vertex -> Graph -> Maybe (Set Vertex)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Vertex
source Graph
graph'

-- | Filter all children that satisfy the predicate
filter :: (Set.Set Vertex -> Bool) -> Graph -> Graph
filter :: (Set Vertex -> Bool) -> Graph -> Graph
filter = (Set Vertex -> Bool) -> Graph -> Graph
forall a k. (a -> Bool) -> Map k a -> Map k a
Map.filter

-- | Filter all keys/values that satisfy the predicate
filterWithKey :: (Vertex -> Set.Set Vertex -> Bool) -> Graph -> Graph
filterWithKey :: (Vertex -> Set Vertex -> Bool) -> Graph -> Graph
filterWithKey = (Vertex -> Set Vertex -> Bool) -> Graph -> Graph
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey

-- | List of recursive vertex
recursive :: Graph -> [Vertex]
recursive :: Graph -> [Vertex]
recursive Graph
graph' =
  Graph -> [Vertex]
Callgraph.vertices (Graph -> [Vertex]) -> Graph -> [Vertex]
forall a b. (a -> b) -> a -> b
$
    (Vertex -> Set Vertex -> Bool) -> Graph -> Graph
Callgraph.filterWithKey (\Vertex
parent Set Vertex
child -> Vertex -> Set Vertex -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Vertex
parent Set Vertex
child) Graph
graph'

-- | List of vertex with no children
leaf :: Graph -> [Vertex]
leaf :: Graph -> [Vertex]
leaf Graph
graph' = Graph -> [Vertex]
Callgraph.vertices (Graph -> [Vertex]) -> Graph -> [Vertex]
forall a b. (a -> b) -> a -> b
$ (Set Vertex -> Bool) -> Graph -> Graph
Callgraph.filter Set Vertex -> Bool
forall a. Set a -> Bool
Set.null Graph
graph'

-- | List of symbols which call source vertex
callers :: Graph -> Vertex -> [Vertex]
callers :: Graph -> Vertex -> [Vertex]
callers Graph
graph' Vertex
source =
  Graph -> [Vertex]
Callgraph.vertices (Graph -> [Vertex]) -> Graph -> [Vertex]
forall a b. (a -> b) -> a -> b
$
    (Set Vertex -> Bool) -> Graph -> Graph
Callgraph.filter (Vertex -> Set Vertex -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Vertex
source) Graph
graph'

-- | List of symbols which source vertex calls
callees :: Graph -> Vertex -> [Vertex]
callees :: Graph -> Vertex -> [Vertex]
callees Graph
graph' Vertex
source =
  [Vertex]
-> (Set Vertex -> [Vertex]) -> Maybe (Set Vertex) -> [Vertex]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] Set Vertex -> [Vertex]
forall a. Set a -> [a]
Set.toList (Maybe (Set Vertex) -> [Vertex]) -> Maybe (Set Vertex) -> [Vertex]
forall a b. (a -> b) -> a -> b
$ Graph -> Vertex -> Maybe (Set Vertex)
Callgraph.neighbors Graph
graph' Vertex
source

-- | Find the vertex with the maximum sum of callers.
--   If multiple vertex share the same maximum caller sum
--   return the first vertex found of maximum sum.
mostCalled :: Graph -> Maybe Vertex
mostCalled :: Graph -> Maybe Vertex
mostCalled Graph
graph' =
  case Graph -> [Vertex]
Callgraph.vertices Graph
graph' of
    [] -> Maybe Vertex
forall a. Maybe a
Nothing
    Vertex
v : [Vertex]
vs -> Vertex -> Maybe Vertex
forall a. a -> Maybe a
Just (Vertex -> Maybe Vertex) -> Vertex -> Maybe Vertex
forall a b. (a -> b) -> a -> b
$ (Vertex, Int) -> Vertex
forall a b. (a, b) -> a
fst ((Vertex, Int) -> Vertex) -> (Vertex, Int) -> Vertex
forall a b. (a -> b) -> a -> b
$ (Vertex -> (Vertex, Int) -> (Vertex, Int))
-> (Vertex, Int) -> [Vertex] -> (Vertex, Int)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Vertex -> (Vertex, Int) -> (Vertex, Int)
step (Vertex
v, Vertex -> Int
value Vertex
v) [Vertex]
vs
  where
    value :: Vertex -> Int
    value :: Vertex -> Int
value Vertex
v = [Vertex] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Graph -> Vertex -> [Vertex]
Callgraph.callers Graph
graph' Vertex
v)

    step :: Vertex -> (Vertex, Int) -> (Vertex, Int)
    step :: Vertex -> (Vertex, Int) -> (Vertex, Int)
step Vertex
candidate (Vertex
curVertex, Int
curVal) =
      if Int
curVal Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vertex -> Int
value Vertex
candidate
        then (Vertex
candidate, Vertex -> Int
value Vertex
candidate)
        else (Vertex
curVertex, Int
curVal)

-- | Find the vertex with the maximum sum of callers and callees.
--   If multiple vertex share the same maximum sum return the
--   first vertex found of maximum sum.
mostConnected :: Graph -> Maybe Vertex
mostConnected :: Graph -> Maybe Vertex
mostConnected Graph
graph' =
  case Graph -> [Vertex]
Callgraph.vertices Graph
graph' of
    [] -> Maybe Vertex
forall a. Maybe a
Nothing
    Vertex
v : [Vertex]
vs -> Vertex -> Maybe Vertex
forall a. a -> Maybe a
Just (Vertex -> Maybe Vertex) -> Vertex -> Maybe Vertex
forall a b. (a -> b) -> a -> b
$ (Vertex, Int) -> Vertex
forall a b. (a, b) -> a
fst ((Vertex, Int) -> Vertex) -> (Vertex, Int) -> Vertex
forall a b. (a -> b) -> a -> b
$ (Vertex -> (Vertex, Int) -> (Vertex, Int))
-> (Vertex, Int) -> [Vertex] -> (Vertex, Int)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Vertex -> (Vertex, Int) -> (Vertex, Int)
step (Vertex
v, Vertex -> Int
value Vertex
v) [Vertex]
vs
  where
    value :: Vertex -> Int
    value :: Vertex -> Int
value Vertex
v =
      [Vertex] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Graph -> Vertex -> [Vertex]
Callgraph.callers Graph
graph' Vertex
v)
        Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Vertex] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Graph -> Vertex -> [Vertex]
callees Graph
graph' Vertex
v)

    step :: Vertex -> (Vertex, Int) -> (Vertex, Int)
    step :: Vertex -> (Vertex, Int) -> (Vertex, Int)
step Vertex
candidate (Vertex
curVertex, Int
curVal) =
      if Int
curVal Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vertex -> Int
value Vertex
candidate
        then (Vertex
candidate, Vertex -> Int
value Vertex
candidate)
        else (Vertex
curVertex, Int
curVal)

-- | Is destination vertex reachable from source vertex
reachable :: Graph -> Vertex -> Vertex -> Bool
reachable :: Graph -> Vertex -> Vertex -> Bool
reachable Graph
graph' Vertex
source Vertex
destination = Set Vertex -> Vertex -> Bool
go Set Vertex
forall a. Set a
Set.empty Vertex
source
  where
    go :: Set.Set Vertex -> Vertex -> Bool
    go :: Set Vertex -> Vertex -> Bool
go Set Vertex
visited Vertex
v
      | Vertex
v Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
== Vertex
destination = Bool
True
      | Vertex -> Set Vertex -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Vertex
v Set Vertex
visited = Bool
False
      | Bool
otherwise =
          case Graph -> Vertex -> Maybe (Set Vertex)
Callgraph.neighbors Graph
graph' Vertex
v of
            Maybe (Set Vertex)
Nothing -> Bool
False
            Just Set Vertex
ns ->
              let visited' :: Set Vertex
visited' = Vertex -> Set Vertex -> Set Vertex
forall a. Ord a => a -> Set a -> Set a
Set.insert Vertex
v Set Vertex
visited
               in (Vertex -> Bool) -> [Vertex] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Set Vertex -> Vertex -> Bool
go Set Vertex
visited') (Set Vertex -> [Vertex]
forall a. Set a -> [a]
Set.toList Set Vertex
ns)

-- | Number of nodes
order :: Graph -> Int
order :: Graph -> Int
order = Graph -> Int
forall k a. Map k a -> Int
Map.size

-- | Numer of edges
size :: Graph -> Int
size :: Graph -> Int
size = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> (Graph -> [Int]) -> Graph -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Set Vertex -> Int) -> [Set Vertex] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Set Vertex -> Int
forall a. Set a -> Int
Set.size ([Set Vertex] -> [Int])
-> (Graph -> [Set Vertex]) -> Graph -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Graph -> [Set Vertex]
forall k a. Map k a -> [a]
Map.elems