Files
damn_simple_architecture/c_compiler/compiler.py
T
2025-11-14 23:36:51 +00:00

927 lines
28 KiB
Python

#!/usr/bin/env python3
"""
Simple C to DSA Assembly Compiler
Supports a subset of C including:
- int variables and functions
- Arithmetic operations (+, -, *, /)
- Comparisons (==, !=, <, >, <=, >=)
- If/else statements
- While loops
- Function calls
- Return statements
"""
import re
import sys
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
from pprint import pprint
import json
class TokenType(Enum):
# Keywords
INT = "int"
IF = "if"
ELSE = "else"
WHILE = "while"
RETURN = "return"
# Identifiers and literals
IDENTIFIER = "IDENTIFIER"
NUMBER = "NUMBER"
# Operators
PLUS = "+"
MINUS = "-"
STAR = "*"
SLASH = "/"
ASSIGN = "="
EQ = "=="
NE = "!="
LT = "<"
GT = ">"
LE = "<="
GE = ">="
# Delimiters
LPAREN = "("
RPAREN = ")"
LBRACE = "{"
RBRACE = "}"
SEMICOLON = ";"
COMMA = ","
EOF = "EOF"
@dataclass
class Token:
type: TokenType
value: str
line: int
col: int
class Lexer:
def __init__(self, source: str):
self.source = source
self.pos = 0
self.line = 1
self.col = 1
self.tokens = []
def error(self, msg: str):
raise SyntaxError(f"Lexer error at line {self.line}, col {self.col}: {msg}")
def peek(self, offset: int = 0) -> Optional[str]:
pos = self.pos + offset
return self.source[pos] if pos < len(self.source) else None
def advance(self) -> Optional[str]:
if self.pos >= len(self.source):
return None
char = self.source[self.pos]
self.pos += 1
if char == "\n":
self.line += 1
self.col = 1
else:
self.col += 1
return char
def skip_whitespace(self):
while self.peek() and self.peek() in " \t\n\r":
self.advance()
def skip_comment(self):
if self.peek() == "/" and self.peek(1) == "/":
while self.peek() and self.peek() != "\n":
self.advance()
self.advance() # skip newline
def read_number(self) -> str:
num = ""
while self.peek() and self.peek().isdigit():
num += self.advance()
return num
def read_identifier(self) -> str:
ident = ""
while self.peek() and (self.peek().isalnum() or self.peek() == "_"):
ident += self.advance()
return ident
def tokenize(self) -> List[Token]:
keywords = {
"int": TokenType.INT,
"if": TokenType.IF,
"else": TokenType.ELSE,
"while": TokenType.WHILE,
"return": TokenType.RETURN,
}
while self.pos < len(self.source):
self.skip_whitespace()
self.skip_comment()
if self.pos >= len(self.source):
break
line, col = self.line, self.col
char = self.peek()
# Numbers
if char.isdigit():
num = self.read_number()
self.tokens.append(Token(TokenType.NUMBER, num, line, col))
# Identifiers and keywords
elif char.isalpha() or char == "_":
ident = self.read_identifier()
token_type = keywords.get(ident, TokenType.IDENTIFIER)
self.tokens.append(Token(token_type, ident, line, col))
# Two-character operators
elif char == "=" and self.peek(1) == "=":
self.advance()
self.advance()
self.tokens.append(Token(TokenType.EQ, "==", line, col))
elif char == "!" and self.peek(1) == "=":
self.advance()
self.advance()
self.tokens.append(Token(TokenType.NE, "!=", line, col))
elif char == "<" and self.peek(1) == "=":
self.advance()
self.advance()
self.tokens.append(Token(TokenType.LE, "<=", line, col))
elif char == ">" and self.peek(1) == "=":
self.advance()
self.advance()
self.tokens.append(Token(TokenType.GE, ">=", line, col))
# Single-character operators
elif char == "+":
self.advance()
self.tokens.append(Token(TokenType.PLUS, "+", line, col))
elif char == "-":
self.advance()
self.tokens.append(Token(TokenType.MINUS, "-", line, col))
elif char == "*":
self.advance()
self.tokens.append(Token(TokenType.STAR, "*", line, col))
elif char == "/":
self.advance()
self.tokens.append(Token(TokenType.SLASH, "/", line, col))
elif char == "=":
self.advance()
self.tokens.append(Token(TokenType.ASSIGN, "=", line, col))
elif char == "<":
self.advance()
self.tokens.append(Token(TokenType.LT, "<", line, col))
elif char == ">":
self.advance()
self.tokens.append(Token(TokenType.GT, ">", line, col))
elif char == "(":
self.advance()
self.tokens.append(Token(TokenType.LPAREN, "(", line, col))
elif char == ")":
self.advance()
self.tokens.append(Token(TokenType.RPAREN, ")", line, col))
elif char == "{":
self.advance()
self.tokens.append(Token(TokenType.LBRACE, "{", line, col))
elif char == "}":
self.advance()
self.tokens.append(Token(TokenType.RBRACE, "}", line, col))
elif char == ";":
self.advance()
self.tokens.append(Token(TokenType.SEMICOLON, ";", line, col))
elif char == ",":
self.advance()
self.tokens.append(Token(TokenType.COMMA, ",", line, col))
else:
self.error(f"Unexpected character: {char}")
self.tokens.append(Token(TokenType.EOF, "", self.line, self.col))
return self.tokens
# AST Node classes
@dataclass
class ASTNode:
pass
@dataclass
class Program(ASTNode):
declarations: List["Declaration"]
@dataclass
class Declaration(ASTNode):
pass
@dataclass
class FunctionDecl(Declaration):
name: str
params: List[str]
body: "CompoundStmt"
@dataclass
class VarDecl(Declaration):
name: str
init: Optional["Expression"] = None
@dataclass
class Statement(ASTNode):
pass
@dataclass
class CompoundStmt(Statement):
statements: List[Statement]
@dataclass
class ExprStmt(Statement):
expr: Optional["Expression"]
@dataclass
class IfStmt(Statement):
condition: "Expression"
then_stmt: Statement
else_stmt: Optional[Statement] = None
@dataclass
class WhileStmt(Statement):
condition: "Expression"
body: Statement
@dataclass
class ReturnStmt(Statement):
expr: Optional["Expression"]
@dataclass
class Expression(ASTNode):
pass
@dataclass
class BinaryOp(Expression):
op: str
left: Expression
right: Expression
@dataclass
class UnaryOp(Expression):
op: str
operand: Expression
@dataclass
class AssignExpr(Expression):
name: str
value: Expression
@dataclass
class VarExpr(Expression):
name: str
@dataclass
class NumberExpr(Expression):
value: int
@dataclass
class CallExpr(Expression):
name: str
args: List[Expression]
class Parser:
def __init__(self, tokens: List[Token]):
self.tokens = tokens
self.pos = 0
def error(self, msg: str):
token = self.current()
raise SyntaxError(f"Parser error at line {token.line}, col {token.col}: {msg}")
def current(self) -> Token:
return self.tokens[self.pos] if self.pos < len(self.tokens) else self.tokens[-1]
def peek(self, offset: int = 0) -> Token:
pos = self.pos + offset
return self.tokens[pos] if pos < len(self.tokens) else self.tokens[-1]
def advance(self) -> Token:
token = self.current()
if self.pos < len(self.tokens) - 1:
self.pos += 1
return token
def expect(self, token_type: TokenType) -> Token:
token = self.current()
if token.type != token_type:
self.error(f"Expected {token_type.value}, got {token.type.value}")
return self.advance()
def parse(self) -> Program:
declarations = []
while self.current().type != TokenType.EOF:
declarations.append(self.parse_declaration())
return Program(declarations)
def parse_declaration(self) -> Declaration:
self.expect(TokenType.INT)
name = self.expect(TokenType.IDENTIFIER).value
if self.current().type == TokenType.LPAREN:
# Function declaration
self.advance()
params = []
if self.current().type != TokenType.RPAREN:
self.expect(TokenType.INT)
params.append(self.expect(TokenType.IDENTIFIER).value)
while self.current().type == TokenType.COMMA:
self.advance()
self.expect(TokenType.INT)
params.append(self.expect(TokenType.IDENTIFIER).value)
self.expect(TokenType.RPAREN)
body = self.parse_compound_stmt()
return FunctionDecl(name, params, body)
else:
# Variable declaration
init = None
if self.current().type == TokenType.ASSIGN:
self.advance()
init = self.parse_expression()
self.expect(TokenType.SEMICOLON)
return VarDecl(name, init)
def parse_compound_stmt(self) -> CompoundStmt:
self.expect(TokenType.LBRACE)
statements = []
while self.current().type != TokenType.RBRACE:
statements.append(self.parse_statement())
self.expect(TokenType.RBRACE)
return CompoundStmt(statements)
def parse_statement(self) -> Statement:
token = self.current()
if token.type == TokenType.LBRACE:
return self.parse_compound_stmt()
elif token.type == TokenType.IF:
return self.parse_if_stmt()
elif token.type == TokenType.WHILE:
return self.parse_while_stmt()
elif token.type == TokenType.RETURN:
return self.parse_return_stmt()
elif token.type == TokenType.INT:
# Local variable declaration
self.advance()
name = self.expect(TokenType.IDENTIFIER).value
init = None
if self.current().type == TokenType.ASSIGN:
self.advance()
init = self.parse_expression()
self.expect(TokenType.SEMICOLON)
return ExprStmt(AssignExpr(name, init) if init else None)
else:
expr = (
self.parse_expression()
if self.current().type != TokenType.SEMICOLON
else None
)
self.expect(TokenType.SEMICOLON)
return ExprStmt(expr)
def parse_if_stmt(self) -> IfStmt:
self.expect(TokenType.IF)
self.expect(TokenType.LPAREN)
condition = self.parse_expression()
self.expect(TokenType.RPAREN)
then_stmt = self.parse_statement()
else_stmt = None
if self.current().type == TokenType.ELSE:
self.advance()
else_stmt = self.parse_statement()
return IfStmt(condition, then_stmt, else_stmt)
def parse_while_stmt(self) -> WhileStmt:
self.expect(TokenType.WHILE)
self.expect(TokenType.LPAREN)
condition = self.parse_expression()
self.expect(TokenType.RPAREN)
body = self.parse_statement()
return WhileStmt(condition, body)
def parse_return_stmt(self) -> ReturnStmt:
self.expect(TokenType.RETURN)
expr = None
if self.current().type != TokenType.SEMICOLON:
expr = self.parse_expression()
self.expect(TokenType.SEMICOLON)
return ReturnStmt(expr)
def parse_expression(self) -> Expression:
return self.parse_assignment()
def parse_assignment(self) -> Expression:
expr = self.parse_comparison()
if self.current().type == TokenType.ASSIGN:
if not isinstance(expr, VarExpr):
self.error("Invalid assignment target")
self.advance()
value = self.parse_assignment()
return AssignExpr(expr.name, value)
return expr
def parse_comparison(self) -> Expression:
expr = self.parse_additive()
while self.current().type in [
TokenType.EQ,
TokenType.NE,
TokenType.LT,
TokenType.GT,
TokenType.LE,
TokenType.GE,
]:
op = self.advance().value
right = self.parse_additive()
expr = BinaryOp(op, expr, right)
return expr
def parse_additive(self) -> Expression:
expr = self.parse_multiplicative()
while self.current().type in [TokenType.PLUS, TokenType.MINUS]:
op = self.advance().value
right = self.parse_multiplicative()
expr = BinaryOp(op, expr, right)
return expr
def parse_multiplicative(self) -> Expression:
expr = self.parse_unary()
while self.current().type in [TokenType.STAR, TokenType.SLASH]:
op = self.advance().value
right = self.parse_unary()
expr = BinaryOp(op, expr, right)
return expr
def parse_unary(self) -> Expression:
if self.current().type in [TokenType.PLUS, TokenType.MINUS]:
op = self.advance().value
operand = self.parse_unary()
return UnaryOp(op, operand)
return self.parse_primary()
def parse_primary(self) -> Expression:
token = self.current()
if token.type == TokenType.NUMBER:
self.advance()
return NumberExpr(int(token.value))
elif token.type == TokenType.IDENTIFIER:
name = self.advance().value
if self.current().type == TokenType.LPAREN:
# Function call
self.advance()
args = []
if self.current().type != TokenType.RPAREN:
args.append(self.parse_expression())
while self.current().type == TokenType.COMMA:
self.advance()
args.append(self.parse_expression())
self.expect(TokenType.RPAREN)
return CallExpr(name, args)
else:
return VarExpr(name)
elif token.type == TokenType.LPAREN:
self.advance()
expr = self.parse_expression()
self.expect(TokenType.RPAREN)
return expr
else:
self.error(f"Unexpected token: {token.type.value}")
class CodeGenerator:
def __init__(self):
self.output = []
self.label_counter = 0
self.string_counter = 0
self.functions = {}
self.current_function = None
self.local_vars = {}
self.global_vars = {}
self.register_pool = [f"rg{i:x}" for i in range(16)]
self.used_registers = set()
def new_label(self, prefix: str = "L") -> str:
label = f"{prefix}{self.label_counter}"
self.label_counter += 1
return label
def allocate_register(self) -> str:
for reg in self.register_pool:
if reg not in self.used_registers:
self.used_registers.add(reg)
return reg
raise RuntimeError("Out of registers")
def free_register(self, reg: str):
self.used_registers.discard(reg)
def emit(self, code: str):
self.output.append(code)
def generate(self, program: Program) -> str:
# Emit data section
self.emit("// Global variables")
for decl in program.declarations:
if isinstance(decl, VarDecl):
self.global_vars[decl.name] = f"var_{decl.name}"
if decl.init:
if isinstance(decl.init, NumberExpr):
self.emit(f"dw var_{decl.name}: {decl.init.value}")
else:
self.emit(f"dw var_{decl.name}: 0")
else:
self.emit(f"dw var_{decl.name}: 0")
self.emit("")
self.emit("// Entry point")
self.emit("dw stack_bottom: 0x10000")
self.emit("")
self.emit("init:")
self.emit(" ldw stack_bottom, spr")
self.emit(" mov spr, bpr")
self.emit(" push zero")
self.emit(" call main")
self.emit(" pop rg0")
self.emit(" hlt")
self.emit("")
# Emit functions
for decl in program.declarations:
if isinstance(decl, FunctionDecl):
self.generate_function(decl)
return "\n".join(self.output)
def generate_function(self, func: FunctionDecl):
self.current_function = func.name
self.functions[func.name] = func
self.local_vars = {}
# Map parameters to stack offsets
# Parameters start at bpr+8 (after return addr at bpr+4)
for i, param in enumerate(func.params):
self.local_vars[param] = 8 + (i * 4)
self.emit(f"{func.name}:")
self.emit(" push bpr")
self.emit(" mov spr, bpr")
self.emit("")
# Generate function body
self.generate_compound_stmt(func.body)
# Default return if no explicit return
self.emit("// default return")
self.emit(f"{func.name}_end:")
self.emit(" mov bpr, spr")
self.emit(" pop bpr")
self.emit(" return")
self.emit("")
def generate_compound_stmt(self, stmt: CompoundStmt):
for s in stmt.statements:
self.generate_statement(s)
def generate_statement(self, stmt: Statement):
if isinstance(stmt, CompoundStmt):
self.generate_compound_stmt(stmt)
elif isinstance(stmt, ExprStmt):
if stmt.expr:
reg = self.generate_expression(stmt.expr)
self.free_register(reg)
elif isinstance(stmt, IfStmt):
self.generate_if_stmt(stmt)
elif isinstance(stmt, WhileStmt):
self.generate_while_stmt(stmt)
elif isinstance(stmt, ReturnStmt):
self.generate_return_stmt(stmt)
def generate_if_stmt(self, stmt: IfStmt):
else_label = self.new_label("else")
end_label = self.new_label("endif")
# Evaluate condition
cond_reg = self.generate_expression(stmt.condition)
self.emit(f" cmp {cond_reg}, zero")
self.free_register(cond_reg)
if stmt.else_stmt:
self.emit(f" jeq {else_label}")
else:
self.emit(f" jeq {end_label}")
# Then branch
self.generate_statement(stmt.then_stmt)
if stmt.else_stmt:
self.emit(f" jmp {end_label}")
self.emit(f"{else_label}:")
self.generate_statement(stmt.else_stmt)
self.emit(f"{end_label}:")
def generate_while_stmt(self, stmt: WhileStmt):
start_label = self.new_label("while_start")
end_label = self.new_label("while_end")
self.emit(f"{start_label}:")
# Evaluate condition
cond_reg = self.generate_expression(stmt.condition)
self.emit(f" cmp {cond_reg}, zero")
self.free_register(cond_reg)
self.emit(f" jeq {end_label}")
# Loop body
self.generate_statement(stmt.body)
self.emit(f" jmp {start_label}")
self.emit(f"{end_label}:")
def generate_return_stmt(self, stmt: ReturnStmt):
if stmt.expr:
reg = self.generate_expression(stmt.expr)
# Store return value at spr+8 according to calling convention
self.emit(f" stw {reg}, spr, 8")
self.free_register(reg)
self.emit(f" jmp {self.current_function}_end")
def generate_expression(self, expr: Expression) -> str:
if isinstance(expr, NumberExpr):
reg = self.allocate_register()
if expr.value <= 0xFFFF and expr.value >= 0:
self.emit(f" lli {expr.value}, {reg}")
if expr.value > 0xFF:
self.emit(f" lui {expr.value >> 16}, {reg}")
else:
self.emit(f" lli {expr.value & 0xFFFF}, {reg}")
self.emit(f" lui {(expr.value >> 16) & 0xFFFF}, {reg}")
return reg
elif isinstance(expr, VarExpr):
reg = self.allocate_register()
if expr.name in self.local_vars:
offset = self.local_vars[expr.name]
self.emit(f" ldw bpr, {reg}, {offset}")
elif expr.name in self.global_vars:
label = self.global_vars[expr.name]
self.emit(f" ldw {label}, {reg}")
else:
raise RuntimeError(f"Undefined variable: {expr.name}")
return reg
elif isinstance(expr, AssignExpr):
value_reg = self.generate_expression(expr.value)
if expr.name in self.local_vars:
offset = self.local_vars[expr.name]
self.emit(f" stw {value_reg}, bpr, {offset}")
elif expr.name in self.global_vars:
label = self.global_vars[expr.name]
self.emit(f" stw {value_reg}, {label}")
else:
# New local variable - allocate after params and return value space
# Start local variables at offset -4 from bpr (growing downward)
offset = -(len([v for v in self.local_vars.values() if v < 0]) + 1) * 4
self.local_vars[expr.name] = offset
self.emit(f" stw {value_reg}, bpr, {offset}")
return value_reg
elif isinstance(expr, BinaryOp):
return self.generate_binary_op(expr)
elif isinstance(expr, UnaryOp):
operand_reg = self.generate_expression(expr.operand)
result_reg = self.allocate_register()
if expr.op == "-":
self.emit(f" lwi 0, {result_reg}")
self.emit(f" sub {result_reg}, {operand_reg}, {result_reg}")
else: # +
self.emit(f" mov {operand_reg}, {result_reg}")
self.free_register(operand_reg)
return result_reg
elif isinstance(expr, CallExpr):
# First, make space for return value (must be pushed BEFORE arguments)
temp_reg = self.allocate_register()
# Then push arguments in reverse order
arg_regs = []
for arg in reversed(expr.args):
reg = self.generate_expression(arg)
self.emit(f" push {reg}")
arg_regs.append(reg)
# Call function
self.emit(f" call {expr.name}")
# Get return value (it's now on top of stack)
self.emit(f" pop {temp_reg}")
# Clean up remaining args
for i in range(len(arg_regs) - 1):
self.emit(f" pop zero")
# Free the arg registers
for reg in arg_regs:
self.free_register(reg)
return temp_reg
else:
raise RuntimeError(f"Unknown expression type: {type(expr)}")
def generate_binary_op(self, expr: BinaryOp) -> str:
# For operations that might contain function calls, we need to be careful
# about register allocation. Evaluate left, save it, evaluate right.
left_reg = self.generate_expression(expr.left)
# If right side contains a function call, we need to save left_reg
# For now, always save to be safe
saved_reg = self.allocate_register()
self.emit(f" mov {left_reg}, {saved_reg}")
self.free_register(left_reg)
right_reg = self.generate_expression(expr.right)
result_reg = self.allocate_register()
if expr.op == "+":
self.emit(f" add {left_reg}, {right_reg}, {result_reg}")
elif expr.op == "-":
self.emit(f" sub {left_reg}, {right_reg}, {result_reg}")
elif expr.op == "*":
# Simple multiplication using loop
temp_label = self.new_label("mult")
end_label = self.new_label("mult_end")
self.emit(f" lli 0, {result_reg}")
self.emit(f"{temp_label}:")
self.emit(f" cmp {right_reg}, zero")
self.emit(f" jeq {end_label}")
self.emit(f" add {result_reg}, {left_reg}, {result_reg}")
self.emit(f" dec {right_reg}")
self.emit(f" jmp {temp_label}")
self.emit(f"{end_label}:")
elif expr.op == "/":
# Simple division using loop
temp_label = self.new_label("div")
end_label = self.new_label("div_end")
self.emit(f" lli 0, {result_reg}")
self.emit(f"{temp_label}:")
self.emit(f" cmp {left_reg}, {right_reg}")
self.emit(f" jlt {end_label}")
self.emit(f" sub {left_reg}, {right_reg}, {left_reg}")
self.emit(f" inc {result_reg}")
self.emit(f" jmp {temp_label}")
self.emit(f"{end_label}:")
elif expr.op in ["==", "!=", "<", ">", "<=", ">="]:
self.emit(f" cmp {left_reg}, {right_reg}")
# Result is 1 if condition true, 0 otherwise
self.emit(f" lli 0, {result_reg}")
true_label = self.new_label("cmp_true")
end_label = self.new_label("cmp_end")
if expr.op == "==":
self.emit(f" jeq {true_label}")
elif expr.op == "!=":
self.emit(f" jne {true_label}")
elif expr.op == "<":
self.emit(f" jlt {true_label}")
elif expr.op == ">":
self.emit(f" jgt {true_label}")
elif expr.op == "<=":
self.emit(f" jle {true_label}")
elif expr.op == ">=":
self.emit(f" jge {true_label}")
self.emit(f" jmp {end_label}")
self.emit(f"{true_label}:")
self.emit(f" lli 1, {result_reg}")
self.emit(f"{end_label}:")
self.free_register(left_reg)
self.free_register(right_reg)
return result_reg
def compile_c_to_asm(source: str) -> str:
"""Compile C source code to DSA assembly."""
lexer = Lexer(source)
tokens = lexer.tokenize()
parser = Parser(tokens)
ast = parser.parse()
codegen = CodeGenerator()
assembly = codegen.generate(ast)
return assembly
def main():
if len(sys.argv) < 2:
print("Usage: python compiler.py <input.c> [output.dsa]")
sys.exit(1)
input_file = sys.argv[1]
output_file = sys.argv[2] if len(sys.argv) > 2 else input_file.replace(".c", ".dsa")
with open(input_file, "r") as f:
source = f.read()
try:
assembly = compile_c_to_asm(source)
with open(output_file, "w") as f:
f.write(assembly)
print(f"Successfully compiled {input_file} to {output_file}")
except (SyntaxError, RuntimeError) as e:
print(f"Compilation error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
# # Example usage
# if len(sys.argv) > 1:
# example_c = sys.argv[1]
# else:
# example_c = """
# int factorial(int n) {
# if (n <= 1) {
# return 1;
# }
# return n * factorial(n - 1);
# }
# int main() {
# int result;
# result = factorial(5);
# return result;
# }
# """
# print("Example C program:")
# print(example_c)
# print("\n" + "="*60 + "\n")
# print("Generated DSA assembly:")
# print(compile_c_to_asm(example_c))