927 lines
28 KiB
Python
927 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Simple C to DSA Assembly Compiler
|
|
Supports a subset of C including:
|
|
- int variables and functions
|
|
- Arithmetic operations (+, -, *, /)
|
|
- Comparisons (==, !=, <, >, <=, >=)
|
|
- If/else statements
|
|
- While loops
|
|
- Function calls
|
|
- Return statements
|
|
"""
|
|
|
|
import re
|
|
import sys
|
|
from typing import List, Dict, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from pprint import pprint
|
|
import json
|
|
|
|
|
|
class TokenType(Enum):
|
|
# Keywords
|
|
INT = "int"
|
|
IF = "if"
|
|
ELSE = "else"
|
|
WHILE = "while"
|
|
RETURN = "return"
|
|
|
|
# Identifiers and literals
|
|
IDENTIFIER = "IDENTIFIER"
|
|
NUMBER = "NUMBER"
|
|
|
|
# Operators
|
|
PLUS = "+"
|
|
MINUS = "-"
|
|
STAR = "*"
|
|
SLASH = "/"
|
|
ASSIGN = "="
|
|
EQ = "=="
|
|
NE = "!="
|
|
LT = "<"
|
|
GT = ">"
|
|
LE = "<="
|
|
GE = ">="
|
|
|
|
# Delimiters
|
|
LPAREN = "("
|
|
RPAREN = ")"
|
|
LBRACE = "{"
|
|
RBRACE = "}"
|
|
SEMICOLON = ";"
|
|
COMMA = ","
|
|
|
|
EOF = "EOF"
|
|
|
|
|
|
@dataclass
|
|
class Token:
|
|
type: TokenType
|
|
value: str
|
|
line: int
|
|
col: int
|
|
|
|
|
|
class Lexer:
|
|
def __init__(self, source: str):
|
|
self.source = source
|
|
self.pos = 0
|
|
self.line = 1
|
|
self.col = 1
|
|
self.tokens = []
|
|
|
|
def error(self, msg: str):
|
|
raise SyntaxError(f"Lexer error at line {self.line}, col {self.col}: {msg}")
|
|
|
|
def peek(self, offset: int = 0) -> Optional[str]:
|
|
pos = self.pos + offset
|
|
return self.source[pos] if pos < len(self.source) else None
|
|
|
|
def advance(self) -> Optional[str]:
|
|
if self.pos >= len(self.source):
|
|
return None
|
|
char = self.source[self.pos]
|
|
self.pos += 1
|
|
if char == "\n":
|
|
self.line += 1
|
|
self.col = 1
|
|
else:
|
|
self.col += 1
|
|
return char
|
|
|
|
def skip_whitespace(self):
|
|
while self.peek() and self.peek() in " \t\n\r":
|
|
self.advance()
|
|
|
|
def skip_comment(self):
|
|
if self.peek() == "/" and self.peek(1) == "/":
|
|
while self.peek() and self.peek() != "\n":
|
|
self.advance()
|
|
self.advance() # skip newline
|
|
|
|
def read_number(self) -> str:
|
|
num = ""
|
|
while self.peek() and self.peek().isdigit():
|
|
num += self.advance()
|
|
return num
|
|
|
|
def read_identifier(self) -> str:
|
|
ident = ""
|
|
while self.peek() and (self.peek().isalnum() or self.peek() == "_"):
|
|
ident += self.advance()
|
|
return ident
|
|
|
|
def tokenize(self) -> List[Token]:
|
|
keywords = {
|
|
"int": TokenType.INT,
|
|
"if": TokenType.IF,
|
|
"else": TokenType.ELSE,
|
|
"while": TokenType.WHILE,
|
|
"return": TokenType.RETURN,
|
|
}
|
|
|
|
while self.pos < len(self.source):
|
|
self.skip_whitespace()
|
|
self.skip_comment()
|
|
|
|
if self.pos >= len(self.source):
|
|
break
|
|
|
|
line, col = self.line, self.col
|
|
char = self.peek()
|
|
|
|
# Numbers
|
|
if char.isdigit():
|
|
num = self.read_number()
|
|
self.tokens.append(Token(TokenType.NUMBER, num, line, col))
|
|
|
|
# Identifiers and keywords
|
|
elif char.isalpha() or char == "_":
|
|
ident = self.read_identifier()
|
|
token_type = keywords.get(ident, TokenType.IDENTIFIER)
|
|
self.tokens.append(Token(token_type, ident, line, col))
|
|
|
|
# Two-character operators
|
|
elif char == "=" and self.peek(1) == "=":
|
|
self.advance()
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.EQ, "==", line, col))
|
|
elif char == "!" and self.peek(1) == "=":
|
|
self.advance()
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.NE, "!=", line, col))
|
|
elif char == "<" and self.peek(1) == "=":
|
|
self.advance()
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.LE, "<=", line, col))
|
|
elif char == ">" and self.peek(1) == "=":
|
|
self.advance()
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.GE, ">=", line, col))
|
|
|
|
# Single-character operators
|
|
elif char == "+":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.PLUS, "+", line, col))
|
|
elif char == "-":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.MINUS, "-", line, col))
|
|
elif char == "*":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.STAR, "*", line, col))
|
|
elif char == "/":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.SLASH, "/", line, col))
|
|
elif char == "=":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.ASSIGN, "=", line, col))
|
|
elif char == "<":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.LT, "<", line, col))
|
|
elif char == ">":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.GT, ">", line, col))
|
|
elif char == "(":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.LPAREN, "(", line, col))
|
|
elif char == ")":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.RPAREN, ")", line, col))
|
|
elif char == "{":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.LBRACE, "{", line, col))
|
|
elif char == "}":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.RBRACE, "}", line, col))
|
|
elif char == ";":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.SEMICOLON, ";", line, col))
|
|
elif char == ",":
|
|
self.advance()
|
|
self.tokens.append(Token(TokenType.COMMA, ",", line, col))
|
|
else:
|
|
self.error(f"Unexpected character: {char}")
|
|
|
|
self.tokens.append(Token(TokenType.EOF, "", self.line, self.col))
|
|
return self.tokens
|
|
|
|
|
|
# AST Node classes
|
|
@dataclass
|
|
class ASTNode:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class Program(ASTNode):
|
|
declarations: List["Declaration"]
|
|
|
|
|
|
@dataclass
|
|
class Declaration(ASTNode):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class FunctionDecl(Declaration):
|
|
name: str
|
|
params: List[str]
|
|
body: "CompoundStmt"
|
|
|
|
|
|
@dataclass
|
|
class VarDecl(Declaration):
|
|
name: str
|
|
init: Optional["Expression"] = None
|
|
|
|
|
|
@dataclass
|
|
class Statement(ASTNode):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class CompoundStmt(Statement):
|
|
statements: List[Statement]
|
|
|
|
|
|
@dataclass
|
|
class ExprStmt(Statement):
|
|
expr: Optional["Expression"]
|
|
|
|
|
|
@dataclass
|
|
class IfStmt(Statement):
|
|
condition: "Expression"
|
|
then_stmt: Statement
|
|
else_stmt: Optional[Statement] = None
|
|
|
|
|
|
@dataclass
|
|
class WhileStmt(Statement):
|
|
condition: "Expression"
|
|
body: Statement
|
|
|
|
|
|
@dataclass
|
|
class ReturnStmt(Statement):
|
|
expr: Optional["Expression"]
|
|
|
|
|
|
@dataclass
|
|
class Expression(ASTNode):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class BinaryOp(Expression):
|
|
op: str
|
|
left: Expression
|
|
right: Expression
|
|
|
|
|
|
@dataclass
|
|
class UnaryOp(Expression):
|
|
op: str
|
|
operand: Expression
|
|
|
|
|
|
@dataclass
|
|
class AssignExpr(Expression):
|
|
name: str
|
|
value: Expression
|
|
|
|
|
|
@dataclass
|
|
class VarExpr(Expression):
|
|
name: str
|
|
|
|
|
|
@dataclass
|
|
class NumberExpr(Expression):
|
|
value: int
|
|
|
|
|
|
@dataclass
|
|
class CallExpr(Expression):
|
|
name: str
|
|
args: List[Expression]
|
|
|
|
|
|
class Parser:
|
|
def __init__(self, tokens: List[Token]):
|
|
self.tokens = tokens
|
|
self.pos = 0
|
|
|
|
def error(self, msg: str):
|
|
token = self.current()
|
|
raise SyntaxError(f"Parser error at line {token.line}, col {token.col}: {msg}")
|
|
|
|
def current(self) -> Token:
|
|
return self.tokens[self.pos] if self.pos < len(self.tokens) else self.tokens[-1]
|
|
|
|
def peek(self, offset: int = 0) -> Token:
|
|
pos = self.pos + offset
|
|
return self.tokens[pos] if pos < len(self.tokens) else self.tokens[-1]
|
|
|
|
def advance(self) -> Token:
|
|
token = self.current()
|
|
if self.pos < len(self.tokens) - 1:
|
|
self.pos += 1
|
|
return token
|
|
|
|
def expect(self, token_type: TokenType) -> Token:
|
|
token = self.current()
|
|
if token.type != token_type:
|
|
self.error(f"Expected {token_type.value}, got {token.type.value}")
|
|
return self.advance()
|
|
|
|
def parse(self) -> Program:
|
|
declarations = []
|
|
while self.current().type != TokenType.EOF:
|
|
declarations.append(self.parse_declaration())
|
|
return Program(declarations)
|
|
|
|
def parse_declaration(self) -> Declaration:
|
|
self.expect(TokenType.INT)
|
|
name = self.expect(TokenType.IDENTIFIER).value
|
|
|
|
if self.current().type == TokenType.LPAREN:
|
|
# Function declaration
|
|
self.advance()
|
|
params = []
|
|
|
|
if self.current().type != TokenType.RPAREN:
|
|
self.expect(TokenType.INT)
|
|
params.append(self.expect(TokenType.IDENTIFIER).value)
|
|
|
|
while self.current().type == TokenType.COMMA:
|
|
self.advance()
|
|
self.expect(TokenType.INT)
|
|
params.append(self.expect(TokenType.IDENTIFIER).value)
|
|
|
|
self.expect(TokenType.RPAREN)
|
|
body = self.parse_compound_stmt()
|
|
return FunctionDecl(name, params, body)
|
|
else:
|
|
# Variable declaration
|
|
init = None
|
|
if self.current().type == TokenType.ASSIGN:
|
|
self.advance()
|
|
init = self.parse_expression()
|
|
self.expect(TokenType.SEMICOLON)
|
|
return VarDecl(name, init)
|
|
|
|
def parse_compound_stmt(self) -> CompoundStmt:
|
|
self.expect(TokenType.LBRACE)
|
|
statements = []
|
|
|
|
while self.current().type != TokenType.RBRACE:
|
|
statements.append(self.parse_statement())
|
|
|
|
self.expect(TokenType.RBRACE)
|
|
return CompoundStmt(statements)
|
|
|
|
def parse_statement(self) -> Statement:
|
|
token = self.current()
|
|
|
|
if token.type == TokenType.LBRACE:
|
|
return self.parse_compound_stmt()
|
|
elif token.type == TokenType.IF:
|
|
return self.parse_if_stmt()
|
|
elif token.type == TokenType.WHILE:
|
|
return self.parse_while_stmt()
|
|
elif token.type == TokenType.RETURN:
|
|
return self.parse_return_stmt()
|
|
elif token.type == TokenType.INT:
|
|
# Local variable declaration
|
|
self.advance()
|
|
name = self.expect(TokenType.IDENTIFIER).value
|
|
init = None
|
|
if self.current().type == TokenType.ASSIGN:
|
|
self.advance()
|
|
init = self.parse_expression()
|
|
self.expect(TokenType.SEMICOLON)
|
|
return ExprStmt(AssignExpr(name, init) if init else None)
|
|
else:
|
|
expr = (
|
|
self.parse_expression()
|
|
if self.current().type != TokenType.SEMICOLON
|
|
else None
|
|
)
|
|
self.expect(TokenType.SEMICOLON)
|
|
return ExprStmt(expr)
|
|
|
|
def parse_if_stmt(self) -> IfStmt:
|
|
self.expect(TokenType.IF)
|
|
self.expect(TokenType.LPAREN)
|
|
condition = self.parse_expression()
|
|
self.expect(TokenType.RPAREN)
|
|
then_stmt = self.parse_statement()
|
|
|
|
else_stmt = None
|
|
if self.current().type == TokenType.ELSE:
|
|
self.advance()
|
|
else_stmt = self.parse_statement()
|
|
|
|
return IfStmt(condition, then_stmt, else_stmt)
|
|
|
|
def parse_while_stmt(self) -> WhileStmt:
|
|
self.expect(TokenType.WHILE)
|
|
self.expect(TokenType.LPAREN)
|
|
condition = self.parse_expression()
|
|
self.expect(TokenType.RPAREN)
|
|
body = self.parse_statement()
|
|
return WhileStmt(condition, body)
|
|
|
|
def parse_return_stmt(self) -> ReturnStmt:
|
|
self.expect(TokenType.RETURN)
|
|
expr = None
|
|
if self.current().type != TokenType.SEMICOLON:
|
|
expr = self.parse_expression()
|
|
self.expect(TokenType.SEMICOLON)
|
|
return ReturnStmt(expr)
|
|
|
|
def parse_expression(self) -> Expression:
|
|
return self.parse_assignment()
|
|
|
|
def parse_assignment(self) -> Expression:
|
|
expr = self.parse_comparison()
|
|
|
|
if self.current().type == TokenType.ASSIGN:
|
|
if not isinstance(expr, VarExpr):
|
|
self.error("Invalid assignment target")
|
|
self.advance()
|
|
value = self.parse_assignment()
|
|
return AssignExpr(expr.name, value)
|
|
|
|
return expr
|
|
|
|
def parse_comparison(self) -> Expression:
|
|
expr = self.parse_additive()
|
|
|
|
while self.current().type in [
|
|
TokenType.EQ,
|
|
TokenType.NE,
|
|
TokenType.LT,
|
|
TokenType.GT,
|
|
TokenType.LE,
|
|
TokenType.GE,
|
|
]:
|
|
op = self.advance().value
|
|
right = self.parse_additive()
|
|
expr = BinaryOp(op, expr, right)
|
|
|
|
return expr
|
|
|
|
def parse_additive(self) -> Expression:
|
|
expr = self.parse_multiplicative()
|
|
|
|
while self.current().type in [TokenType.PLUS, TokenType.MINUS]:
|
|
op = self.advance().value
|
|
right = self.parse_multiplicative()
|
|
expr = BinaryOp(op, expr, right)
|
|
|
|
return expr
|
|
|
|
def parse_multiplicative(self) -> Expression:
|
|
expr = self.parse_unary()
|
|
|
|
while self.current().type in [TokenType.STAR, TokenType.SLASH]:
|
|
op = self.advance().value
|
|
right = self.parse_unary()
|
|
expr = BinaryOp(op, expr, right)
|
|
|
|
return expr
|
|
|
|
def parse_unary(self) -> Expression:
|
|
if self.current().type in [TokenType.PLUS, TokenType.MINUS]:
|
|
op = self.advance().value
|
|
operand = self.parse_unary()
|
|
return UnaryOp(op, operand)
|
|
|
|
return self.parse_primary()
|
|
|
|
def parse_primary(self) -> Expression:
|
|
token = self.current()
|
|
|
|
if token.type == TokenType.NUMBER:
|
|
self.advance()
|
|
return NumberExpr(int(token.value))
|
|
|
|
elif token.type == TokenType.IDENTIFIER:
|
|
name = self.advance().value
|
|
|
|
if self.current().type == TokenType.LPAREN:
|
|
# Function call
|
|
self.advance()
|
|
args = []
|
|
|
|
if self.current().type != TokenType.RPAREN:
|
|
args.append(self.parse_expression())
|
|
while self.current().type == TokenType.COMMA:
|
|
self.advance()
|
|
args.append(self.parse_expression())
|
|
|
|
self.expect(TokenType.RPAREN)
|
|
return CallExpr(name, args)
|
|
else:
|
|
return VarExpr(name)
|
|
|
|
elif token.type == TokenType.LPAREN:
|
|
self.advance()
|
|
expr = self.parse_expression()
|
|
self.expect(TokenType.RPAREN)
|
|
return expr
|
|
|
|
else:
|
|
self.error(f"Unexpected token: {token.type.value}")
|
|
|
|
|
|
class CodeGenerator:
|
|
def __init__(self):
|
|
self.output = []
|
|
self.label_counter = 0
|
|
self.string_counter = 0
|
|
self.functions = {}
|
|
self.current_function = None
|
|
self.local_vars = {}
|
|
self.global_vars = {}
|
|
self.register_pool = [f"rg{i:x}" for i in range(16)]
|
|
self.used_registers = set()
|
|
|
|
def new_label(self, prefix: str = "L") -> str:
|
|
label = f"{prefix}{self.label_counter}"
|
|
self.label_counter += 1
|
|
return label
|
|
|
|
def allocate_register(self) -> str:
|
|
for reg in self.register_pool:
|
|
if reg not in self.used_registers:
|
|
self.used_registers.add(reg)
|
|
return reg
|
|
raise RuntimeError("Out of registers")
|
|
|
|
def free_register(self, reg: str):
|
|
self.used_registers.discard(reg)
|
|
|
|
def emit(self, code: str):
|
|
self.output.append(code)
|
|
|
|
def generate(self, program: Program) -> str:
|
|
# Emit data section
|
|
self.emit("// Global variables")
|
|
for decl in program.declarations:
|
|
if isinstance(decl, VarDecl):
|
|
self.global_vars[decl.name] = f"var_{decl.name}"
|
|
if decl.init:
|
|
if isinstance(decl.init, NumberExpr):
|
|
self.emit(f"dw var_{decl.name}: {decl.init.value}")
|
|
else:
|
|
self.emit(f"dw var_{decl.name}: 0")
|
|
else:
|
|
self.emit(f"dw var_{decl.name}: 0")
|
|
|
|
self.emit("")
|
|
self.emit("// Entry point")
|
|
self.emit("dw stack_bottom: 0x10000")
|
|
self.emit("")
|
|
self.emit("init:")
|
|
self.emit(" ldw stack_bottom, spr")
|
|
self.emit(" mov spr, bpr")
|
|
|
|
self.emit(" push zero")
|
|
self.emit(" call main")
|
|
self.emit(" pop rg0")
|
|
self.emit(" hlt")
|
|
self.emit("")
|
|
|
|
# Emit functions
|
|
for decl in program.declarations:
|
|
if isinstance(decl, FunctionDecl):
|
|
self.generate_function(decl)
|
|
|
|
return "\n".join(self.output)
|
|
|
|
def generate_function(self, func: FunctionDecl):
|
|
self.current_function = func.name
|
|
self.functions[func.name] = func
|
|
self.local_vars = {}
|
|
|
|
# Map parameters to stack offsets
|
|
# Parameters start at bpr+8 (after return addr at bpr+4)
|
|
for i, param in enumerate(func.params):
|
|
self.local_vars[param] = 8 + (i * 4)
|
|
|
|
self.emit(f"{func.name}:")
|
|
self.emit(" push bpr")
|
|
self.emit(" mov spr, bpr")
|
|
self.emit("")
|
|
|
|
# Generate function body
|
|
self.generate_compound_stmt(func.body)
|
|
|
|
# Default return if no explicit return
|
|
self.emit("// default return")
|
|
self.emit(f"{func.name}_end:")
|
|
self.emit(" mov bpr, spr")
|
|
self.emit(" pop bpr")
|
|
self.emit(" return")
|
|
self.emit("")
|
|
|
|
def generate_compound_stmt(self, stmt: CompoundStmt):
|
|
for s in stmt.statements:
|
|
self.generate_statement(s)
|
|
|
|
def generate_statement(self, stmt: Statement):
|
|
if isinstance(stmt, CompoundStmt):
|
|
self.generate_compound_stmt(stmt)
|
|
elif isinstance(stmt, ExprStmt):
|
|
if stmt.expr:
|
|
reg = self.generate_expression(stmt.expr)
|
|
self.free_register(reg)
|
|
elif isinstance(stmt, IfStmt):
|
|
self.generate_if_stmt(stmt)
|
|
elif isinstance(stmt, WhileStmt):
|
|
self.generate_while_stmt(stmt)
|
|
elif isinstance(stmt, ReturnStmt):
|
|
self.generate_return_stmt(stmt)
|
|
|
|
def generate_if_stmt(self, stmt: IfStmt):
|
|
else_label = self.new_label("else")
|
|
end_label = self.new_label("endif")
|
|
|
|
# Evaluate condition
|
|
cond_reg = self.generate_expression(stmt.condition)
|
|
self.emit(f" cmp {cond_reg}, zero")
|
|
self.free_register(cond_reg)
|
|
|
|
if stmt.else_stmt:
|
|
self.emit(f" jeq {else_label}")
|
|
else:
|
|
self.emit(f" jeq {end_label}")
|
|
|
|
# Then branch
|
|
self.generate_statement(stmt.then_stmt)
|
|
|
|
if stmt.else_stmt:
|
|
self.emit(f" jmp {end_label}")
|
|
self.emit(f"{else_label}:")
|
|
self.generate_statement(stmt.else_stmt)
|
|
|
|
self.emit(f"{end_label}:")
|
|
|
|
def generate_while_stmt(self, stmt: WhileStmt):
|
|
start_label = self.new_label("while_start")
|
|
end_label = self.new_label("while_end")
|
|
|
|
self.emit(f"{start_label}:")
|
|
|
|
# Evaluate condition
|
|
cond_reg = self.generate_expression(stmt.condition)
|
|
self.emit(f" cmp {cond_reg}, zero")
|
|
self.free_register(cond_reg)
|
|
self.emit(f" jeq {end_label}")
|
|
|
|
# Loop body
|
|
self.generate_statement(stmt.body)
|
|
self.emit(f" jmp {start_label}")
|
|
|
|
self.emit(f"{end_label}:")
|
|
|
|
def generate_return_stmt(self, stmt: ReturnStmt):
|
|
if stmt.expr:
|
|
reg = self.generate_expression(stmt.expr)
|
|
# Store return value at spr+8 according to calling convention
|
|
self.emit(f" stw {reg}, spr, 8")
|
|
self.free_register(reg)
|
|
self.emit(f" jmp {self.current_function}_end")
|
|
|
|
def generate_expression(self, expr: Expression) -> str:
|
|
if isinstance(expr, NumberExpr):
|
|
reg = self.allocate_register()
|
|
if expr.value <= 0xFFFF and expr.value >= 0:
|
|
self.emit(f" lli {expr.value}, {reg}")
|
|
if expr.value > 0xFF:
|
|
self.emit(f" lui {expr.value >> 16}, {reg}")
|
|
else:
|
|
self.emit(f" lli {expr.value & 0xFFFF}, {reg}")
|
|
self.emit(f" lui {(expr.value >> 16) & 0xFFFF}, {reg}")
|
|
return reg
|
|
|
|
elif isinstance(expr, VarExpr):
|
|
reg = self.allocate_register()
|
|
if expr.name in self.local_vars:
|
|
offset = self.local_vars[expr.name]
|
|
self.emit(f" ldw bpr, {reg}, {offset}")
|
|
elif expr.name in self.global_vars:
|
|
label = self.global_vars[expr.name]
|
|
self.emit(f" ldw {label}, {reg}")
|
|
else:
|
|
raise RuntimeError(f"Undefined variable: {expr.name}")
|
|
return reg
|
|
|
|
elif isinstance(expr, AssignExpr):
|
|
value_reg = self.generate_expression(expr.value)
|
|
|
|
if expr.name in self.local_vars:
|
|
offset = self.local_vars[expr.name]
|
|
self.emit(f" stw {value_reg}, bpr, {offset}")
|
|
elif expr.name in self.global_vars:
|
|
label = self.global_vars[expr.name]
|
|
self.emit(f" stw {value_reg}, {label}")
|
|
else:
|
|
# New local variable - allocate after params and return value space
|
|
# Start local variables at offset -4 from bpr (growing downward)
|
|
offset = -(len([v for v in self.local_vars.values() if v < 0]) + 1) * 4
|
|
self.local_vars[expr.name] = offset
|
|
self.emit(f" stw {value_reg}, bpr, {offset}")
|
|
|
|
return value_reg
|
|
|
|
elif isinstance(expr, BinaryOp):
|
|
return self.generate_binary_op(expr)
|
|
|
|
elif isinstance(expr, UnaryOp):
|
|
operand_reg = self.generate_expression(expr.operand)
|
|
result_reg = self.allocate_register()
|
|
|
|
if expr.op == "-":
|
|
self.emit(f" lwi 0, {result_reg}")
|
|
self.emit(f" sub {result_reg}, {operand_reg}, {result_reg}")
|
|
else: # +
|
|
self.emit(f" mov {operand_reg}, {result_reg}")
|
|
|
|
self.free_register(operand_reg)
|
|
return result_reg
|
|
|
|
elif isinstance(expr, CallExpr):
|
|
# First, make space for return value (must be pushed BEFORE arguments)
|
|
temp_reg = self.allocate_register()
|
|
|
|
# Then push arguments in reverse order
|
|
arg_regs = []
|
|
for arg in reversed(expr.args):
|
|
reg = self.generate_expression(arg)
|
|
self.emit(f" push {reg}")
|
|
arg_regs.append(reg)
|
|
|
|
# Call function
|
|
self.emit(f" call {expr.name}")
|
|
|
|
# Get return value (it's now on top of stack)
|
|
self.emit(f" pop {temp_reg}")
|
|
|
|
# Clean up remaining args
|
|
for i in range(len(arg_regs) - 1):
|
|
self.emit(f" pop zero")
|
|
|
|
# Free the arg registers
|
|
for reg in arg_regs:
|
|
self.free_register(reg)
|
|
|
|
return temp_reg
|
|
|
|
else:
|
|
raise RuntimeError(f"Unknown expression type: {type(expr)}")
|
|
|
|
def generate_binary_op(self, expr: BinaryOp) -> str:
|
|
# For operations that might contain function calls, we need to be careful
|
|
# about register allocation. Evaluate left, save it, evaluate right.
|
|
left_reg = self.generate_expression(expr.left)
|
|
|
|
# If right side contains a function call, we need to save left_reg
|
|
# For now, always save to be safe
|
|
saved_reg = self.allocate_register()
|
|
self.emit(f" mov {left_reg}, {saved_reg}")
|
|
self.free_register(left_reg)
|
|
|
|
right_reg = self.generate_expression(expr.right)
|
|
result_reg = self.allocate_register()
|
|
|
|
if expr.op == "+":
|
|
self.emit(f" add {left_reg}, {right_reg}, {result_reg}")
|
|
elif expr.op == "-":
|
|
self.emit(f" sub {left_reg}, {right_reg}, {result_reg}")
|
|
elif expr.op == "*":
|
|
# Simple multiplication using loop
|
|
temp_label = self.new_label("mult")
|
|
end_label = self.new_label("mult_end")
|
|
self.emit(f" lli 0, {result_reg}")
|
|
self.emit(f"{temp_label}:")
|
|
self.emit(f" cmp {right_reg}, zero")
|
|
self.emit(f" jeq {end_label}")
|
|
self.emit(f" add {result_reg}, {left_reg}, {result_reg}")
|
|
self.emit(f" dec {right_reg}")
|
|
self.emit(f" jmp {temp_label}")
|
|
self.emit(f"{end_label}:")
|
|
elif expr.op == "/":
|
|
# Simple division using loop
|
|
temp_label = self.new_label("div")
|
|
end_label = self.new_label("div_end")
|
|
self.emit(f" lli 0, {result_reg}")
|
|
self.emit(f"{temp_label}:")
|
|
self.emit(f" cmp {left_reg}, {right_reg}")
|
|
self.emit(f" jlt {end_label}")
|
|
self.emit(f" sub {left_reg}, {right_reg}, {left_reg}")
|
|
self.emit(f" inc {result_reg}")
|
|
self.emit(f" jmp {temp_label}")
|
|
self.emit(f"{end_label}:")
|
|
elif expr.op in ["==", "!=", "<", ">", "<=", ">="]:
|
|
self.emit(f" cmp {left_reg}, {right_reg}")
|
|
|
|
# Result is 1 if condition true, 0 otherwise
|
|
self.emit(f" lli 0, {result_reg}")
|
|
true_label = self.new_label("cmp_true")
|
|
end_label = self.new_label("cmp_end")
|
|
|
|
if expr.op == "==":
|
|
self.emit(f" jeq {true_label}")
|
|
elif expr.op == "!=":
|
|
self.emit(f" jne {true_label}")
|
|
elif expr.op == "<":
|
|
self.emit(f" jlt {true_label}")
|
|
elif expr.op == ">":
|
|
self.emit(f" jgt {true_label}")
|
|
elif expr.op == "<=":
|
|
self.emit(f" jle {true_label}")
|
|
elif expr.op == ">=":
|
|
self.emit(f" jge {true_label}")
|
|
|
|
self.emit(f" jmp {end_label}")
|
|
self.emit(f"{true_label}:")
|
|
self.emit(f" lli 1, {result_reg}")
|
|
self.emit(f"{end_label}:")
|
|
|
|
self.free_register(left_reg)
|
|
self.free_register(right_reg)
|
|
return result_reg
|
|
|
|
|
|
def compile_c_to_asm(source: str) -> str:
|
|
"""Compile C source code to DSA assembly."""
|
|
lexer = Lexer(source)
|
|
tokens = lexer.tokenize()
|
|
|
|
parser = Parser(tokens)
|
|
ast = parser.parse()
|
|
|
|
codegen = CodeGenerator()
|
|
assembly = codegen.generate(ast)
|
|
|
|
return assembly
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python compiler.py <input.c> [output.dsa]")
|
|
sys.exit(1)
|
|
|
|
input_file = sys.argv[1]
|
|
output_file = sys.argv[2] if len(sys.argv) > 2 else input_file.replace(".c", ".dsa")
|
|
|
|
with open(input_file, "r") as f:
|
|
source = f.read()
|
|
|
|
try:
|
|
assembly = compile_c_to_asm(source)
|
|
|
|
with open(output_file, "w") as f:
|
|
f.write(assembly)
|
|
|
|
print(f"Successfully compiled {input_file} to {output_file}")
|
|
except (SyntaxError, RuntimeError) as e:
|
|
print(f"Compilation error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
# # Example usage
|
|
# if len(sys.argv) > 1:
|
|
# example_c = sys.argv[1]
|
|
|
|
# else:
|
|
# example_c = """
|
|
# int factorial(int n) {
|
|
# if (n <= 1) {
|
|
# return 1;
|
|
# }
|
|
# return n * factorial(n - 1);
|
|
# }
|
|
|
|
# int main() {
|
|
# int result;
|
|
# result = factorial(5);
|
|
# return result;
|
|
# }
|
|
# """
|
|
|
|
# print("Example C program:")
|
|
# print(example_c)
|
|
# print("\n" + "="*60 + "\n")
|
|
# print("Generated DSA assembly:")
|
|
# print(compile_c_to_asm(example_c))
|