module Ermine.Unification.Type
( unifyType
, zonkKindsAndTypes
, zonkKindsAndTypesWith
) where
import Control.Applicative
import Control.Comonad
import Control.Lens
import Control.Monad
import Control.Monad.ST
import Control.Monad.ST.Class
import Control.Monad.Writer.Strict
import Data.Bitraversable
import Data.Set as Set
import Data.IntMap as IntMap
import Data.Set.Lens
import Data.STRef
import Data.Traversable
import Ermine.Diagnostic
import Ermine.Pretty
import Ermine.Pretty.Type
import Ermine.Inference.Kind (checkKind)
import Ermine.Syntax.Scope
import Ermine.Syntax.Type as Type
import Ermine.Syntax.Kind as Kind hiding (Var)
import Ermine.Unification.Kind
import Ermine.Unification.Meta
import Ermine.Unification.Sharing
typeOccurs
:: (MonadWriter Any m, MonadMeta s m)
=> Depth -> TypeM s -> (MetaT s -> Bool) -> m (TypeM s)
typeOccurs depth1 t p = zonkKindsAndTypesWith t tweakDepth tweakType where
tweakDepth :: MonadST m => Meta (World m) f a -> m ()
tweakDepth m = liftST $ forMOf_ metaDepth m $ \d -> do
depth2 <- readSTRef d
when (depth2 > depth1) $ writeSTRef d depth1
tweakType m
| p m = do
zt <- sharing t $ zonk t
let st = setOf typeVars zt
mt = IntMap.fromList $ zipWith (\u n -> (u^.metaId, n)) (Set.toList st) names
sk = setOf kindVars zt
mk = IntMap.fromList $ zipWith (\u n -> (u^.metaId, n)) (Set.toList sk) names
v = mt ^?! ix (st ^?! folded.filtered p.metaId)
td = prettyType
(bimap (\u -> mk ^?! ix (u^.metaId))
(\u -> mt ^?! ix (u^.metaId)) zt)
(drop (Set.size st) names)
(1)
r <- viewMeta rendering
throwM $ die r "infinite type detected" & footnotes .~ [text "cyclic type:" <+> hang 4 (group (pretty v </> char '=' </> td))]
| otherwise = tweakDepth m
zonkKindsAndTypes
:: (MonadMeta s m, MonadWriter Any m)
=> TypeM s -> m (TypeM s)
zonkKindsAndTypes t = zonkKindsAndTypesWith t (const $ return ()) (const $ return ())
zonkKindsAndTypesWith
:: (MonadMeta s m, MonadWriter Any m)
=> TypeM s -> (MetaK s -> m ()) -> (MetaT s -> m ()) -> m (TypeM s)
zonkKindsAndTypesWith fs0 tweakKind tweakType = go fs0 where
go fs = bindType id id <$> bitraverse handleKind handleType fs
handleType m = do
tweakType m
zmv <- zonkWith (m^.metaValue) tweakKind
readMeta m >>= \mv -> case mv of
Nothing -> return $ return $ m & metaValue .~ zmv
Just fmf -> do
tell $ Any True
r <- go fmf
r <$ writeMeta m r
handleKind m = do
tweakKind m
readMeta m >>= \mv -> case mv of
Nothing -> return (return m)
Just fmf -> do
tell $ Any True
r <- zonkWith fmf tweakKind
r <$ writeMeta m r
unifyType
:: (MonadWriter Any m, MonadMeta s m)
=> TypeM s -> TypeM s -> m (TypeM s)
unifyType t1 t2 = do
t1' <- semiprune t1
t2' <- semiprune t2
go t1' t2'
where
go x@(Var tv1) (Var tv2) | tv1 == tv2 = return x
go x@(Var (Meta k _ i r d u)) y@(Var (Meta l _ j s e v)) = do
() <$ unifyKind k l
m <- liftST $ readSTRef u
n <- liftST $ readSTRef v
case compare m n of
LT -> unifyTV True i r d y $ return ()
EQ -> unifyTV True i r d y $ writeSTRef v $! n + 1
GT -> unifyTV False j s e x $ return ()
go (Var (Meta k _ i r d _)) t = do
checkKind (view metaValue <$> t) k
unifyTV True i r d t $ return ()
go t (Var (Meta k _ i r d _)) = do
checkKind (view metaValue <$> t) k
unifyTV False i r d t $ return ()
go (App f x) (App g y) = App <$> unifyType f g <*> unifyType x y
go (Loc l s) t = Loc l <$> unifyType s t
go s (Loc _ t) = unifyType s t
go Exists{} _ = fail "unifyType: existential"
go _ Exists{} = fail "unifyType: existential"
go t@(Forall m xs _ a) t'@(Forall n ys _ b)
| m /= n = fail "unifyType: forall: mismatched kind arity"
| length xs /= length ys = fail "unifyType: forall: mismatched type arity"
| otherwise = do
((sts, sks), Any modified) <- listen $ do
sks <- for m $ newMeta False
let nxs = instantiateVars sks <$> fmap extract xs
nys = instantiateVars sks <$> fmap extract ys
sts <- for (zip nxs nys) $ \(x,y) -> do
k <- unifyKind x y
newMeta k Nothing
_ <- unifyType (instantiateKindVars sks (instantiateVars sts a))
(instantiateKindVars sks (instantiateVars sts b))
return (sts, sks)
if modified
then do
_ <- checkDistinct sks
sts' <- checkDistinct sts
fst <$> checkSkolems Nothing both sts' (t, t')
else return t
go t@(HardType x) (HardType y) | x == y = return t
go _ _ = fail "type mismatch"
unifyTV
:: (MonadWriter Any m, MonadMeta s m)
=> Bool -> Int -> STRef s (Maybe (TypeM s)) -> STRef s Depth
-> TypeM s -> ST s () -> m (TypeM s)
unifyTV interesting i r d t bump = liftST (readSTRef r) >>= \ mt1 -> case mt1 of
Just j -> do
(t', Any m) <- listen $ unifyType j t
if m then liftST $ t' <$ writeSTRef r (Just t')
else j <$ tell (Any True)
Nothing -> case t of
Var v@(Meta k _ _ _ e _) -> do
tell (Any interesting)
zk <- zonkWith k $ \kv -> liftST $ do
let f = kv^.metaDepth
depth1 <- readSTRef d
depth2 <- readSTRef f
when (depth2 > depth1) $ writeSTRef f depth1
liftST $ do
bump
writeSTRef r (Just t)
depth1 <- readSTRef d
depth2 <- readSTRef e
when (depth2 > depth1) $ writeSTRef e depth1
return $ Var $ v & metaValue .~ zk
_ -> do
tell (Any interesting)
depth1 <- liftST $ readSTRef d
zt <- typeOccurs depth1 t $ \v -> v^.metaId == i
zt <$ liftST (writeSTRef r (Just zt))