From 259746558f256fea80fb2652272ebf8af219e9a2 Mon Sep 17 00:00:00 2001 From: zxq5 Date: Thu, 29 Jan 2026 19:29:48 +0000 Subject: [PATCH] codegen progress --- .gitignore | 4 +- c_compiler/code.c | 11 +- c_compiler/src/codegen.rs | 465 ++++++++++++++++++++++++++++++++---- c_compiler/src/main.rs | 1 + c_compiler/src/parser.rs | 30 +-- c_compiler/src/registers.rs | 324 +++++++++++++++++++++++++ 6 files changed, 767 insertions(+), 68 deletions(-) create mode 100644 c_compiler/src/registers.rs diff --git a/.gitignore b/.gitignore index a851572..b05f24c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target -**/*.env \ No newline at end of file +**/*.env +Cargo.lock +*Cargo.lock diff --git a/c_compiler/code.c b/c_compiler/code.c index 5840a39..d069fa2 100644 --- a/c_compiler/code.c +++ b/c_compiler/code.c @@ -1,6 +1,4 @@ -int x = 5; - -int add(int a, int b) { return a + b; } +int var_x = 5; int factorial(int n) { if (n <= 1) { @@ -10,12 +8,7 @@ int factorial(int n) { } int main() { - int x; - x = 5; - int x = 5; - int result; - int result = 5; - result = x + factorial(5); + int result = var_x + factorial(5); print(result); return 0; } diff --git a/c_compiler/src/codegen.rs b/c_compiler/src/codegen.rs index 201ad29..308fe43 100644 --- a/c_compiler/src/codegen.rs +++ b/c_compiler/src/codegen.rs @@ -1,17 +1,25 @@ +use std::hash::Hash; +use std::sync::atomic::AtomicU32; use std::time::SystemTime; use std::{collections::HashMap, path::PathBuf}; use chrono::{DateTime, Local}; +use crate::registers::RegisterAllocator; use crate::{block, cmd, comment, dsa}; -use crate::parser::{ConstExpr, Declaration, Program}; +use crate::parser::{ + BinaryOperator, ConstExpr, Declaration, Expression, Parameter, Program, Statement, + UnaryOperator, +}; pub struct CodeGenerator { ast: Program, imports: HashMap, globals: Vec, functions: Vec, + allocator: RegisterAllocator, + call_stack: Vec, } fn import(name: &str, path: &str) -> String { @@ -25,6 +33,8 @@ impl CodeGenerator { imports: HashMap::new(), globals: Vec::new(), functions: Vec::new(), + allocator: RegisterAllocator::new(), + call_stack: Vec::new(), } } @@ -37,44 +47,16 @@ impl CodeGenerator { self.include("print", "./lib/io/print.dsa"); for block in self.ast.clone().declarations { - self.generate_block(block.clone()); + self.generate_block(block.clone())?; + } + + for func in &self.functions { + println!("{func}"); } self.generate_layout() } - fn generate_block(&mut self, block: Declaration) { - match block { - Declaration::Variable { name, init } => self.globals.push(format!( - "dw {}: {}", - name, - init.unwrap_or(ConstExpr::Number(0)) - )), - Declaration::Function { - name, - return_type, - params, - body, - } => { - let function_start = format!( - "{name}: \n\t\ - push bpr \n\t\ - mov spr, bpr" - ); - - let function_end = format!( - "\n\t\ - mov bpr, spr \n\t\ - pop bpr \n\t\ - return\n" - ); - - self.functions - .push(format!("{function_start}\n{function_end}")); - } - } - } - fn generate_layout(&mut self) -> Result { let datetime: DateTime = SystemTime::now().into(); Ok(dsa![ @@ -114,19 +96,416 @@ impl CodeGenerator { dsa![pop zero], dsa![hlt] ], - block! [ "main" - dsa![push bpr], - dsa![mov spr, bpr], - dsa![lwi 67, rg1], - dsa![stw rg1, spr, 8], - dsa![mov bpr, spr], - dsa![pop bpr], - dsa![return] - ], + // block! [ "main" + // dsa![push bpr], + // dsa![mov spr, bpr], + // dsa![lwi 67, rg1], + // dsa![stw rg1, spr, 8], + // dsa![mov bpr, spr], + // dsa![pop bpr], + // dsa![return] + // ], "", self.functions.join("\n"), ]) } + + fn generate_global(&mut self, name: &str, init: Option) { + self.globals.push(format!( + "dw {}: {}", + name, + init.unwrap_or(ConstExpr::Number(0)) + )) + } + + fn generate_block(&mut self, block: Declaration) -> Result<(), String> { + match block { + Declaration::Variable { name, init } => self.generate_global(&name, init), + Declaration::Function { + name, + return_type, + params, + body, + } => { + let func = self.generate_function(&name, ¶ms, &body).join("\n"); + + self.functions.push(format!("{func}\n")); + } + }; + + Ok(()) + } + + // Example: Generate code for a function + fn generate_function( + &mut self, + name: &str, + params: &[Parameter], + body: &[Statement], + ) -> Vec { + self.call_stack.push(name.to_string()); + + let mut code = Vec::new(); + + // Reset allocator for new function + self.allocator.reset(); + + // Function prologue + code.push(format!("{}:", name)); + code.push("\tpush bpr".to_string()); + code.push("\tmov spr, bpr".to_string()); + + // Allocate parameters to registers or stack locations + for (i, param) in params.iter().enumerate() { + let offset = 8 + (i as i32 * 4); // Parameters start at bpr+8 + // Track that this parameter is at a stack location + let (reg, mut load_code) = self.allocator.alloc_var(¶m.name).unwrap(); + code.extend(load_code); + code.push(format!("\tldw bpr, {}, {}", reg, offset)); + } + + // Generate code for function body + for stmt in body { + let stmt_code = self.generate_statement(stmt).unwrap(); + code.extend(stmt_code); + } + + // Function epilogue + code.push(format!("_ret_{name}:")); + code.push("\tmov bpr, spr".to_string()); + code.push("\tpop bpr".to_string()); + code.push("\treturn".to_string()); + + self.call_stack.pop(); + code + } + + // Example: Generate code for a statement + fn generate_statement(&mut self, stmt: &Statement) -> Result, String> { + let mut code = Vec::new(); + + match stmt { + Statement::Assign { + name, + declare_type, + value, + } => { + if let Some(expr) = value { + // Evaluate expression + let (result_reg, expr_code) = self.generate_expression(expr)?; + code.extend(expr_code); + + // Store result in variable + let store_code = self.allocator.store_var(name, &result_reg); + code.extend(store_code); + + // Free temporary register + self.allocator.free_temp(&result_reg); + } else { + // Just declaring variable without initialization + self.allocator.alloc_var(name)?; + } + } + + Statement::Return { expr } => { + if let Some(e) = expr { + let (result_reg, expr_code) = self.generate_expression(e)?; + code.extend(expr_code); + code.push(format!("\tstw {}, bpr, 8", result_reg)); + code.push(format!("\tjmp _ret_{}", self.call_stack.last().unwrap())); + self.allocator.free_temp(&result_reg); + } + } + + Statement::If { + condition, + then_stmt, + else_stmt, + } => { + // Generate condition + let (cond_reg, cond_code) = self.generate_expression(condition)?; + code.extend(cond_code); + + // Compare with zero + code.push(format!("\tcmp {}, zero", cond_reg)); + self.allocator.free_temp(&cond_reg); + + // Generate unique labels + let then_label = format!("_then_{}", self.get_unique_label()); + let else_label = format!("_else_{}", self.get_unique_label()); + let end_label = format!("_end_{}", self.get_unique_label()); + + // Jump to else if condition is false (equal to zero) + code.push(format!("\tjeq {}", else_label)); + + // Then block + code.push(format!("{}:", then_label)); + for s in then_stmt { + code.extend(self.generate_statement(s)?); + } + + if then_stmt.len() == 0 { + code.push("\tnop".to_string()); + } + + code.push(format!("\tjmp {}", end_label)); + + // Else block + code.push(format!("{}:", else_label)); + for s in else_stmt { + code.extend(self.generate_statement(s)?); + } + + if else_stmt.len() == 0 { + code.push("\tnop".to_string()); + } + + code.push(format!("{}:", end_label)); + } + + Statement::While { condition, body } => { + let loop_start = format!("_while_start_{}", self.get_unique_label()); + let loop_end = format!("_while_end_{}", self.get_unique_label()); + + code.push(format!("{}:", loop_start)); + + // Generate condition + let (cond_reg, cond_code) = self.generate_expression(condition)?; + code.extend(cond_code); + + code.push(format!("\tcmp {}, zero", cond_reg)); + self.allocator.free_temp(&cond_reg); + + code.push(format!("\tjeq {}", loop_end)); + + // Loop body + for s in body { + code.extend(self.generate_statement(s)?); + } + + code.push(format!("\tjmp {}", loop_start)); + code.push(format!("{}:", loop_end)); + } + + Statement::Expression { expr } => { + let (result_reg, expr_code) = self.generate_expression(expr)?; + code.extend(expr_code); + self.allocator.free_temp(&result_reg); + } + + Statement::Block(statements) => { + for s in statements { + code.extend(self.generate_statement(s)?); + } + } + } + + Ok(code) + } + + // Example: Generate code for an expression + // Returns (register containing result, assembly code) + fn generate_expression( + &mut self, + expr: &Expression, + ) -> Result<(String, Vec), String> { + let mut code = Vec::new(); + + match expr { + Expression::Number { value } => { + let (reg, alloc_code) = self.allocator.alloc_temp()?; + code.extend(alloc_code); + + // Load immediate value + code.push(format!("\tlli {}, {}", value & 0xFFFF, reg)); + if *value > 0xFFFF || *value < 0 { + code.push(format!("\tlui {}, {}", (value >> 16) & 0xFFFF, reg)); + } + + Ok((reg, code)) + } + + Expression::Variable { name, .. } => { + let (reg, load_code) = self.allocator.load_var(name)?; + code.extend(load_code); + Ok((reg, code)) + } + + Expression::Binary { op, left, right } => { + // Evaluate left operand + let (left_reg, left_code) = self.generate_expression(left)?; + code.extend(left_code); + + // Evaluate right operand + let (right_reg, right_code) = self.generate_expression(right)?; + code.extend(right_code); + + // Allocate result register + let (result_reg, result_alloc) = self.allocator.alloc_temp()?; + code.extend(result_alloc); + + // Generate operation + match op { + BinaryOperator::Add => { + code.push(format!( + "\tadd {}, {}, {}", + left_reg, right_reg, result_reg + )); + } + BinaryOperator::Sub => { + code.push(format!( + "\tsub {}, {}, {}", + left_reg, right_reg, result_reg + )); + } + BinaryOperator::Mul => { + self.include("maths", "./lib/maths/core.dsa"); + // Call multiply function + code.push(format!("\tpush {}", right_reg)); + code.push(format!("\tpush {}", left_reg)); + code.push("\tcall maths::multiply".to_string()); + code.push(format!("\tpop {}", result_reg)); + code.push("\tpop zero".to_string()); + } + // Comparison operators - return 1 (true) or 0 (false) + BinaryOperator::Eq => { + code.push(format!("\tcmp {}, {}", left_reg, right_reg)); + code.push(format!("\tlli 0, {}", result_reg)); + let end_label = format!("_cmp_end_{}", self.get_unique_label()); + code.push(format!("\tjne {}", end_label)); // If not equal, skip setting to 1 + code.push(format!("\tlli 1, {}", result_reg)); + code.push(format!("{}:", end_label)); + } + BinaryOperator::Ne => { + code.push(format!("\tcmp {}, {}", left_reg, right_reg)); + code.push(format!("\tlli 0, {}", result_reg)); + let end_label = format!("_cmp_end_{}", self.get_unique_label()); + code.push(format!("\tjeq {}", end_label)); // If equal, skip setting to 1 + code.push(format!("\tlli 1, {}", result_reg)); + code.push(format!("{}:", end_label)); + } + BinaryOperator::Lt => { + code.push(format!("\tcmp {}, {}", left_reg, right_reg)); + code.push(format!("\tlli 0, {}", result_reg)); + let end_label = format!("_cmp_end_{}", self.get_unique_label()); + code.push(format!("\tjge {}", end_label)); // If greater or equal, skip setting to 1 + code.push(format!("\tlli 1, {}", result_reg)); + code.push(format!("{}:", end_label)); + } + BinaryOperator::Le => { + code.push(format!("\tcmp {}, {}", left_reg, right_reg)); + code.push(format!("\tlli 0, {}", result_reg)); + let end_label = format!("_cmp_end_{}", self.get_unique_label()); + code.push(format!("\tjgt {}", end_label)); // If greater than, skip setting to 1 + code.push(format!("\tlli 1, {}", result_reg)); + code.push(format!("{}:", end_label)); + } + BinaryOperator::Gt => { + code.push(format!("\tcmp {}, {}", left_reg, right_reg)); + code.push(format!("\tlli 0, {}", result_reg)); + let end_label = format!("_cmp_end_{}", self.get_unique_label()); + code.push(format!("\tjle {}", end_label)); // If less or equal, skip setting to 1 + code.push(format!("\tlli 1, {}", result_reg)); + code.push(format!("{}:", end_label)); + } + BinaryOperator::Ge => { + code.push(format!("\tcmp {}, {}", left_reg, right_reg)); + code.push(format!("\tlli 0, {}", result_reg)); + let end_label = format!("_cmp_end_{}", self.get_unique_label()); + code.push(format!("\tjlt {}", end_label)); // If less than, skip setting to 1 + code.push(format!("\tlli 1, {}", result_reg)); + code.push(format!("{}:", end_label)); + } + _ => return Err(format!("Unsupported binary operator: {:?}", op)), + } + + // Free operand registers + self.allocator.free_temp(&left_reg); + self.allocator.free_temp(&right_reg); + + Ok((result_reg, code)) + } + + Expression::Call { name, args } => { + // Save caller-saved registers + let save_code = self.allocator.save_caller_saved(); + code.extend(save_code); + + // Evaluate and push arguments in reverse order + let mut arg_regs = Vec::new(); + for arg in args.iter().rev() { + let (arg_reg, arg_code) = self.generate_expression(arg)?; + code.extend(arg_code); + code.push(format!("\tpush {}", arg_reg)); + arg_regs.push(arg_reg); + } + + if self.functions.contains_key(name) { + // Call local function + code.push(format!("\tcall {}", name)); + } + + if self.imports + + + + // Clean up arguments + for _ in 0..args.len() { + code.push("\tpop zero".to_string()); + } + + // Free argument registers + for reg in arg_regs { + self.allocator.free_temp(®); + } + + // Result is in rg0, allocate a register and move it + let (result_reg, result_alloc) = self.allocator.alloc_temp()?; + code.extend(result_alloc); + + if result_reg != "rg0" { + code.push(format!("\tmov rg0, {}", result_reg)); + } + + // Restore caller-saved registers (simplified - you'd track which ones) + + Ok((result_reg, code)) + } + + Expression::Unary { op, operand } => { + let (operand_reg, operand_code) = self.generate_expression(operand)?; + code.extend(operand_code); + + let (result_reg, result_alloc) = self.allocator.alloc_temp()?; + code.extend(result_alloc); + + match op { + UnaryOperator::Minus => { + // Negate: result = 0 - operand + code.push(format!("\tsub zero, {}, {}", operand_reg, result_reg)); + } + UnaryOperator::Plus => { + // Just move + code.push(format!("\tmov {}, {}", operand_reg, result_reg)); + } + } + + self.allocator.free_temp(&operand_reg); + Ok((result_reg, code)) + } + + Expression::Empty => Ok(("zero".to_string(), code)), + } + } + + // Helper for generating unique labels + fn get_unique_label(&mut self) -> String { + // You'd implement a counter here + static COUNTER: AtomicU32 = AtomicU32::new(0); + + let val = COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + (val + 1).to_string() + } } /// Build a single string from any number of arguments. diff --git a/c_compiler/src/main.rs b/c_compiler/src/main.rs index 812d1ba..06758c8 100644 --- a/c_compiler/src/main.rs +++ b/c_compiler/src/main.rs @@ -5,6 +5,7 @@ use crate::{codegen::CodeGenerator, lexer::Lexer, parser::Parser}; pub mod codegen; pub mod lexer; pub mod parser; +mod registers; // ============================================================================ // Main & Tests diff --git a/c_compiler/src/parser.rs b/c_compiler/src/parser.rs index b89aaa1..734cb04 100644 --- a/c_compiler/src/parser.rs +++ b/c_compiler/src/parser.rs @@ -17,7 +17,7 @@ pub enum Declaration { name: String, return_type: Type, params: Vec, - body: Statement, + body: Block, }, Variable { name: String, @@ -44,11 +44,11 @@ pub enum Type { Struct(String), } +pub type Block = Vec; + #[derive(Debug, Clone)] pub enum Statement { - Compound { - statements: Vec, - }, + Block(Block), Assign { // left side name: String, @@ -62,12 +62,12 @@ pub enum Statement { }, If { condition: Expression, - then_stmt: Box, - else_stmt: Option>, + then_stmt: Block, + else_stmt: Block, }, While { condition: Expression, - body: Box, + body: Vec, }, Return { expr: Option, @@ -271,7 +271,7 @@ impl Parser { } self.expect(TokenType::RParen)?; - let body = self.parse_compound_stmt()?; + let body = self.parse_block()?; Ok(Declaration::Function { name, @@ -302,7 +302,7 @@ impl Parser { } } - fn parse_compound_stmt(&mut self) -> Result { + fn parse_block(&mut self) -> Result { self.expect(TokenType::LBrace)?; let mut statements = Vec::new(); @@ -311,12 +311,12 @@ impl Parser { } self.expect(TokenType::RBrace)?; - Ok(Statement::Compound { statements }) + Ok(statements) } fn parse_statement(&mut self) -> Result { match &self.current().token_type { - TokenType::LBrace => Ok(self.parse_compound_stmt()?), + TokenType::LBrace => Ok(Statement::Block(self.parse_block()?)), TokenType::If => self.parse_if_stmt(), TokenType::While => self.parse_while_stmt(), TokenType::Return => self.parse_return_stmt(), @@ -408,13 +408,13 @@ impl Parser { self.expect(TokenType::LParen)?; let condition = self.parse_expression()?; self.expect(TokenType::RParen)?; - let then_stmt = Box::new(self.parse_statement()?); + let then_stmt = self.parse_block()?; let else_stmt = if matches!(self.current().token_type, TokenType::Else) { self.advance(); - Some(Box::new(self.parse_statement()?)) + self.parse_block()? } else { - None + Vec::new() }; Ok(Statement::If { @@ -429,7 +429,7 @@ impl Parser { self.expect(TokenType::LParen)?; let condition = self.parse_expression()?; self.expect(TokenType::RParen)?; - let body = Box::new(self.parse_statement()?); + let body = self.parse_block()?; Ok(Statement::While { condition, body }) } diff --git a/c_compiler/src/registers.rs b/c_compiler/src/registers.rs new file mode 100644 index 0000000..d13babd --- /dev/null +++ b/c_compiler/src/registers.rs @@ -0,0 +1,324 @@ +use std::collections::HashMap; + +/// Register allocator for DSA assembly generation +/// Manages general-purpose registers (rg0-rgf) and handles stack spilling +pub struct RegisterAllocator { + /// Available general-purpose registers + available_registers: Vec, + + /// Maps variable names to their current location (register or stack offset) + variable_locations: HashMap, + + /// Maps registers to the variables they currently hold + register_contents: HashMap, + + /// Current stack offset for local variables (relative to bpr) + /// Starts at -4 (going downward from base pointer) + stack_offset: i32, + + /// Track which registers are currently in use + in_use: HashMap, +} + +#[derive(Debug, Clone)] +pub enum Location { + Register(String), + Stack(i32), // offset from bpr +} + +impl RegisterAllocator { + pub fn new() -> Self { + // Initialize with available GP registers (rg0-rgf = 16 registers) + let registers = vec![ + "rg0", "rg1", "rg2", "rg3", "rg4", "rg5", "rg6", "rg7", "rg8", "rg9", "rga", + "rgb", "rgc", "rgd", "rge", "rgf", + ] + .into_iter() + .map(String::from) + .collect(); + + RegisterAllocator { + available_registers: registers, + variable_locations: HashMap::new(), + register_contents: HashMap::new(), + stack_offset: -4, // Start at -4 (first local below saved bpr) + in_use: HashMap::new(), + } + } + + /// Allocate a temporary register for expression evaluation + /// Returns the register name and optionally assembly code to save it + pub fn alloc_temp(&mut self) -> Result<(String, Vec), String> { + let mut code = Vec::new(); + + // Try to find an unused register + for reg in &self.available_registers { + if !self.in_use.get(reg).unwrap_or(&false) { + self.in_use.insert(reg.clone(), true); + return Ok((reg.clone(), code)); + } + } + + // All registers in use - need to spill one + // Choose the first register with a variable we can spill + // Find a register to spill + let reg_to_spill = self + .available_registers + .iter() + .find(|reg| self.register_contents.contains_key(*reg)) + .cloned(); + + if let Some(reg) = reg_to_spill { + // Spill this variable to stack + let spill_code = self.spill_register(®)?; + code.extend(spill_code); + + self.in_use.insert(reg.clone(), true); + return Ok((reg, code)); + } + + Err("No registers available and nothing to spill".to_string()) + } + + /// Free a temporary register after use + pub fn free_temp(&mut self, reg: &str) { + self.in_use.insert(reg.to_string(), false); + } + + /// Allocate a register for a named variable + /// Returns the register and any necessary assembly code + pub fn alloc_var(&mut self, var_name: &str) -> Result<(String, Vec), String> { + // Check if variable already has a location + if let Some(location) = self.variable_locations.get(var_name).cloned() { + match location { + Location::Register(reg) => { + return Ok((reg.clone(), Vec::new())); + } + Location::Stack(offset) => { + // Variable is on stack, load it into a register + let (reg, mut code) = self.alloc_temp()?; + code.push(format!("\tldw bpr, {}, {}", reg, offset)); + + // Update location to register + self.variable_locations + .insert(var_name.to_string(), Location::Register(reg.clone())); + self.register_contents + .insert(reg.clone(), var_name.to_string()); + + return Ok((reg, code)); + } + } + } + + // Variable doesn't have a location yet, allocate a new register + let (reg, code) = self.alloc_temp()?; + self.variable_locations + .insert(var_name.to_string(), Location::Register(reg.clone())); + self.register_contents + .insert(reg.clone(), var_name.to_string()); + + Ok((reg, code)) + } + + /// Get the current location of a variable + pub fn get_var_location(&self, var_name: &str) -> Option<&Location> { + self.variable_locations.get(var_name) + } + + /// Load a variable into a register (allocating if necessary) + /// Returns the register and assembly code to load it + pub fn load_var(&mut self, var_name: &str) -> Result<(String, Vec), String> { + self.alloc_var(var_name) + } + + /// Store a value from a register into a variable + /// Updates tracking and returns any necessary assembly code + pub fn store_var(&mut self, var_name: &str, source_reg: &str) -> Vec { + let mut code = Vec::new(); + + // Check if variable already has a location + if let Some(location) = self.variable_locations.get(var_name) { + match location { + Location::Register(dest_reg) => { + if dest_reg != source_reg { + code.push(format!("\tmov {}, {}", source_reg, dest_reg)); + } + } + Location::Stack(offset) => { + code.push(format!("\tstw {}, bpr, {}", source_reg, offset)); + } + } + } else { + // Variable doesn't exist yet - try to allocate a register + if let Some(free_reg) = self.find_free_register() { + if &free_reg != source_reg { + code.push(format!("\tmov {}, {}", source_reg, free_reg)); + } + self.variable_locations + .insert(var_name.to_string(), Location::Register(free_reg.clone())); + self.register_contents + .insert(free_reg.clone(), var_name.to_string()); + self.in_use.insert(free_reg, true); + } else { + // No free registers - allocate on stack + code.push(format!("\tstw {}, bpr, {}", source_reg, self.stack_offset)); + self.variable_locations + .insert(var_name.to_string(), Location::Stack(self.stack_offset)); + self.stack_offset -= 4; // Move to next stack slot + } + } + + code + } + + /// Spill a register to the stack + /// Returns assembly code to perform the spill + fn spill_register(&mut self, reg: &str) -> Result, String> { + let mut code = Vec::new(); + + if let Some(var_name) = self.register_contents.get(reg).cloned() { + // Store register content to stack + code.push(format!("\tstw {}, bpr, {}", reg, self.stack_offset)); + + // Update variable location + self.variable_locations + .insert(var_name.clone(), Location::Stack(self.stack_offset)); + + // Remove from register tracking + self.register_contents.remove(reg); + + // Move to next stack slot + self.stack_offset -= 4; + } + + Ok(code) + } + + /// Find a free register (not currently in use) + fn find_free_register(&self) -> Option { + for reg in &self.available_registers { + if !self.in_use.get(reg).unwrap_or(&false) { + return Some(reg.clone()); + } + } + None + } + + /// Spill all registers to stack (useful before function calls) + pub fn spill_all(&mut self) -> Vec { + let mut code = Vec::new(); + + let regs_to_spill: Vec = self.register_contents.keys().cloned().collect(); + + for reg in regs_to_spill { + if let Ok(spill_code) = self.spill_register(®) { + code.extend(spill_code); + } + } + + code + } + + /// Get the total stack space needed for local variables + pub fn get_stack_size(&self) -> i32 { + -self.stack_offset // Convert negative offset to positive size + } + + /// Reset allocator for a new function + pub fn reset(&mut self) { + self.variable_locations.clear(); + self.register_contents.clear(); + self.stack_offset = -4; + self.in_use.clear(); + } + + /// Mark a variable as dead (no longer needed) + /// Frees its register if it's in one + pub fn free_var(&mut self, var_name: &str) { + if let Some(Location::Register(reg)) = self.variable_locations.get(var_name) { + let reg = reg.clone(); + self.register_contents.remove(®); + self.in_use.insert(reg, false); + } + self.variable_locations.remove(var_name); + } + + /// Save caller-saved registers before a function call + /// Returns assembly code to save them + pub fn save_caller_saved(&mut self) -> Vec { + let mut code = Vec::new(); + + // For simplicity, save all currently used registers + // In a more sophisticated compiler, you'd only save registers that are live + for (reg, var_name) in self.register_contents.clone() { + if *self.in_use.get(®).unwrap_or(&false) { + code.push(format!("\tpush {}", reg)); + } + } + + code + } + + /// Restore caller-saved registers after a function call + /// Returns assembly code to restore them + pub fn restore_caller_saved(&mut self, saved_regs: &[String]) -> Vec { + let mut code = Vec::new(); + + // Restore in reverse order (LIFO) + for reg in saved_regs.iter().rev() { + code.push(format!("\tpop {}", reg)); + } + + code + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_allocation() { + let mut allocator = RegisterAllocator::new(); + + let (reg1, code1) = allocator.alloc_temp().unwrap(); + assert_eq!(code1.len(), 0); // No spill needed + assert_eq!(reg1, "rg0"); + + let (reg2, code2) = allocator.alloc_temp().unwrap(); + assert_eq!(code2.len(), 0); + assert_eq!(reg2, "rg1"); + + allocator.free_temp(®1); + + let (reg3, code3) = allocator.alloc_temp().unwrap(); + assert_eq!(code3.len(), 0); + assert_eq!(reg3, "rg0"); // Reuses freed register + } + + #[test] + fn test_variable_allocation() { + let mut allocator = RegisterAllocator::new(); + + let (reg, _) = allocator.alloc_var("x").unwrap(); + assert_eq!(reg, "rg0"); + + // Requesting same variable again should return same register + let (reg2, _) = allocator.alloc_var("x").unwrap(); + assert_eq!(reg2, "rg0"); + } + + #[test] + fn test_stack_allocation() { + let mut allocator = RegisterAllocator::new(); + + // Allocate all 16 registers + for i in 0..16 { + allocator.alloc_var(&format!("var{}", i)).unwrap(); + } + + // Next allocation should spill to stack + let (reg, code) = allocator.alloc_var("var16").unwrap(); + assert!(code.len() > 0); // Should have spill code + } +}