-
Notifications
You must be signed in to change notification settings - Fork 0
/
semant.ml
170 lines (135 loc) · 6.18 KB
/
semant.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
(* Semantic checking for the TPL compiler *)
open Ast
module StringMap = Map.Make(String)
(* Semantic checking of a program. Returns void if successful,
throws an exception if something is wrong.
Check each global variable, then check each function *)
let check (globals, functions) =
(* Raise an exception if the given list has a duplicate *)
let report_duplicate exceptf list =
let rec helper = function
n1 :: n2 :: _ when n1 = n2 -> raise (Failure (exceptf n1))
| _ :: t -> helper t
| [] -> ()
in helper (List.sort compare list)
in
(* Raise an exception if a given binding is to a void type *)
let check_not_void exceptf = function
(Void, n) -> raise (Failure (exceptf n))
| _ -> ()
in
(* Raise an exception of the given rvalue type cannot be assigned to
the given lvalue type *)
let check_assign lvaluet rvaluet err =
if lvaluet == rvaluet then lvaluet else raise err
in
(**** Checking Global Variables ****)
List.iter (check_not_void (fun n -> "illegal void global " ^ n)) globals;
report_duplicate (fun n -> "duplicate global " ^ n) (List.map snd globals);
(**** Checking Functions ****)
let predefined_functions = [ "print" ; "printb" ; "printfloat" ] in
let check_is_predifined x = if List.mem x (List.map (fun fd -> fd.fname) functions)
then raise (Failure ("function " ^ x ^ " may not be defined")) else () in
List.iter check_is_predifined predefined_functions;
report_duplicate (fun n -> "duplicate function " ^ n)
(List.map (fun fd -> fd.fname) functions);
(* Function declaration for a named function *)
(* HJ, Add the definitions of other predefined functions here *)
let built_in_decls = StringMap.add "print"
{ typ = Void; fname = "print"; formals = [(Int, "x")];
locals = []; body = [] } (StringMap.add "printb"
{ typ = Void; fname = "printb"; formals = [(Bool, "x")];
locals = []; body = [] } (StringMap.singleton "printfloat"
{ typ = Void; fname = "printfloat"; formals = [(Float, "x")];
locals = []; body = [] } ))
in
let function_decls = List.fold_left (fun m fd -> StringMap.add fd.fname fd m)
built_in_decls functions
in
let function_decl s = try StringMap.find s function_decls
with Not_found -> raise (Failure ("unrecognized function " ^ s))
in
let _ = function_decl "main" in (* Ensure "main" is defined *)
let check_function func =
List.iter (check_not_void (fun n -> "illegal void formal " ^ n ^
" in " ^ func.fname)) func.formals;
report_duplicate (fun n -> "duplicate formal " ^ n ^ " in " ^ func.fname)
(List.map snd func.formals);
List.iter (check_not_void (fun n -> "illegal void local " ^ n ^
" in " ^ func.fname)) func.locals;
report_duplicate (fun n -> "duplicate local " ^ n ^ " in " ^ func.fname)
(List.map snd func.locals);
(* Type of each variable (global, formal, or local *)
let symbols = List.fold_left (fun m (t, n) -> StringMap.add n t m)
StringMap.empty (globals @ func.formals @ func.locals )
in
let type_of_identifier s =
try StringMap.find s symbols
with Not_found -> raise (Failure ("undeclared identifier " ^ s))
in
(* Return the type of an expression or throw an exception *)
let rec expr = function
Literal _ -> Int
| Floatliteral _ -> Float
| Strliteral _ -> String
| BoolLit _ -> Bool
| Id s -> type_of_identifier s
| Binop(e1, op, e2) as e -> let t1 = expr e1 and t2 = expr e2 in
(match op with
Add | Sub | Mult | Div when t1 = Int && t2 = Int -> Int
| Add | Sub | Mult | Div when t1 = Float && t2 = Float -> Float
| Equal | Neq when t1 = t2 -> Bool
| Less | Leq | Greater | Geq when t1 = Int && t2 = Int -> Bool
| And | Or when t1 = Bool && t2 = Bool -> Bool
| _ -> raise (Failure ("illegal binary operator " ^
string_of_typ t1 ^ " " ^ string_of_op op ^ " " ^
string_of_typ t2 ^ " in " ^ string_of_expr e))
)
| Unop(op, e) as ex -> let t = expr e in
(match op with
Neg when t = Int -> Int
| Not when t = Bool -> Bool
| _ -> raise (Failure ("illegal unary operator " ^ string_of_uop op ^
string_of_typ t ^ " in " ^ string_of_expr ex)))
| Noexpr -> Void
| Assign(var, e) as ex -> let lt = type_of_identifier var
and rt = expr e in
check_assign lt rt (Failure ("illegal assignment " ^ string_of_typ lt ^
" = " ^ string_of_typ rt ^ " in " ^
string_of_expr ex))
| Call(fname, actuals) as call -> let fd = function_decl fname in
if List.length actuals != List.length fd.formals then
raise (Failure ("expecting " ^ string_of_int
(List.length fd.formals) ^ " arguments in " ^ string_of_expr call))
else
List.iter2 (fun (ft, _) e -> let et = expr e in
ignore (check_assign ft et
(Failure ("illegal actual argument found " ^ string_of_typ et ^
" expected " ^ string_of_typ ft ^ " in " ^ string_of_expr e))))
fd.formals actuals;
fd.typ
in
let check_bool_expr e = if expr e != Bool
then raise (Failure ("expected Boolean expression in " ^ string_of_expr e))
else () in
(* Verify a statement or throw an exception *)
let rec stmt = function
Block sl -> let rec check_block = function
[Return _ as s] -> stmt s
| Return _ :: _ -> raise (Failure "nothing may follow a return")
| Block sl :: ss -> check_block (sl @ ss)
| s :: ss -> stmt s ; check_block ss
| [] -> ()
in check_block sl
| Expr e -> ignore (expr e)
| Return e -> let t = expr e in if t = func.typ then () else
raise (Failure ("return gives " ^ string_of_typ t ^ " expected " ^
string_of_typ func.typ ^ " in " ^ string_of_expr e))
| If(p, b1, b2) -> check_bool_expr p; stmt b1; stmt b2
| For(e1, e2, e3, st) -> ignore (expr e1); check_bool_expr e2;
ignore (expr e3); stmt st
| While(p, s) -> check_bool_expr p; stmt s
in
stmt (Block func.body)
in
List.iter check_function functions