summaryrefslogtreecommitdiffstats
path: root/src/back/generator.py
blob: 84ab63917db327ed0071381fcdf1fdc62c13aeaa (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from back.tac import *
from front.scope import Scope

class Register(object):
    __shared_state = {}
    __current_function = None
    __function_registers = {}
    count = 0

    def __init__(self):
        self.__dict__ = self.__shared_state

    def new(self):
        self.count += 1
        self.__function_registers[self.__current_function].append(self.count)
        return self.count

    def set_function(self, name):
        self.__current_function = name
        self.__function_registers[self.__current_function] = []

    def get_registers(self, name):
        return self.__function_registers[name]

class Generator(object):
    def __init__(self):
        self.tac_list = TACList()
        self.__op_list = []

    def emit(self, op, tac = None):
        if tac:
            self.__op_list[self.__op_list.index(tac)] = op
        else:
            self.__op_list.append(op)

    # this cde will appear before every function
    # sets bp, pushes used registers and variables
    def generate_prologue(self, name):
        Scope().set_function(name)
        self.emit(Label(name))
        self.emit("PUSH bp")
        self.emit("ADD bp, r0, sp")
        for variable in Scope().get_variables()[1]:
            self.emit("PUSH r0")
        for register in Register().get_registers(name):
            self.emit("PUSH r%d" % register)

    # this code will appear after every function
    # restores registers, sp, bp
    def generate_epilogue(self, name):
        for variable in Scope().get_variables()[1]:
            self.emit("POP r0")
        Register().get_registers(name).reverse()
        for register in Register().get_registers(name):
            self.emit("POP r%d" % register)
        Register().get_registers(name).reverse()
        self.emit("ADD sp, r0, bp")
        self.emit("POP bp")
        self.emit("RET")

    def generate(self):
        # pass 1 - generate ops, but not labels yet
        for tac in self.tac_list:
            if isinstance(tac, TAC):
                if tac.op in  [Op.ADD, Op.SUB, Op.MUL, Op.DIV, Op.MOD, Op.AND, Op.OR]:
                    self.emit("%s r%d, r%d, r%d" % (tac.op, tac.arg1, tac.arg1, tac.arg2))
                elif tac.op == Op.NOT:
                    self.emit("CMP r%d, r0" % tac.arg1)
                    self.emit("EQ  r%d" % tac.arg1)
                elif tac.op == Op.MINUS:
                    self.emit("SUB r%d, r0, r%d" % (tac.arg1, tac.arg1))
                elif tac.op in [Op.STORE, Op.LOAD]:
                    offset = Scope().get_variable_offset(tac.arg1)
                    op = {Op.STORE: "SW", Op.LOAD: "LW"}[tac.op]
                    self.emit("%s bp, r%d, %d" % (op, tac.arg2, offset * 4))
                elif tac.op == Op.MOV:
                    self.emit("MOV r%d, %d" % (tac.arg1, tac.arg2))
                elif tac.op == Op.CMP:
                    self.emit("CMP r%d, r%d" % (tac.arg1, tac.arg2))
                elif tac.op in [Op.EQ, Op.NE, Op.LT, Op.LE, Op.GE, Op.GT]:
                    self.emit("%s r%d" % (tac.op, tac.arg1))
                elif tac.op in [Op.BEZ, Op.JMP]:
                    self.emit(tac)
                elif tac.op in [Op.PUSH, Op.POP]:
                    self.emit("%s r%d" % (tac.op, tac.arg1))
                elif tac.op == Op.CALL:
                    self.emit(tac)
                    self.emit("ADD r%d, r0, rv" % tac.arg2)
                elif tac.op == Op.RETURN:
                    self.emit("ADD rv, r0, r%d" % tac.arg1)
                    self.emit(tac)
                elif tac.op == Op.PRINT:
                    r = Register().new()
                    self.emit("PUSH r%d" % tac.arg1)
                    self.emit("MOV r%d, %d" % (r, 1))
                    self.emit("PUSH r%d" % r)
                    self.emit("SYS")
                    self.emit("POP r0")
                    self.emit("POP r0")
                elif tac.op == Op.SYS:
                    self.emit("SYS")
                else:
                    raise Exception("%s is not a valid TAC operator" % tac.op)
            elif isinstance(tac, FunctionPrologue):
                self.generate_prologue(tac.name)
            elif isinstance(tac, FunctionEpilogue):
                self.generate_epilogue(tac.name)
            elif isinstance(tac, Label):
                self.emit(Label(tac.name))
            else:
                raise Exception("%s is not a valid TACList element" % repr(tac))

        # pass 2 - generate label ops (ignores still present labels)
        for op in self.__op_list:
            if isinstance(op, TAC):
                if op.op == Op.BEZ:
                    offset = self.get_label_offset(op.arg2, op)
                    self.emit("BEZ r%d, %d" % (op.arg1, offset), op)
                elif op.op == Op.JMP:
                    offset = self.get_label_offset(op.arg1, op)
                    self.emit("JMP %d" % offset, op)
                elif op.op == Op.CALL:
                    offset = self.get_label_offset(op.arg1, op)
                    self.emit("CALL %d" % offset, op)
                elif op.op == Op.RETURN:
                    offset = self.get_label_offset("L%d" % op.arg2, op)
                    self.emit("JMP %d" % offset, op)

        # pass 3 - remove labels
        self.__op_list = filter(lambda x: not isinstance(x, Label), self.__op_list)

        # pass 4 - replace the pseudo registers with their register numbers
        def replace_pseudo_regs(op):
            regs = Register().count
            bp = "r%d" % (regs + 1) # base pointer
            sp = "r%d" % (regs + 2) # stack pointer
            rv = "r%d" % (regs + 3) # return value
            return op.replace("bp", bp).replace("sp", sp).replace("rv", rv)

        self.__op_list = map(replace_pseudo_regs, self.__op_list)
        self.__op_list = [".REGS %d" % Register().count] + self.__op_list

        # finally, print the generated code
        print "\n".join(self.__op_list)

    # find the line number to which the given label leads
    def get_label_offset(self, name, tac):
        tac_index = 0
        for op in self.__op_list:
            if isinstance(op, Label):
                continue
            if op == tac:
                break
            tac_index += 1

        name_index = 0
        for op in self.__op_list:
            if isinstance(op, Label):
                if op.name == name:
                    return name_index - tac_index
                continue
            name_index += 1

        raise Exception("label %s does not exist" % name)