module Ermine.Inference.Kind
( inferKind
, inferAnnotKind
, checkKind
, checkDataTypeKinds
, checkDataTypeGroup
, checkDataTypeKind
, checkConstructorKind
) where
import Bound
import Bound.Var
import Control.Applicative
import Control.Comonad
import Control.Lens
import Control.Monad
import Control.Monad.State
import Data.Foldable hiding (concat)
import Data.Graph (stronglyConnComp, flattenSCC)
import Data.List (nub, elemIndex)
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Traversable (for)
import Data.Void
import Ermine.Diagnostic
import Ermine.Syntax.Constructor as Data
import Ermine.Syntax.Data as Data
import Ermine.Syntax.Global
import Ermine.Syntax.Kind as Kind
import Ermine.Syntax.Name
import Ermine.Syntax.Scope
import qualified Ermine.Syntax.Type as Type
import Ermine.Syntax.Type hiding (Var)
import Ermine.Unification.Data
import Ermine.Unification.Kind
import Ermine.Unification.Meta
import Ermine.Unification.Sharing
productKind :: Int -> Kind k
productKind 0 = star
productKind n = star :-> productKind (n 1)
matchFunKind :: MonadMeta s m => KindM s -> m (KindM s, KindM s)
matchFunKind (a :-> b) = return (a, b)
matchFunKind (Var kv) = do
a <- Var <$> newMeta False Nothing
b <- Var <$> newMeta False Nothing
(a, b) <$ unsharingT (unifyKindVar kv (a :-> b))
matchFunKind (Type _) = fail "not a fun kind"
matchFunKind (HardKind _) = fail "not a fun kind"
instantiateSchema :: MonadMeta s m => Schema (MetaK s) -> m (KindM s)
instantiateSchema (Schema hs s) = do
vs <- forM hs (fmap Var . newMeta False)
return $ instantiate (vs!!) s
checkKind :: MonadMeta s m => Type (MetaK s) (KindM s) -> KindM s -> m ()
checkKind t k = do
k' <- inferKind t
() <$ unsharingT (unifyKind k k')
inferKind :: MonadMeta s m => Type (MetaK s) (KindM s) -> m (KindM s)
inferKind (Loc l t) = set rendering l `localMeta` inferKind t
inferKind (Type.Var tk) = return tk
inferKind (HardType Arrow) = do
a <- Var <$> newMeta True Nothing
b <- Var <$> newMeta True Nothing
return $ Type a :-> Type b :-> star
inferKind (HardType (Con _ s)) = instantiateSchema (vacuous s)
inferKind (HardType (Tuple n)) = return $ productKind n
inferKind (HardType ConcreteRho{}) = return rho
inferKind (App f x) = do
kf <- inferKind f
(a, b) <- matchFunKind kf
b <$ checkKind x a
inferKind (And cs) = constraint <$ traverse_ (checkKind ?? constraint) cs
inferKind (Exists n tks cs) = do
sks <- for n $ newMeta False
let btys = instantiateVars sks . extract <$> tks
checkKind (instantiateKindVars sks $ instantiateVars btys cs) constraint
return constraint
inferKind (Forall n tks cs b) = do
sks <- for n $ newMeta False
let btys = instantiateVars sks . extract <$> tks
checkKind (instantiateKindVars sks (instantiateVars btys cs)) constraint
checkKind (instantiateKindVars sks (instantiateVars btys b)) star
return star
inferAnnotKind :: MonadMeta s m => Annot (KindM s) -> m (KindM s)
inferAnnotKind (Annot hs hts ty) = do
ks <- for hs $ newMeta False
ts <- for hts $ fmap pure . newMeta False
inferKind . over kindVars (ks!!) $ instantiateVars ts ty
fixCons :: (Ord t) => Map t (Type k u) -> (t -> Type k u) -> DataType k t -> DataType k u
fixCons m f = boundBy (\t -> fromMaybe (f t) $ Map.lookup t m)
checkDataTypeKinds :: [DataType () Text] -> M s [DataType Void Void]
checkDataTypeKinds dts = flip evalStateT Map.empty
. fmap concat . traverse (ck.flattenSCC) $ stronglyConnComp graph
where
graph = map (\dt -> (dt, dt^.name, toListOf typeVars dt)) dts
conMap = Map.fromList . map (\dt -> (dt^.name, con (dt^.global) (dataTypeSchema dt)))
ck scc = StateT $ \m -> do
let scc' = map (fixCons m pure) scc
cscc <- checkDataTypeGroup scc'
return (cscc, m `Map.union` conMap cscc)
checkDataTypeGroup :: [DataType () Text] -> M s [DataType Void Void]
checkDataTypeGroup dts = do
dts' <- traverse (kindVars (\_ -> newMeta False Nothing)) dts
let m = Map.fromList $ map (\dt -> (dt^.name, dt)) dts'
checked <- for dts' $ \dt -> (dt,) <$> checkDataTypeKind m dt
let sm = Map.fromList $ (\(dt, (_, sch)) -> (dt^.name, con (dt^.global) sch)) <$> checked
sdts = checked <&> \(dt, (sks, _)) ->
(sks, fixCons sm (\_ -> error "checkDataTypeGroup: IMPOSSIBLE") dt)
traverse closeKinds sdts
where
closeKinds (sks, dt) = do
dt' <- zonkDataType dt
let fvs = nub . toListOf kindVars $ dt'
offset = length sks
f (F m) | Just i <- m `elemIndex` fvs = pure . B $ i + offset
| Just i <- m `elemIndex` sks = pure . B $ i + offset
| otherwise = fail "escaped skolem"
f (B i) = pure $ B i
reabstr = fmap toScope . traverse f . fromScope
ts' <- (traverse . traverse) reabstr $ dt'^.tparams
cs' <- (traverse . kindVars) f $ dt'^.constrs
return $ DataType (dt^.global) (dt^.kparams ++ (Nothing <$ fvs)) ts' cs'
checkDataTypeKind :: Map Text (DataType (MetaK s) a) -> DataType (MetaK s) Text
-> M s ([MetaK s], Schema v)
checkDataTypeKind m dt = do
sks <- for (dt^.kparams) $ newMeta False
let Schema _ s = dataTypeSchema dt
selfKind = instantiateVars sks s
dt' <- for dt $ \t ->
if t == dt^.name
then pure selfKind
else case Map.lookup t m of
Just sc -> refresh (dataTypeSchema sc)
Nothing -> fail "unknown reference"
let btys = instantiateVars sks . extract <$> (dt^.tparams)
for_ (dt'^.constrs) $ \c ->
checkConstructorKind $ bimap (unvar (sks!!) id) (unvar (btys !!) id) c
void $ checkDistinct sks
(sks,) <$> generalize selfKind
where
refresh (Schema ks s) =
(instantiateVars ?? s) <$> traverse (newMeta False) ks
checkConstructorKind :: Constructor (MetaK s) (KindM s) -> M s ()
checkConstructorKind (Constructor _ ks ts fs) = do
sks <- for ks $ newMeta False
let btys = instantiateVars sks . extract <$> ts
for_ fs $ \fld ->
checkKind (instantiateKindVars sks $ instantiateVars btys fld) star
() <$ checkDistinct sks