data Op = Constant Double | Var Double | ReLU Op | Add Op Op | Multiply Op Op | Negate Op deriving (Eq, Show) instance Num Op where (+) = Add (*) = Multiply negate = Negate fromInteger n = Var (fromInteger n) evaluate :: Op -> Double evaluate op = case op of Constant x -> x Var x -> x Negate x -> negate (evaluate x) ReLU x -> max 0 (evaluate x) Add x y -> evaluate x + evaluate y Multiply x y -> evaluate x * evaluate y backward :: Op -> Double -> [(Op, Double)] backward op grad = case op of Constant _ -> [] Var _ -> [(op, grad)] Negate x -> backward x (negate grad) ReLU x -> let xVal = evaluate x reluGrad = if xVal > 0 then grad else 0 in backward x reluGrad Add x y -> backward x grad ++ backward y grad Multiply x y -> let xVal = evaluate x yVal = evaluate y in backward x (grad * yVal) ++ backward y (grad * xVal) main :: IO () main = do let a :: Op = -4 let b :: Op = 2 let c = a * b + Constant 3 putStrLn $ show $ a putStrLn $ show $ backward c 1.0