use std::collections::HashMap; use crate::model::{ BinaryOperator, // You'll need to add this to your imports CompilerError, Declaration, Dependency, Expression, Program, TypeId, UnaryOperator, }; pub struct Analyser { symbol_table: HashMap, } const NUMERIC_TYPES: &[TypeId] = &[ TypeId::U32, TypeId::I32, TypeId::I16, TypeId::U16, TypeId::I8, TypeId::U8, ]; impl Analyser { pub fn new() -> Self { Self { symbol_table: HashMap::new(), } } pub fn analyse(&mut self, ast: Program) -> Result<(), CompilerError> { // build table of global symbols. for dec in ast.declarations { let name = match dec.clone() { Declaration::Function { name, .. } => name, Declaration::Variable { var, .. } => var.name, Declaration::Dependency(Dependency { name, .. }) => name, }; self.symbol_table.insert(name, dec); } Ok(()) } fn match_type( actual: TypeId, expected: Option, ) -> Result { match expected { Some(id) => { if id != actual { Err(CompilerError::TypeMismatch(id, actual)) } else { Ok(actual) } } None => Ok(actual), } } fn get_type( &mut self, // Changed from &self to &mut self since we modify expr expr: &mut Expression, expected_type: Option, ) -> Result { match expr { // Correct IFF we're expecting a void type Expression::Empty => Self::match_type(TypeId::Void, expected_type), // Correct IFF we're expecting a char type Expression::CharLiteral(_) => Self::match_type(TypeId::Char, expected_type), // Correct IFF we're expecting a string slice type Expression::StringLiteral(_) => { Self::match_type(TypeId::Ptr(Box::new(TypeId::Char)), expected_type) } Expression::Variable { name, expr_type } => { let actual = expr_type.clone().ok_or(CompilerError::UnknownType)?; Self::match_type(actual, expected_type) } Expression::Number { value, type_id } => { // If we already know the TypeId if let Some(id) = type_id { return Self::match_type(id.clone(), expected_type); } // If we're expecting a type id, check it's numeric. // TODO: add checks to make sure it's valid for its size eg u8 cant be // more than 255 if let Some(expected) = expected_type { if NUMERIC_TYPES.contains(&expected) { *type_id = Some(expected.clone()); return Ok(expected); } else { return Err(CompilerError::TypeMismatch(expected, TypeId::U32)); } } // Default to i32 if no type information is available *type_id = Some(TypeId::I32); Ok(TypeId::I32) } Expression::Binary { op, left, right, type_id, } => { // For binary operations, both operands should have compatible types // and the result type depends on the operation let left_type = self.get_type(left, None)?; let right_type = self.get_type(right, Some(left_type.clone()))?; // For numeric operations, result has the same type as operands if NUMERIC_TYPES.contains(&left_type) && NUMERIC_TYPES.contains(&right_type) { *type_id = Some(left_type); Self::match_type(left_type, expected_type) } else { Err(CompilerError::TypeMismatch(left_type, right_type)) } } Expression::Unary { op, operand, type_id, } => { match op { UnaryOperator::Plus | UnaryOperator::Minus => { // Unary +/- require numeric operands let inner_type = self.get_type(operand, None)?; if NUMERIC_TYPES.contains(&inner_type) { *type_id = Some(inner_type.clone()); Self::match_type(inner_type, expected_type) } else { Err(CompilerError::TypeMismatch(inner_type, TypeId::I32)) } } UnaryOperator::Dereference => { // For dereference (*ptr), the operand must be a pointer // and the result type is what the pointer points to let inner_type = self.get_type(operand, None)?; match inner_type { TypeId::Ptr(inner) => { let deref_type = *inner; *type_id = Some(deref_type.clone()); Self::match_type(deref_type, expected_type) } _ => Err(CompilerError::Generic(format!( "Cannot dereference non-pointer type: {:?}", inner_type ))), } } UnaryOperator::Reference => { // For reference (&var), we need to determine what we're taking // a reference to, then wrap it in a Ptr // If expected_type is Ptr(T), then operand should have type T let expected_inner = match expected_type.clone() { Some(TypeId::Ptr(inner)) => Some(*inner), _ => None, }; let inner_type = self.get_type(operand, expected_inner)?; let ref_type = TypeId::Ptr(Box::new(inner_type)); *type_id = Some(ref_type.clone()); Self::match_type(ref_type, expected_type) } } } Expression::Call { name, args, type_id, } => match self.symbol_table.get(&name.name) { Some(Declaration::Function { params, return_type, .. }) => { // check that we've given the right number of arguments. if args.len() != params.len() { return Err(CompilerError::Generic(format!( "Function {} expected {} arguments but received {}", name.name, params.len(), args.len() ))); } for (arg, param) in args.iter_mut().zip(params.iter()) { // check that the argument type matches the parameter type. let provided_type = self.get_type(arg, Some(param.type_id))?; if provided_type != param.type_id { return Err(CompilerError::TypeMismatch( param.type_id, provided_type, )); } } *type_id = Some(return_type.clone()); Self::match_type(return_type.clone(), expected_type) } _ => Err(CompilerError::Generic(format!( "Function {} not found in symbol table", name.name ))), }, } } }