summaryrefslogtreecommitdiffstats
path: root/src/back/generator.py
blob: 56a143c8f779ccd9297e46309e6b82375d1ecdfc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from back.tac import *
from front.scope import Scope

class Register(object):
    __shared_state = {}
    __current_function = None
    __function_registers = {}
    count = 0

    def __init__(self):
        self.__dict__ = self.__shared_state

    def new(self):
        self.count += 1
        self.__function_registers[self.__current_function].append(self.count)
        return self.count

    def set_function(self, name):
        self.__current_function = name
        self.__function_registers[self.__current_function] = []

    def get_registers(self, name):
        return self.__function_registers[name]

class Generator(object):
    def __init__(self):
        self.tac_list = TACList()
        self.__op_list = []

    def emit(self, op, tac = None):
        if tac:
            self.__op_list[self.__op_list.index(tac)] = op
        else:
            self.__op_list.append(op)

    def generate_prologue(self, name):
        Scope().set_function(name)
        self.emit(Label(name))
        self.emit("PUSH bp")
        self.emit("ADD bp, r0, sp")
        for variable in Scope().get_variables()[1]:
            self.emit("PUSH r0")
        for register in Register().get_registers(name):
            self.emit("PUSH r%d" % register)

    def generate_epilogue(self, name):
        for variable in Scope().get_variables()[1]:
            self.emit("POP r0")
        Register().get_registers(name).reverse()
        for register in Register().get_registers(name):
            self.emit("POP r%d" % register)
        Register().get_registers(name).reverse()
        self.emit("ADD sp, r0, bp")
        self.emit("POP bp")
        if name == "main":
            self.emit("PUSH rv")
            self.emit("PUSH r0")
            self.emit("SYS")
        else:
            self.emit("RET")

    def generate(self):
        # pass 1 - generate non-label ops
        for tac in self.tac_list:
            if isinstance(tac, TAC):
                if tac.op in  [Op.ADD, Op.SUB, Op.MUL, Op.DIV, Op.MOD, Op.AND, Op.OR]:
                    self.emit("%s r%d, r%d, r%d" % (tac.op, tac.arg1, tac.arg1, tac.arg2))
                elif tac.op == Op.NOT:
                    self.emit("CMP r%d, r0" % tac.arg1)
                    self.emit("EQ  r%d" % tac.arg1)
                elif tac.op == Op.MINUS:
                    self.emit("SUB r%d, r0, r%d" % (tac.arg1, tac.arg1))
                elif tac.op in [Op.STORE, Op.LOAD]:
                    offset = Scope().get_variable_offset(tac.arg1)
                    self.emit("%s bp, r%d, %d" % (tac.op, tac.arg2, offset * 4))
                elif tac.op == Op.MOV:
                    self.emit("MOV r%d, %d" % (tac.arg1, tac.arg2))
                elif tac.op == Op.CMP:
                    self.emit("CMP r%d, r%d" % (tac.arg1, tac.arg2))
                elif tac.op in [Op.EQ, Op.NE, Op.LT, Op.LE, Op.GE, Op.GT]:
                    self.emit("%s r%d" % (tac.op, tac.arg1))
                elif tac.op in [Op.BEZ, Op.JMP]:
                    self.emit(tac)
                elif tac.op in [Op.PUSH, Op.POP]:
                    self.emit("%s r%d" % (tac.op, tac.arg1))
                elif tac.op == Op.CALL:
                    self.emit(tac)
                    self.emit("ADD r%d, r0, rv" % tac.arg2)
                elif tac.op == Op.RETURN:
                    self.emit("ADD rv, r0, r%d" % tac.arg1)
                    self.emit(tac)
            elif isinstance(tac, FunctionPrologue):
                self.generate_prologue(tac.name)
            elif isinstance(tac, FunctionEpilogue):
                self.generate_epilogue(tac.name)
            elif isinstance(tac, Label):
                self.emit(Label(tac.name))
            else:
                raise Exception("%s is not a valid TACList element", repr(tac))

        # pass 2 - generate label ops
        for op in self.__op_list:
            if isinstance(op, TAC):
                if op.op == Op.BEZ:
                    offset = self.get_label_offset(op.arg2, op)
                    self.emit("BEZ r%d, %d" % (op.arg1, offset), op)
                elif op.op == Op.JMP:
                    offset = self.get_label_offset(op.arg1, op)
                    self.emit("JMP %d" % offset, op)
                elif op.op == Op.CALL:
                    offset = self.get_label_offset(op.arg1, op)
                    self.emit("CALL %d" % offset, op)
                elif op.op == Op.RETURN:
                    offset = self.get_label_offset("L%d" % op.arg2, op)
                    self.emit("JMP %d" % offset, op)

        # pass 3 - remove labels
        self.__op_list = filter(lambda x: not isinstance(x, Label), self.__op_list)

        # pass 4 - replace pseudo registers and insert number of registers
        def replace_pseudo_regs(op):
            regs = Register().count
            bp = "r%d" % (regs + 1)
            sp = "r%d" % (regs + 2)
            rv = "r%d" % (regs + 3)
            return op.replace("bp", bp).replace("sp", sp).replace("rv", rv)

        self.__op_list = map(replace_pseudo_regs, self.__op_list)
        self.__op_list = [".REGS %d" % Register().count] + self.__op_list

        # finally, print the generated code
        print "\n".join(self.__op_list)

    def get_label_offset(self, name, tac):
        tac_index = 0
        for op in self.__op_list:
            if isinstance(op, Label):
                continue
            if op == tac:
                break
            tac_index += 1

        name_index = 0
        for op in self.__op_list:
            if isinstance(op, Label):
                if op.name == name:
                    return name_index - tac_index
                continue
            name_index += 1

        raise Exception("label %s does not exist" % name)