Expressions
Next, let us consider another example of a recursively defined data structure: symbolic arithmetic expressions.
Arithmetic expressions in their symbolic form are used by compilers (including the Scala compiler) internally to manipulate and optimize user code. Similarly, mathematical software (such as Wolfram Mathematica) works with symbolic expressions for simplification and other algebraic manipulation, such as taking the derivative of an expression.
More generally, manipulation of mathematical expressions in their symbolic form is called symbolic computation (as opposed to numerical computation).
Let us start and give a recursive definition for simple mathematical expressions in symbolic form:
A constant (such as \(3.0\)) is an expression.
A variable (such as \(x\)) is an expression
If \(e_1\) and \(e_2\) are expressions, then \(-e_1\), \((e_1 + e_2)\), \((e_1 - e_2)\), and \((e_1 * e_2)\) are also expressions.
There are no other expressions.
By the above rules, the following are examples of expressions: \(2.0\), \(x\), \((2.0 * x)\), \(y\), \(((2.0 * x)+y)\), and \(-((2.0 * x)+y)\).
Based on the definition, we can build the corresponding Scala classes.
Again, we define an immutable data structure and now we will use case classes
right from the beginning.
(This source code is provided in the Expressions
programming assignment,
where your task is to extend the classes with new operations.)
/**
* The abstract base class for expressions.
* Expressions are immutable.
*/
abstract class Expr {
/* Overriding these operators enable us to construct expressions with infix notation */
def +(other: Expr) = Add(this, other)
def -(other: Expr) = Subtract(this, other)
def *(other: Expr) = Multiply(this, other)
def unary_- = Negate(this)
}
/** Variable expression, like "x" or "y". */
case class Var(name: String) extends Expr {
override def toString = name
}
/** Constant expression like "2.1" or "-1.0" */
case class Num(value: Double) extends Expr {
override def toString = value.toString
}
/** Expression formed by multiplying two expressions, like "x * (y+3)" */
case class Multiply(left: Expr, right: Expr) extends Expr {
override def toString = "(" + left + "*" + right + ")"
}
/** Expression formed by adding two expressions, like "x + (y*3)" */
case class Add(left: Expr, right: Expr) extends Expr {
override def toString = "(" + left + " + " + right + ")"
}
/** Expression formed by subtracting an expression from another, "x - (y*3)" */
case class Subtract(left: Expr, right: Expr) extends Expr {
override def toString = "(" + left + " - " + right + ")"
}
/** Negation of an expression, like "-((3*y)+z)" */
case class Negate(p: Expr) extends Expr {
override def toString = "-" + p
}
We can now construct the expression \(−((2.0∗x)+y)\) with:
val e = Negate(Add(Multiply(Num(2.0),Var("x")),Var("y")))
Or, by using the overridden operators, with:
val e = -((Num(2.0) * Var("x")) + Var("y"))
In both cases, the objects in memory look like this:
The file package.scala
in the source code of the Expressions
assignment includes the following implicit conversions:
implicit def doubleToNum(v: Double) = Num(v)
implicit def intToNum(v: Int) = Num(v)
These conversions allow us to write the above expression even more compactly:
val e = -(2.0 * Var("x") + Var("y"))
You can read more on implicit conversions and parameters in Programming in Scala, 1st edition.
Evaluation
We now have the basic data structure for expressions. Let us next implement a method that allows us to evaluate the value of an expression when the variables in the expression are given some values.
Again, the evaluation proceeds recursively: to evaluate the value of an expression, we must first evaluate the values of its sub-expressions, unless the expression is a constant or a variable.
/**
* The exception returned by "evaluate" when
* a variable in the expression is not assigned a value.
*/
class VariableNotAssignedException(message: String) extends java.lang.RuntimeException(message)
/**
* Evaluate the expression in the point p, where p is a map
* associating each variable name in the expression into a Double value.
*/
def evaluate(p: Map[String, Double]): Double = this match {
case Var(n) => p.get(n) match {
case Some(v) => v
case None => throw new VariableNotAssignedException(n)
}
case Num(v) => v
case Multiply(l, r) => l.evaluate(p) * r.evaluate(p)
case Add(l, r) => l.evaluate(p) + r.evaluate(p)
case Subtract(l, r) => l.evaluate(p) - r.evaluate(p)
case Negate(t) => -t.evaluate(p)
}
Observe that we have also defined our own exception class,
allowing the method evaluate
to signal the erroneous situation
where not all variables have been assigned a value.
We can now test the method:
scala> val e = -((Num(2.0) * Var("x")) + Var("y"))
e: expressions.Negate = -((2.0*x) + y)
scala> e.evaluate(Map("x" -> 2.0, "y" -> -3.5))
res: Double = -0.5
Simplification
When manipulating expressions, it would of course be very nice if they are in as simple as possible form. For example, instead of
we naturally prefer using the much simpler
We can get this simpler form by first applying the simplification rule
and then
The process of transforming expressions into simpler (but equivalent!) form is, unsurprisingly, called simplification. We next show how to do some elementary simplification recursively by using case classes and pattern matching. This is where case classes really start to show their power. So far, we could have done without them but now pattern matching enables us to write very concise and readable code:
/**
* Simplifies the expression with some simple rules like "1.0*r equals r".
*/
def simplify: Expr = {
// First, (recursively) simplify sub-expressions
val subresult = this match {
case Multiply(l, r) => Multiply(l.simplify, r.simplify)
case Add(l, r) => Add(l.simplify, r.simplify)
case Subtract(l, r) => Subtract(l.simplify, r.simplify)
case Negate(t) => Negate(t.simplify)
case _ => this // Handles Var and Num
}
// Then simplify this sub-expression by applying some simple rules
subresult match {
case Multiply(Num(1.0), r) => r
case Multiply(l, Num(1.0)) => l
case Multiply(Num(0.0), _) => Num(0.0)
case Multiply(_, Num(0.0)) => Num(0.0)
case Multiply(Num(v1), Num(v2)) => Num(v1 * v2)
case Add(Num(0.0), r) => r
case Add(l, Num(0.0)) => l
case Add(Num(v1), Num(v2)) => Num(v1 + v2)
case Subtract(Num(0.0), r) => Negate(r)
case Subtract(l, Num(0.0)) => l
case Subtract(Num(v1), Num(v2)) => Num(v1 - v2)
case Negate(Num(v)) => Num(-v)
case _ => subresult // Could not simplify
}
}
Observe the two phases in the function: the sub-expressions of an expression are (recursively) simplified first and only after that is the expression itself considered. For instance, in our example the sub-expression \(1.0 * x\) is simplified to \(x\) and then the new version \(x + 0.0\) of the sub-expression \((1.0 * x) + 0.0\) is simplified into \(x\).
Now we get, as expected:
scala> val e = Var("y")* ((Num(1.0) * Var("x")) + Num(0.0))
e: expressions.Multiply = (y*((1.0*x) + 0.0))
scala> e.simplify
res: expressions.Expr = (y*x)
Alternative (bad) approaches
As a final note here, let us observe that although case classes are very convenient, they are only “syntactic sugar” in the sense that the same things can be done without them.
If we would have liked to write the simplification function
in the abstract base class Expr
without using case classes,
we could have used the isInstanceOf
and asInstanceOf
methods
of scala.Any class.
However, this gets very ugly:
abstract class Expr {
def simplify: Expr = {
...
if(subresult.isInstanceOf[Multiply] &&
subresult.asInstanceOf[Multiply].left.isInstanceOf[Num] &&
subresult.asInstanceOf[Multiply].left.asInstanceOf[Num].value == 1.0)
subresult.asInstanceOf[Multiply].right
...
}
}
Another, almost as bad approach would be to implement type inspection and field-getting methods in all the sub-classes:
abstract class Expr {
def simplify: Expr = {
...
if(subresult.isMultiply && subresult.getLeft.isNum && subresult.getLeft.getValue == 1.0)
subresult.getRight
...
}
}
case class Multiply(left: Expr, right: Expr) extends Expr {
override def toString = "(" + left + "*" + right + ")"
def isMultiply = true
def isNum = false
...
def getLeft = left
def getRight = right
def getValue = throw new java.util.NoSuchElementException("value of Multiply")
}
...
Compared with the convenience of case classes, it is immediate that case classes result in far more readable and maintainable code.