Expressions

Next, let us consider another example of a recursively defined data structure: symbolic arithmetic expressions.

Arithmetic expressions in their symbolic form are used by compilers (including the Scala compiler) internally to manipulate and optimize user code. Similarly, mathematical software (such as Wolfram Mathematica) works with symbolic expressions for simplification and other algebraic manipulation, such as taking the derivative of an expression.

More generally, manipulation of mathematical expressions in their symbolic form is called symbolic computation (as opposed to numerical computation).

Let us start and give a recursive definition for simple mathematical expressions in symbolic form:

  1. A constant (such as \(3.0\)) is an expression.

  2. A variable (such as \(x\)) is an expression

  3. If \(e_1\) and \(e_2\) are expressions, then \(-e_1\), \((e_1 + e_2)\), \((e_1 - e_2)\), and \((e_1 * e_2)\) are also expressions.

  4. There are no other expressions.

By the above rules, the following are examples of expressions: \(2.0\), \(x\), \((2.0 * x)\), \(y\), \(((2.0 * x)+y)\), and \(-((2.0 * x)+y)\).

Based on the definition, we can build the corresponding Scala classes. Again, we define an immutable data structure and now we will use case classes right from the beginning. (This source code is provided in the Expressions programming assignment, where your task is to extend the classes with new operations.)

/**
 * The abstract base class for expressions.
 * Expressions are immutable.
 */
abstract class Expr {
  /* Overriding these operators enable us to construct expressions with infix notation */
  def +(other: Expr) = Add(this, other)
  def -(other: Expr) = Subtract(this, other)
  def *(other: Expr) = Multiply(this, other)
  def unary_- = Negate(this)
}
/** Variable expression, like "x" or "y". */
case class Var(name: String) extends Expr {
  override def toString = name
}
/** Constant expression like "2.1" or "-1.0" */
case class Num(value: Double) extends Expr {
  override def toString = value.toString
}
/** Expression formed by multiplying two expressions, like "x * (y+3)" */
case class Multiply(left: Expr, right: Expr) extends Expr {
  override def toString = "(" + left + "*" + right + ")"
}
/** Expression formed by adding two expressions, like "x + (y*3)" */
case class Add(left: Expr, right: Expr) extends Expr {
  override def toString = "(" + left + " + " + right + ")"
}
/** Expression formed by subtracting an expression from another, "x - (y*3)" */
case class Subtract(left: Expr, right: Expr) extends Expr {
  override def toString = "(" + left + " - " + right + ")"
}
/** Negation of an expression, like "-((3*y)+z)" */
case class Negate(p: Expr) extends Expr {
  override def toString = "-" + p
}

We can now construct the expression \(−((2.0∗x)+y)\) with:

val e = Negate(Add(Multiply(Num(2.0),Var("x")),Var("y")))

Or, by using the overridden operators, with:

val e = -((Num(2.0) * Var("x")) + Var("y"))

In both cases, the objects in memory look like this:

_images/exprs-objects.svg

The file package.scala in the source code of the Expressions assignment includes the following implicit conversions:

implicit def doubleToNum(v: Double) = Num(v)
implicit def intToNum(v: Int) = Num(v)

These conversions allow us to write the above expression even more compactly:

val e = -(2.0 * Var("x") + Var("y"))

You can read more on implicit conversions and parameters in Programming in Scala, 1st edition.

Evaluation

We now have the basic data structure for expressions. Let us next implement a method that allows us to evaluate the value of an expression when the variables in the expression are given some values.

Again, the evaluation proceeds recursively: to evaluate the value of an expression, we must first evaluate the values of its sub-expressions, unless the expression is a constant or a variable.

  /**
   * The exception returned by "evaluate" when
   * a variable in the expression is not assigned a value.
   */
  class VariableNotAssignedException(message: String) extends java.lang.RuntimeException(message)

  /**
   * Evaluate the expression in the point p, where p is a map
   * associating each variable name in the expression into a Double value.
   */
  def evaluate(p: Map[String, Double]): Double = this match {
    case Var(n) => p.get(n) match {
      case Some(v) => v
      case None => throw new VariableNotAssignedException(n)
    }
    case Num(v) => v
    case Multiply(l, r) => l.evaluate(p) * r.evaluate(p)
    case Add(l, r) => l.evaluate(p) + r.evaluate(p)
    case Subtract(l, r) => l.evaluate(p) - r.evaluate(p)
    case Negate(t) => -t.evaluate(p)
  }

Observe that we have also defined our own exception class, allowing the method evaluate to signal the erroneous situation where not all variables have been assigned a value.

We can now test the method:

scala> val e = -((Num(2.0) * Var("x")) + Var("y"))
e: expressions.Negate = -((2.0*x) + y)

scala> e.evaluate(Map("x" -> 2.0, "y" -> -3.5))
res: Double = -0.5

Simplification

When manipulating expressions, it would of course be very nice if they are in as simple as possible form. For example, instead of

\[y * ((1.0 * x) + 0.0)\]

we naturally prefer using the much simpler

\[y * x\,.\]

We can get this simpler form by first applying the simplification rule

\[1.0 * z = z\]

and then

\[z + 0.0 = z\,.\]

The process of transforming expressions into simpler (but equivalent!) form is, unsurprisingly, called simplification. We next show how to do some elementary simplification recursively by using case classes and pattern matching. This is where case classes really start to show their power. So far, we could have done without them but now pattern matching enables us to write very concise and readable code:

  /**
   * Simplifies the expression with some simple rules like "1.0*r equals r".
   */
  def simplify: Expr = {
    // First, (recursively) simplify sub-expressions
    val subresult = this match {
      case Multiply(l, r) => Multiply(l.simplify, r.simplify)
      case Add(l, r) => Add(l.simplify, r.simplify)
      case Subtract(l, r) => Subtract(l.simplify, r.simplify)
      case Negate(t) => Negate(t.simplify)
      case _ => this // Handles Var and Num
    }
    // Then simplify this sub-expression by applying some simple rules
    subresult match {
      case Multiply(Num(1.0), r) => r
      case Multiply(l, Num(1.0)) => l
      case Multiply(Num(0.0), _) => Num(0.0)
      case Multiply(_, Num(0.0)) => Num(0.0)
      case Multiply(Num(v1), Num(v2)) => Num(v1 * v2)
      case Add(Num(0.0), r) => r
      case Add(l, Num(0.0)) => l
      case Add(Num(v1), Num(v2)) => Num(v1 + v2)
      case Subtract(Num(0.0), r) => Negate(r)
      case Subtract(l, Num(0.0)) => l
      case Subtract(Num(v1), Num(v2)) => Num(v1 - v2)
      case Negate(Num(v)) => Num(-v)
      case _ => subresult // Could not simplify
    }
  }

Observe the two phases in the function: the sub-expressions of an expression are (recursively) simplified first and only after that is the expression itself considered. For instance, in our example the sub-expression \(1.0 * x\) is simplified to \(x\) and then the new version \(x + 0.0\) of the sub-expression \((1.0 * x) + 0.0\) is simplified into \(x\).

Now we get, as expected:

scala> val e = Var("y")* ((Num(1.0) * Var("x")) + Num(0.0))
e: expressions.Multiply = (y*((1.0*x) + 0.0))

scala> e.simplify
res: expressions.Expr = (y*x)

Alternative (bad) approaches

As a final note here, let us observe that although case classes are very convenient, they are only “syntactic sugar” in the sense that the same things can be done without them.

If we would have liked to write the simplification function in the abstract base class Expr without using case classes, we could have used the isInstanceOf and asInstanceOf methods of scala.Any class. However, this gets very ugly:

abstract class Expr {
  def simplify: Expr = {
    ...
    if(subresult.isInstanceOf[Multiply] &&
       subresult.asInstanceOf[Multiply].left.isInstanceOf[Num] &&
       subresult.asInstanceOf[Multiply].left.asInstanceOf[Num].value == 1.0)
      subresult.asInstanceOf[Multiply].right
   ...
  }
}

Another, almost as bad approach would be to implement type inspection and field-getting methods in all the sub-classes:

abstract class Expr {
  def simplify: Expr = {
    ...
    if(subresult.isMultiply && subresult.getLeft.isNum && subresult.getLeft.getValue == 1.0)
            subresult.getRight
    ...
  }
}
case class Multiply(left: Expr, right: Expr) extends Expr {
  override def toString = "(" + left + "*" + right + ")"
  def isMultiply = true
  def isNum = false
  ...
  def getLeft = left
  def getRight = right
  def getValue = throw new java.util.NoSuchElementException("value of Multiply")
}
...

Compared with the convenience of case classes, it is immediate that case classes result in far more readable and maintainable code.