started work on c compiler
This commit is contained in:
@@ -0,0 +1,926 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple C to DSA Assembly Compiler
|
||||
Supports a subset of C including:
|
||||
- int variables and functions
|
||||
- Arithmetic operations (+, -, *, /)
|
||||
- Comparisons (==, !=, <, >, <=, >=)
|
||||
- If/else statements
|
||||
- While loops
|
||||
- Function calls
|
||||
- Return statements
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
import json
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
# Keywords
|
||||
INT = "int"
|
||||
IF = "if"
|
||||
ELSE = "else"
|
||||
WHILE = "while"
|
||||
RETURN = "return"
|
||||
|
||||
# Identifiers and literals
|
||||
IDENTIFIER = "IDENTIFIER"
|
||||
NUMBER = "NUMBER"
|
||||
|
||||
# Operators
|
||||
PLUS = "+"
|
||||
MINUS = "-"
|
||||
STAR = "*"
|
||||
SLASH = "/"
|
||||
ASSIGN = "="
|
||||
EQ = "=="
|
||||
NE = "!="
|
||||
LT = "<"
|
||||
GT = ">"
|
||||
LE = "<="
|
||||
GE = ">="
|
||||
|
||||
# Delimiters
|
||||
LPAREN = "("
|
||||
RPAREN = ")"
|
||||
LBRACE = "{"
|
||||
RBRACE = "}"
|
||||
SEMICOLON = ";"
|
||||
COMMA = ","
|
||||
|
||||
EOF = "EOF"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
type: TokenType
|
||||
value: str
|
||||
line: int
|
||||
col: int
|
||||
|
||||
|
||||
class Lexer:
|
||||
def __init__(self, source: str):
|
||||
self.source = source
|
||||
self.pos = 0
|
||||
self.line = 1
|
||||
self.col = 1
|
||||
self.tokens = []
|
||||
|
||||
def error(self, msg: str):
|
||||
raise SyntaxError(f"Lexer error at line {self.line}, col {self.col}: {msg}")
|
||||
|
||||
def peek(self, offset: int = 0) -> Optional[str]:
|
||||
pos = self.pos + offset
|
||||
return self.source[pos] if pos < len(self.source) else None
|
||||
|
||||
def advance(self) -> Optional[str]:
|
||||
if self.pos >= len(self.source):
|
||||
return None
|
||||
char = self.source[self.pos]
|
||||
self.pos += 1
|
||||
if char == "\n":
|
||||
self.line += 1
|
||||
self.col = 1
|
||||
else:
|
||||
self.col += 1
|
||||
return char
|
||||
|
||||
def skip_whitespace(self):
|
||||
while self.peek() and self.peek() in " \t\n\r":
|
||||
self.advance()
|
||||
|
||||
def skip_comment(self):
|
||||
if self.peek() == "/" and self.peek(1) == "/":
|
||||
while self.peek() and self.peek() != "\n":
|
||||
self.advance()
|
||||
self.advance() # skip newline
|
||||
|
||||
def read_number(self) -> str:
|
||||
num = ""
|
||||
while self.peek() and self.peek().isdigit():
|
||||
num += self.advance()
|
||||
return num
|
||||
|
||||
def read_identifier(self) -> str:
|
||||
ident = ""
|
||||
while self.peek() and (self.peek().isalnum() or self.peek() == "_"):
|
||||
ident += self.advance()
|
||||
return ident
|
||||
|
||||
def tokenize(self) -> List[Token]:
|
||||
keywords = {
|
||||
"int": TokenType.INT,
|
||||
"if": TokenType.IF,
|
||||
"else": TokenType.ELSE,
|
||||
"while": TokenType.WHILE,
|
||||
"return": TokenType.RETURN,
|
||||
}
|
||||
|
||||
while self.pos < len(self.source):
|
||||
self.skip_whitespace()
|
||||
self.skip_comment()
|
||||
|
||||
if self.pos >= len(self.source):
|
||||
break
|
||||
|
||||
line, col = self.line, self.col
|
||||
char = self.peek()
|
||||
|
||||
# Numbers
|
||||
if char.isdigit():
|
||||
num = self.read_number()
|
||||
self.tokens.append(Token(TokenType.NUMBER, num, line, col))
|
||||
|
||||
# Identifiers and keywords
|
||||
elif char.isalpha() or char == "_":
|
||||
ident = self.read_identifier()
|
||||
token_type = keywords.get(ident, TokenType.IDENTIFIER)
|
||||
self.tokens.append(Token(token_type, ident, line, col))
|
||||
|
||||
# Two-character operators
|
||||
elif char == "=" and self.peek(1) == "=":
|
||||
self.advance()
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.EQ, "==", line, col))
|
||||
elif char == "!" and self.peek(1) == "=":
|
||||
self.advance()
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.NE, "!=", line, col))
|
||||
elif char == "<" and self.peek(1) == "=":
|
||||
self.advance()
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.LE, "<=", line, col))
|
||||
elif char == ">" and self.peek(1) == "=":
|
||||
self.advance()
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.GE, ">=", line, col))
|
||||
|
||||
# Single-character operators
|
||||
elif char == "+":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.PLUS, "+", line, col))
|
||||
elif char == "-":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.MINUS, "-", line, col))
|
||||
elif char == "*":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.STAR, "*", line, col))
|
||||
elif char == "/":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.SLASH, "/", line, col))
|
||||
elif char == "=":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.ASSIGN, "=", line, col))
|
||||
elif char == "<":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.LT, "<", line, col))
|
||||
elif char == ">":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.GT, ">", line, col))
|
||||
elif char == "(":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.LPAREN, "(", line, col))
|
||||
elif char == ")":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.RPAREN, ")", line, col))
|
||||
elif char == "{":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.LBRACE, "{", line, col))
|
||||
elif char == "}":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.RBRACE, "}", line, col))
|
||||
elif char == ";":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.SEMICOLON, ";", line, col))
|
||||
elif char == ",":
|
||||
self.advance()
|
||||
self.tokens.append(Token(TokenType.COMMA, ",", line, col))
|
||||
else:
|
||||
self.error(f"Unexpected character: {char}")
|
||||
|
||||
self.tokens.append(Token(TokenType.EOF, "", self.line, self.col))
|
||||
return self.tokens
|
||||
|
||||
|
||||
# AST Node classes
|
||||
@dataclass
|
||||
class ASTNode:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Program(ASTNode):
|
||||
declarations: List["Declaration"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Declaration(ASTNode):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionDecl(Declaration):
|
||||
name: str
|
||||
params: List[str]
|
||||
body: "CompoundStmt"
|
||||
|
||||
|
||||
@dataclass
|
||||
class VarDecl(Declaration):
|
||||
name: str
|
||||
init: Optional["Expression"] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Statement(ASTNode):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompoundStmt(Statement):
|
||||
statements: List[Statement]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExprStmt(Statement):
|
||||
expr: Optional["Expression"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IfStmt(Statement):
|
||||
condition: "Expression"
|
||||
then_stmt: Statement
|
||||
else_stmt: Optional[Statement] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhileStmt(Statement):
|
||||
condition: "Expression"
|
||||
body: Statement
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReturnStmt(Statement):
|
||||
expr: Optional["Expression"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Expression(ASTNode):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryOp(Expression):
|
||||
op: str
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnaryOp(Expression):
|
||||
op: str
|
||||
operand: Expression
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssignExpr(Expression):
|
||||
name: str
|
||||
value: Expression
|
||||
|
||||
|
||||
@dataclass
|
||||
class VarExpr(Expression):
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberExpr(Expression):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallExpr(Expression):
|
||||
name: str
|
||||
args: List[Expression]
|
||||
|
||||
|
||||
class Parser:
|
||||
def __init__(self, tokens: List[Token]):
|
||||
self.tokens = tokens
|
||||
self.pos = 0
|
||||
|
||||
def error(self, msg: str):
|
||||
token = self.current()
|
||||
raise SyntaxError(f"Parser error at line {token.line}, col {token.col}: {msg}")
|
||||
|
||||
def current(self) -> Token:
|
||||
return self.tokens[self.pos] if self.pos < len(self.tokens) else self.tokens[-1]
|
||||
|
||||
def peek(self, offset: int = 0) -> Token:
|
||||
pos = self.pos + offset
|
||||
return self.tokens[pos] if pos < len(self.tokens) else self.tokens[-1]
|
||||
|
||||
def advance(self) -> Token:
|
||||
token = self.current()
|
||||
if self.pos < len(self.tokens) - 1:
|
||||
self.pos += 1
|
||||
return token
|
||||
|
||||
def expect(self, token_type: TokenType) -> Token:
|
||||
token = self.current()
|
||||
if token.type != token_type:
|
||||
self.error(f"Expected {token_type.value}, got {token.type.value}")
|
||||
return self.advance()
|
||||
|
||||
def parse(self) -> Program:
|
||||
declarations = []
|
||||
while self.current().type != TokenType.EOF:
|
||||
declarations.append(self.parse_declaration())
|
||||
return Program(declarations)
|
||||
|
||||
def parse_declaration(self) -> Declaration:
|
||||
self.expect(TokenType.INT)
|
||||
name = self.expect(TokenType.IDENTIFIER).value
|
||||
|
||||
if self.current().type == TokenType.LPAREN:
|
||||
# Function declaration
|
||||
self.advance()
|
||||
params = []
|
||||
|
||||
if self.current().type != TokenType.RPAREN:
|
||||
self.expect(TokenType.INT)
|
||||
params.append(self.expect(TokenType.IDENTIFIER).value)
|
||||
|
||||
while self.current().type == TokenType.COMMA:
|
||||
self.advance()
|
||||
self.expect(TokenType.INT)
|
||||
params.append(self.expect(TokenType.IDENTIFIER).value)
|
||||
|
||||
self.expect(TokenType.RPAREN)
|
||||
body = self.parse_compound_stmt()
|
||||
return FunctionDecl(name, params, body)
|
||||
else:
|
||||
# Variable declaration
|
||||
init = None
|
||||
if self.current().type == TokenType.ASSIGN:
|
||||
self.advance()
|
||||
init = self.parse_expression()
|
||||
self.expect(TokenType.SEMICOLON)
|
||||
return VarDecl(name, init)
|
||||
|
||||
def parse_compound_stmt(self) -> CompoundStmt:
|
||||
self.expect(TokenType.LBRACE)
|
||||
statements = []
|
||||
|
||||
while self.current().type != TokenType.RBRACE:
|
||||
statements.append(self.parse_statement())
|
||||
|
||||
self.expect(TokenType.RBRACE)
|
||||
return CompoundStmt(statements)
|
||||
|
||||
def parse_statement(self) -> Statement:
|
||||
token = self.current()
|
||||
|
||||
if token.type == TokenType.LBRACE:
|
||||
return self.parse_compound_stmt()
|
||||
elif token.type == TokenType.IF:
|
||||
return self.parse_if_stmt()
|
||||
elif token.type == TokenType.WHILE:
|
||||
return self.parse_while_stmt()
|
||||
elif token.type == TokenType.RETURN:
|
||||
return self.parse_return_stmt()
|
||||
elif token.type == TokenType.INT:
|
||||
# Local variable declaration
|
||||
self.advance()
|
||||
name = self.expect(TokenType.IDENTIFIER).value
|
||||
init = None
|
||||
if self.current().type == TokenType.ASSIGN:
|
||||
self.advance()
|
||||
init = self.parse_expression()
|
||||
self.expect(TokenType.SEMICOLON)
|
||||
return ExprStmt(AssignExpr(name, init) if init else None)
|
||||
else:
|
||||
expr = (
|
||||
self.parse_expression()
|
||||
if self.current().type != TokenType.SEMICOLON
|
||||
else None
|
||||
)
|
||||
self.expect(TokenType.SEMICOLON)
|
||||
return ExprStmt(expr)
|
||||
|
||||
def parse_if_stmt(self) -> IfStmt:
|
||||
self.expect(TokenType.IF)
|
||||
self.expect(TokenType.LPAREN)
|
||||
condition = self.parse_expression()
|
||||
self.expect(TokenType.RPAREN)
|
||||
then_stmt = self.parse_statement()
|
||||
|
||||
else_stmt = None
|
||||
if self.current().type == TokenType.ELSE:
|
||||
self.advance()
|
||||
else_stmt = self.parse_statement()
|
||||
|
||||
return IfStmt(condition, then_stmt, else_stmt)
|
||||
|
||||
def parse_while_stmt(self) -> WhileStmt:
|
||||
self.expect(TokenType.WHILE)
|
||||
self.expect(TokenType.LPAREN)
|
||||
condition = self.parse_expression()
|
||||
self.expect(TokenType.RPAREN)
|
||||
body = self.parse_statement()
|
||||
return WhileStmt(condition, body)
|
||||
|
||||
def parse_return_stmt(self) -> ReturnStmt:
|
||||
self.expect(TokenType.RETURN)
|
||||
expr = None
|
||||
if self.current().type != TokenType.SEMICOLON:
|
||||
expr = self.parse_expression()
|
||||
self.expect(TokenType.SEMICOLON)
|
||||
return ReturnStmt(expr)
|
||||
|
||||
def parse_expression(self) -> Expression:
|
||||
return self.parse_assignment()
|
||||
|
||||
def parse_assignment(self) -> Expression:
|
||||
expr = self.parse_comparison()
|
||||
|
||||
if self.current().type == TokenType.ASSIGN:
|
||||
if not isinstance(expr, VarExpr):
|
||||
self.error("Invalid assignment target")
|
||||
self.advance()
|
||||
value = self.parse_assignment()
|
||||
return AssignExpr(expr.name, value)
|
||||
|
||||
return expr
|
||||
|
||||
def parse_comparison(self) -> Expression:
|
||||
expr = self.parse_additive()
|
||||
|
||||
while self.current().type in [
|
||||
TokenType.EQ,
|
||||
TokenType.NE,
|
||||
TokenType.LT,
|
||||
TokenType.GT,
|
||||
TokenType.LE,
|
||||
TokenType.GE,
|
||||
]:
|
||||
op = self.advance().value
|
||||
right = self.parse_additive()
|
||||
expr = BinaryOp(op, expr, right)
|
||||
|
||||
return expr
|
||||
|
||||
def parse_additive(self) -> Expression:
|
||||
expr = self.parse_multiplicative()
|
||||
|
||||
while self.current().type in [TokenType.PLUS, TokenType.MINUS]:
|
||||
op = self.advance().value
|
||||
right = self.parse_multiplicative()
|
||||
expr = BinaryOp(op, expr, right)
|
||||
|
||||
return expr
|
||||
|
||||
def parse_multiplicative(self) -> Expression:
|
||||
expr = self.parse_unary()
|
||||
|
||||
while self.current().type in [TokenType.STAR, TokenType.SLASH]:
|
||||
op = self.advance().value
|
||||
right = self.parse_unary()
|
||||
expr = BinaryOp(op, expr, right)
|
||||
|
||||
return expr
|
||||
|
||||
def parse_unary(self) -> Expression:
|
||||
if self.current().type in [TokenType.PLUS, TokenType.MINUS]:
|
||||
op = self.advance().value
|
||||
operand = self.parse_unary()
|
||||
return UnaryOp(op, operand)
|
||||
|
||||
return self.parse_primary()
|
||||
|
||||
def parse_primary(self) -> Expression:
|
||||
token = self.current()
|
||||
|
||||
if token.type == TokenType.NUMBER:
|
||||
self.advance()
|
||||
return NumberExpr(int(token.value))
|
||||
|
||||
elif token.type == TokenType.IDENTIFIER:
|
||||
name = self.advance().value
|
||||
|
||||
if self.current().type == TokenType.LPAREN:
|
||||
# Function call
|
||||
self.advance()
|
||||
args = []
|
||||
|
||||
if self.current().type != TokenType.RPAREN:
|
||||
args.append(self.parse_expression())
|
||||
while self.current().type == TokenType.COMMA:
|
||||
self.advance()
|
||||
args.append(self.parse_expression())
|
||||
|
||||
self.expect(TokenType.RPAREN)
|
||||
return CallExpr(name, args)
|
||||
else:
|
||||
return VarExpr(name)
|
||||
|
||||
elif token.type == TokenType.LPAREN:
|
||||
self.advance()
|
||||
expr = self.parse_expression()
|
||||
self.expect(TokenType.RPAREN)
|
||||
return expr
|
||||
|
||||
else:
|
||||
self.error(f"Unexpected token: {token.type.value}")
|
||||
|
||||
|
||||
class CodeGenerator:
|
||||
def __init__(self):
|
||||
self.output = []
|
||||
self.label_counter = 0
|
||||
self.string_counter = 0
|
||||
self.functions = {}
|
||||
self.current_function = None
|
||||
self.local_vars = {}
|
||||
self.global_vars = {}
|
||||
self.register_pool = [f"rg{i:x}" for i in range(16)]
|
||||
self.used_registers = set()
|
||||
|
||||
def new_label(self, prefix: str = "L") -> str:
|
||||
label = f"{prefix}{self.label_counter}"
|
||||
self.label_counter += 1
|
||||
return label
|
||||
|
||||
def allocate_register(self) -> str:
|
||||
for reg in self.register_pool:
|
||||
if reg not in self.used_registers:
|
||||
self.used_registers.add(reg)
|
||||
return reg
|
||||
raise RuntimeError("Out of registers")
|
||||
|
||||
def free_register(self, reg: str):
|
||||
self.used_registers.discard(reg)
|
||||
|
||||
def emit(self, code: str):
|
||||
self.output.append(code)
|
||||
|
||||
def generate(self, program: Program) -> str:
|
||||
# Emit data section
|
||||
self.emit("// Global variables")
|
||||
for decl in program.declarations:
|
||||
if isinstance(decl, VarDecl):
|
||||
self.global_vars[decl.name] = f"var_{decl.name}"
|
||||
if decl.init:
|
||||
if isinstance(decl.init, NumberExpr):
|
||||
self.emit(f"dw var_{decl.name}: {decl.init.value}")
|
||||
else:
|
||||
self.emit(f"dw var_{decl.name}: 0")
|
||||
else:
|
||||
self.emit(f"dw var_{decl.name}: 0")
|
||||
|
||||
self.emit("")
|
||||
self.emit("// Entry point")
|
||||
self.emit("dw stack_bottom: 0x10000")
|
||||
self.emit("")
|
||||
self.emit("init:")
|
||||
self.emit(" ldw stack_bottom, spr")
|
||||
self.emit(" mov spr, bpr")
|
||||
|
||||
self.emit(" push zero")
|
||||
self.emit(" call main")
|
||||
self.emit(" pop rg0")
|
||||
self.emit(" hlt")
|
||||
self.emit("")
|
||||
|
||||
# Emit functions
|
||||
for decl in program.declarations:
|
||||
if isinstance(decl, FunctionDecl):
|
||||
self.generate_function(decl)
|
||||
|
||||
return "\n".join(self.output)
|
||||
|
||||
def generate_function(self, func: FunctionDecl):
|
||||
self.current_function = func.name
|
||||
self.functions[func.name] = func
|
||||
self.local_vars = {}
|
||||
|
||||
# Map parameters to stack offsets
|
||||
# Parameters start at bpr+8 (after return addr at bpr+4)
|
||||
for i, param in enumerate(func.params):
|
||||
self.local_vars[param] = 8 + (i * 4)
|
||||
|
||||
self.emit(f"{func.name}:")
|
||||
self.emit(" push bpr")
|
||||
self.emit(" mov spr, bpr")
|
||||
self.emit("")
|
||||
|
||||
# Generate function body
|
||||
self.generate_compound_stmt(func.body)
|
||||
|
||||
# Default return if no explicit return
|
||||
self.emit("// default return")
|
||||
self.emit(f"{func.name}_end:")
|
||||
self.emit(" mov bpr, spr")
|
||||
self.emit(" pop bpr")
|
||||
self.emit(" return")
|
||||
self.emit("")
|
||||
|
||||
def generate_compound_stmt(self, stmt: CompoundStmt):
|
||||
for s in stmt.statements:
|
||||
self.generate_statement(s)
|
||||
|
||||
def generate_statement(self, stmt: Statement):
|
||||
if isinstance(stmt, CompoundStmt):
|
||||
self.generate_compound_stmt(stmt)
|
||||
elif isinstance(stmt, ExprStmt):
|
||||
if stmt.expr:
|
||||
reg = self.generate_expression(stmt.expr)
|
||||
self.free_register(reg)
|
||||
elif isinstance(stmt, IfStmt):
|
||||
self.generate_if_stmt(stmt)
|
||||
elif isinstance(stmt, WhileStmt):
|
||||
self.generate_while_stmt(stmt)
|
||||
elif isinstance(stmt, ReturnStmt):
|
||||
self.generate_return_stmt(stmt)
|
||||
|
||||
def generate_if_stmt(self, stmt: IfStmt):
|
||||
else_label = self.new_label("else")
|
||||
end_label = self.new_label("endif")
|
||||
|
||||
# Evaluate condition
|
||||
cond_reg = self.generate_expression(stmt.condition)
|
||||
self.emit(f" cmp {cond_reg}, zero")
|
||||
self.free_register(cond_reg)
|
||||
|
||||
if stmt.else_stmt:
|
||||
self.emit(f" jeq {else_label}")
|
||||
else:
|
||||
self.emit(f" jeq {end_label}")
|
||||
|
||||
# Then branch
|
||||
self.generate_statement(stmt.then_stmt)
|
||||
|
||||
if stmt.else_stmt:
|
||||
self.emit(f" jmp {end_label}")
|
||||
self.emit(f"{else_label}:")
|
||||
self.generate_statement(stmt.else_stmt)
|
||||
|
||||
self.emit(f"{end_label}:")
|
||||
|
||||
def generate_while_stmt(self, stmt: WhileStmt):
|
||||
start_label = self.new_label("while_start")
|
||||
end_label = self.new_label("while_end")
|
||||
|
||||
self.emit(f"{start_label}:")
|
||||
|
||||
# Evaluate condition
|
||||
cond_reg = self.generate_expression(stmt.condition)
|
||||
self.emit(f" cmp {cond_reg}, zero")
|
||||
self.free_register(cond_reg)
|
||||
self.emit(f" jeq {end_label}")
|
||||
|
||||
# Loop body
|
||||
self.generate_statement(stmt.body)
|
||||
self.emit(f" jmp {start_label}")
|
||||
|
||||
self.emit(f"{end_label}:")
|
||||
|
||||
def generate_return_stmt(self, stmt: ReturnStmt):
|
||||
if stmt.expr:
|
||||
reg = self.generate_expression(stmt.expr)
|
||||
# Store return value at spr+8 according to calling convention
|
||||
self.emit(f" stw {reg}, spr, 8")
|
||||
self.free_register(reg)
|
||||
self.emit(f" jmp {self.current_function}_end")
|
||||
|
||||
def generate_expression(self, expr: Expression) -> str:
|
||||
if isinstance(expr, NumberExpr):
|
||||
reg = self.allocate_register()
|
||||
if expr.value <= 0xFFFF and expr.value >= 0:
|
||||
self.emit(f" lli {expr.value}, {reg}")
|
||||
if expr.value > 0xFF:
|
||||
self.emit(f" lui {expr.value >> 16}, {reg}")
|
||||
else:
|
||||
self.emit(f" lli {expr.value & 0xFFFF}, {reg}")
|
||||
self.emit(f" lui {(expr.value >> 16) & 0xFFFF}, {reg}")
|
||||
return reg
|
||||
|
||||
elif isinstance(expr, VarExpr):
|
||||
reg = self.allocate_register()
|
||||
if expr.name in self.local_vars:
|
||||
offset = self.local_vars[expr.name]
|
||||
self.emit(f" ldw bpr, {reg}, {offset}")
|
||||
elif expr.name in self.global_vars:
|
||||
label = self.global_vars[expr.name]
|
||||
self.emit(f" ldw {label}, {reg}")
|
||||
else:
|
||||
raise RuntimeError(f"Undefined variable: {expr.name}")
|
||||
return reg
|
||||
|
||||
elif isinstance(expr, AssignExpr):
|
||||
value_reg = self.generate_expression(expr.value)
|
||||
|
||||
if expr.name in self.local_vars:
|
||||
offset = self.local_vars[expr.name]
|
||||
self.emit(f" stw {value_reg}, bpr, {offset}")
|
||||
elif expr.name in self.global_vars:
|
||||
label = self.global_vars[expr.name]
|
||||
self.emit(f" stw {value_reg}, {label}")
|
||||
else:
|
||||
# New local variable - allocate after params and return value space
|
||||
# Start local variables at offset -4 from bpr (growing downward)
|
||||
offset = -(len([v for v in self.local_vars.values() if v < 0]) + 1) * 4
|
||||
self.local_vars[expr.name] = offset
|
||||
self.emit(f" stw {value_reg}, bpr, {offset}")
|
||||
|
||||
return value_reg
|
||||
|
||||
elif isinstance(expr, BinaryOp):
|
||||
return self.generate_binary_op(expr)
|
||||
|
||||
elif isinstance(expr, UnaryOp):
|
||||
operand_reg = self.generate_expression(expr.operand)
|
||||
result_reg = self.allocate_register()
|
||||
|
||||
if expr.op == "-":
|
||||
self.emit(f" lwi 0, {result_reg}")
|
||||
self.emit(f" sub {result_reg}, {operand_reg}, {result_reg}")
|
||||
else: # +
|
||||
self.emit(f" mov {operand_reg}, {result_reg}")
|
||||
|
||||
self.free_register(operand_reg)
|
||||
return result_reg
|
||||
|
||||
elif isinstance(expr, CallExpr):
|
||||
# First, make space for return value (must be pushed BEFORE arguments)
|
||||
temp_reg = self.allocate_register()
|
||||
|
||||
# Then push arguments in reverse order
|
||||
arg_regs = []
|
||||
for arg in reversed(expr.args):
|
||||
reg = self.generate_expression(arg)
|
||||
self.emit(f" push {reg}")
|
||||
arg_regs.append(reg)
|
||||
|
||||
# Call function
|
||||
self.emit(f" call {expr.name}")
|
||||
|
||||
# Get return value (it's now on top of stack)
|
||||
self.emit(f" pop {temp_reg}")
|
||||
|
||||
# Clean up remaining args
|
||||
for i in range(len(arg_regs) - 1):
|
||||
self.emit(f" pop zero")
|
||||
|
||||
# Free the arg registers
|
||||
for reg in arg_regs:
|
||||
self.free_register(reg)
|
||||
|
||||
return temp_reg
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown expression type: {type(expr)}")
|
||||
|
||||
def generate_binary_op(self, expr: BinaryOp) -> str:
|
||||
# For operations that might contain function calls, we need to be careful
|
||||
# about register allocation. Evaluate left, save it, evaluate right.
|
||||
left_reg = self.generate_expression(expr.left)
|
||||
|
||||
# If right side contains a function call, we need to save left_reg
|
||||
# For now, always save to be safe
|
||||
saved_reg = self.allocate_register()
|
||||
self.emit(f" mov {left_reg}, {saved_reg}")
|
||||
self.free_register(left_reg)
|
||||
|
||||
right_reg = self.generate_expression(expr.right)
|
||||
result_reg = self.allocate_register()
|
||||
|
||||
if expr.op == "+":
|
||||
self.emit(f" add {left_reg}, {right_reg}, {result_reg}")
|
||||
elif expr.op == "-":
|
||||
self.emit(f" sub {left_reg}, {right_reg}, {result_reg}")
|
||||
elif expr.op == "*":
|
||||
# Simple multiplication using loop
|
||||
temp_label = self.new_label("mult")
|
||||
end_label = self.new_label("mult_end")
|
||||
self.emit(f" lli 0, {result_reg}")
|
||||
self.emit(f"{temp_label}:")
|
||||
self.emit(f" cmp {right_reg}, zero")
|
||||
self.emit(f" jeq {end_label}")
|
||||
self.emit(f" add {result_reg}, {left_reg}, {result_reg}")
|
||||
self.emit(f" dec {right_reg}")
|
||||
self.emit(f" jmp {temp_label}")
|
||||
self.emit(f"{end_label}:")
|
||||
elif expr.op == "/":
|
||||
# Simple division using loop
|
||||
temp_label = self.new_label("div")
|
||||
end_label = self.new_label("div_end")
|
||||
self.emit(f" lli 0, {result_reg}")
|
||||
self.emit(f"{temp_label}:")
|
||||
self.emit(f" cmp {left_reg}, {right_reg}")
|
||||
self.emit(f" jlt {end_label}")
|
||||
self.emit(f" sub {left_reg}, {right_reg}, {left_reg}")
|
||||
self.emit(f" inc {result_reg}")
|
||||
self.emit(f" jmp {temp_label}")
|
||||
self.emit(f"{end_label}:")
|
||||
elif expr.op in ["==", "!=", "<", ">", "<=", ">="]:
|
||||
self.emit(f" cmp {left_reg}, {right_reg}")
|
||||
|
||||
# Result is 1 if condition true, 0 otherwise
|
||||
self.emit(f" lli 0, {result_reg}")
|
||||
true_label = self.new_label("cmp_true")
|
||||
end_label = self.new_label("cmp_end")
|
||||
|
||||
if expr.op == "==":
|
||||
self.emit(f" jeq {true_label}")
|
||||
elif expr.op == "!=":
|
||||
self.emit(f" jne {true_label}")
|
||||
elif expr.op == "<":
|
||||
self.emit(f" jlt {true_label}")
|
||||
elif expr.op == ">":
|
||||
self.emit(f" jgt {true_label}")
|
||||
elif expr.op == "<=":
|
||||
self.emit(f" jle {true_label}")
|
||||
elif expr.op == ">=":
|
||||
self.emit(f" jge {true_label}")
|
||||
|
||||
self.emit(f" jmp {end_label}")
|
||||
self.emit(f"{true_label}:")
|
||||
self.emit(f" lli 1, {result_reg}")
|
||||
self.emit(f"{end_label}:")
|
||||
|
||||
self.free_register(left_reg)
|
||||
self.free_register(right_reg)
|
||||
return result_reg
|
||||
|
||||
|
||||
def compile_c_to_asm(source: str) -> str:
|
||||
"""Compile C source code to DSA assembly."""
|
||||
lexer = Lexer(source)
|
||||
tokens = lexer.tokenize()
|
||||
|
||||
parser = Parser(tokens)
|
||||
ast = parser.parse()
|
||||
|
||||
codegen = CodeGenerator()
|
||||
assembly = codegen.generate(ast)
|
||||
|
||||
return assembly
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python compiler.py <input.c> [output.dsa]")
|
||||
sys.exit(1)
|
||||
|
||||
input_file = sys.argv[1]
|
||||
output_file = sys.argv[2] if len(sys.argv) > 2 else input_file.replace(".c", ".dsa")
|
||||
|
||||
with open(input_file, "r") as f:
|
||||
source = f.read()
|
||||
|
||||
try:
|
||||
assembly = compile_c_to_asm(source)
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
f.write(assembly)
|
||||
|
||||
print(f"Successfully compiled {input_file} to {output_file}")
|
||||
except (SyntaxError, RuntimeError) as e:
|
||||
print(f"Compilation error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# # Example usage
|
||||
# if len(sys.argv) > 1:
|
||||
# example_c = sys.argv[1]
|
||||
|
||||
# else:
|
||||
# example_c = """
|
||||
# int factorial(int n) {
|
||||
# if (n <= 1) {
|
||||
# return 1;
|
||||
# }
|
||||
# return n * factorial(n - 1);
|
||||
# }
|
||||
|
||||
# int main() {
|
||||
# int result;
|
||||
# result = factorial(5);
|
||||
# return result;
|
||||
# }
|
||||
# """
|
||||
|
||||
# print("Example C program:")
|
||||
# print(example_c)
|
||||
# print("\n" + "="*60 + "\n")
|
||||
# print("Generated DSA assembly:")
|
||||
# print(compile_c_to_asm(example_c))
|
||||
Reference in New Issue
Block a user