-
Notifications
You must be signed in to change notification settings - Fork 0
/
checker.scala
209 lines (175 loc) · 6.98 KB
/
checker.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import scala.io._
import cs162.assign4.syntax._
import Aliases._
import scala.io.Source.fromFile
//——————————————————————————————————————————————————————————————————————————————
// Main entry point
object Checker {
type TypeEnv = Map[Var, Type]
val TypeEnv = Map
object Illtyped extends Exception
def main( args:Array[String] ) {
val filename = args(0)
val input = fromFile(filename).mkString
Parsers.program.run(input, filename) match {
case Left(e) => // parse error
println(e)
case Right(program) =>
Checker(program.typedefs).getType(program.e, TypeEnv())
println("This program is well-typed")
}
}
}
case class Checker(typeDefs: Set[TypeDef]) {
import Checker.{ TypeEnv, Illtyped }
// Gets a listing of the constructor names associated with a given
// type definition. For example, consider the following type
// definition:
//
// type Either['A, 'B] = Left 'A | Right 'B
//
// Some example calls to `constructors`, along with return values:
//
// constructors("Either") = Set("Left", "Right")
// constructors("Foo") = a thrown Illtyped exception
//
def constructors(name: NameLabel): Set[ConsLabel] =
typeDefs.
find(_.name == name).
map(_.constructors.keySet).
getOrElse(throw Illtyped)
// Takes the following parameters:
// -The name of a user-defined type
// -The name of a user-defined constructor in that user-defined type
// -The types which we wish to apply to the constructor
// Returns the type that is held within the constructor.
//
// For example, consider the following type definition:
//
// type Either['A, 'B] = Left 'A | Right 'B
//
// Some example calls to `constructorType`, along with return values:
//
// constructorType("Either", "Left", Seq(NumT, BoolT)) = NumT
// constructorType("Either", "Right", Seq(NumT, BoolT)) = BoolT
// constructorType("Either", "Left", Seq(NumT)) = a thrown Illtyped exception
// constructorType("Either", "Right", Seq(BoolT)) = a thrown Illtyped exception
// constructorType("Either", "Foo", Seq(UnitT)) = a thrown Illtyped exception
// constructorType("Bar", "Left", Seq(UnitT)) = a thrown Illtyped exception
//
def constructorType(name: NameLabel, constructor: ConsLabel, types: Seq[Type]): Type =
(for {
td <- typeDefs.find(_.name == name)
rawType <- td.constructors.get(constructor)
if (types.size == td.tvars.size)
} yield replace(rawType, td.tvars.zip(types).toMap)).getOrElse(throw Illtyped)
// Given a type and a mapping of type variables to other types, it
// will recursively replace the type variables in `t` with the
// types in `tv2t`, if possible. If a type variable isn't
// in `tv2t`, it should simply return the original type. If a
// `TFunT` is encountered, then whatever type variables it defines
// (the first parameter in the `TFunT`) should overwrite whatever is in
// `tv2t` right before a recursive `replace` call. In other words,
// type variables can shadow other type variables.
//
def replace( t:Type, tv2t:Map[TVar, Type] ): Type =
t match {
case NumT | BoolT | UnitT => t
case FunT(params, ret) =>
FunT(params.map(a => replace(a, tv2t)), replace(ret, tv2t))
case RcdT(fields) =>
RcdT(fields.map(fe => (fe._1, replace(fe._2, tv2t))))
case TypT(name, typs) =>
TypT(name, typs.map(a => replace(a, tv2t)))
case tv:TVar =>
//System.out.println("tv2t contains: " + tv2t)
tv2t.getOrElse(tv, t)
case TFunT(tvars, funt) => replace(funt, tv2t)
}
// HINT - the bulk of this remains unchanged from the previous assignment.
// Feel free to copy and paste code from your last submission into here.
def getType( e:Exp, env:TypeEnv ): Type =
e match {
case x:Var => env.getOrElse(x, throw Illtyped)
case _:Num => NumT
case _:Bool => BoolT
case _:Unit => UnitT
case Plus | Minus | Times | Divide => FunT(Seq(NumT, NumT), NumT)
case LT | EQ => FunT(Seq(NumT, NumT), BoolT)
case And | Or => FunT(Seq(BoolT, BoolT), BoolT)
case Not => FunT(Seq(BoolT), BoolT)
case Fun(params, body) =>
FunT(params.map(s => s._2), getType(body, env ++ params.toMap))
case Call(fun, args) =>
getType(fun, env) match {
case FunT(params:Seq[Type], ret:Type) =>
if(params == args.map((e: Exp) => getType(e, env))) ret
else{
//System.out.println(params)
//System.out.println(args.map((e: Exp) => getType(e, env)))
throw Illtyped
}
case _ => throw Illtyped
}
case If(e1, e2, e3) =>
getType(e1, env) match {
case BoolT =>
if(getType(e2, env) == getType(e3, env)) getType(e3, env)
else throw Illtyped
case _ => throw Illtyped
}
case Let(x, e1, e2) =>
getType(e2, env + (x -> getType(e1, env)))
case Rec(x, t1, e1, e2) =>
t1 match {
case FunT(params, ret) =>
getType(e1, env + (x -> t1)) match {
case FunT(params1, ret1) =>
if (ret == ret1) getType(e1, env + (x -> t1))
else{
throw Illtyped
}
case _ =>
if (ret == getType(e1, env + (x -> t1)))
getType(e2, env + (x -> t1))
else throw Illtyped
}
case _ =>
if(t1 == getType(e1, env + (x -> t1))) t1
else throw Illtyped
}
case Record(fields) =>
RcdT(fields.map(fe => (fe._1, getType(fe._2, env))))
case Access(e, field) =>
getType(e, env) match {
case RcdT(fields) => fields.getOrElse(field, throw Illtyped)
case _ => throw Illtyped
}
case c @ Construct(name, constructor, typs, e) =>
if(constructorType(name, constructor, typs) == getType(e, env)) TypT(name, typs)
else throw Illtyped
case Match(e, cases) =>
if(cases.isEmpty) throw Illtyped
getType(e, env) match {
case a @ TypT(name, typs) =>
if(cases.size != constructors(name).size) throw Illtyped
if(cases.map(c => c._1).distinct.size != cases.size) throw Illtyped
if(cases.map(c => getType(c._3, env + (c._2 -> constructorType(name, c._1, typs)))).distinct.size != 1) throw Illtyped
a
case _ => throw Illtyped
}
case TAbs(tvars, fun) =>
getType(fun, env) match {
case f @ FunT(params, body) =>
TFunT(tvars, f)
case _ => throw Illtyped
}
case TApp(e, typs) =>
getType(e, env) match {
case f @ TFunT(tvars, funt) =>
//System.out.println(replace(f, Map()))
replace(f, f.tvars.zip(typs).toMap)
case _ => throw Illtyped
}
}
}