Magic of State Monad


I have been doing Object Oriented Programming in Java for almost a decade. Recently for the past six months, I have been doing Functional Programming in Scala. This blog is about the reasoning behind why we need “State Monad”.

Modifying a variable – SIDE EFFECT

If “Design Patterns: Elements of Reusable Object-Oriented Software” is the bible for Object Oriented Design Patterns, I would say “Functional programming in Scala” book is the bible for learning Functional Programming.

“FP in scala” book starts by defining, “Functional Programming (FP) is based on a simple premise with far-reaching implications: we construct our programs using only pure functions – in other words, functions that have no side effects. What are side effects? A function has a side effect if it does something other than simply return a result, for example: Modifying a variable is a “SIDE EFFECT”.

When I first read that Modifying a variable is a “SIDE EFFECT”, I was completely surprised and puzzled, since we do it all the time in OOP. for eg., i++ is a side effecting operation. I was wondering how do we make state changes in FP? Well, the answer is State Monad.

State Monad

Chapter 6 “Purely functional state” in “FP in Scala” book, starts with the problem of Random number generation.

Random Number with Side effect:

val rng = new scala.util.Random // Create an instance of Random
rng.nextInt  // Call nextInt function to get a random integer
rng.nextInt // Call nextInt function to get a random integer

Everytime you call nextInt function, it doles out a random value since rng has some internal state that gets updated after each invocation, since we would otherwise get the same value each time we called nextInt.

Basically, if you think about the implementation, it holds an internal STATE using which it generates a NEW RANDOM VALUE when you call the function. To make the implementation pure, accept the STATE as a function parameter asking to be passed by the caller every time when they need a value.

Random Number without Side effect:

trait RNG {
def nextInt: (Int, RNG)
case class SimpleRNG(seed: Long) extends RNG {
def nextInt: (Int, RNG) = {
val newSeed = (seed * 0x5DEECE66DL+ 0xBL) & 0xFFFFFFFFFFFFL;
val nextRNG = SimpleRNG(newSeed);
val n = (newSeed >>> 16).toInt
(n, nextRNG)

The common abstraction for making stateful APIs pure is the essence of State Monad. The essence of the signature:

case class StateMonad[S, A](run: (S => (A, S)))

Basically, it encapsulates a run function which takes a STATE argument and return a TUPLE capturing the (VALUE, NEXT STATE). The problem has been inverted in a way where the client needs to pass the “NEXT STATE” to generate the next (VALUE,  STATE).


Assume that there is a CRUD application to create, update, find, delete employee which is using Mysql database for persistence. If you open a SINGLE database terminal and issue CRUD operations against a database what is essentially happening is a STATE TRANSITION. Say there is an “EMPLOYEE” table with zero records which can be thought as the initial state of the database D. Now if you issue a INSERT INTO EMPLOYEE VALUES(‘VMKR’) a new record gets inserted which can be thought as a new state D’ of the database and value produced being the employee record. So it is a database transition from D => (VALUE, D’).

Now assume that you wanted to write unit test case to test this CRUD API. Obviously you would not want to hit the database and would ideally mock it. A simple way to mock a database is to use an in-memory map.

Initial State: Empty Map
Create an Employee:   (Empty Map) => (Map with 1 record, Employee value)
Update an Employee: (Map with 1 record) => (Map with 1 record, Employee value)
Find an Employee: (Map with 1 record) => (Map with 1 record, Option{Employee])
Delete an Employee: (Map with 1 record) => (Empty Map, Unit)

BINGO! We can use the “State Monad” abstraction to solve this problem. I have listed the source code below:

package com.fp.statemonad

import StateMonad._

import scala.collection.immutable.TreeMap

case class StateMonad[S, A](run: (S => (A, S))) {

  def map[B](f: A => B): StateMonad[S, B] =

    StateMonad(s => {

      val (a, s1) = run(s)

      (f(a), s1)


  def flatMap[B](f: A => StateMonad[S, B]): StateMonad[S, B] = StateMonad(s => {

    val (a, s1) = run(s)




case class Employee(id: Int, name: String)

trait MEmployee[M[_]] {

  def createEmployee(id: Int, name: String): M[Employee]

  def updateEmployee(id: Int, name: String): M[Employee]

  def findEmployee(id: Int): M[Option[Employee]]

  def deleteEmployee(id: Int): M[Unit]


object MEmployee extends MEmployeeInstances {

  def createEmployee[M[_]](id: Int, name: String)(implicit M: MEmployee[M]): M[Employee] = M.createEmployee(id, name)

  def updateEmployee[M[_]](id: Int, name: String)(implicit M: MEmployee[M]): M[Employee] = M.updateEmployee(id, name)

  def findEmployee[M[_]](id: Int)(implicit M: MEmployee[M]): M[Option[Employee]] = M.findEmployee(id)

  def deleteEmployee[M[_]](id: Int)(implicit M: MEmployee[M]): M[Unit] = M.deleteEmployee(id)


trait MEmployeeInstances {


  implicit def MEmployee[M[+_], S] = new MEmployee[({ type λ[α] = StateMonad[Map[Int, Employee], α] })#λ] {

    def createEmployee(id: Int, name: String) = StateMonad(m => {

      val e = Employee(id, name);

      (Employee(id, name), m.+((id, e)))


    def updateEmployee(id: Int, update: String) =

      StateMonad(m => {

        val e = Employee(id, update);

        (Employee(id, update), m.+((id, e)))


    def findEmployee(id: Int) =

      StateMonad(m => {

        val o = m.get(id);

        (o, m)


    def deleteEmployee(id: Int) =

      StateMonad(m => {

        ((), m.(id))




object Run extends App {

  import StateMonad._

  type TestState[M] = StateMonad[Map[Int, Employee], M]

  val state = for {

    c1 <- MEmployee.createEmployee[TestState](1, “Mayakumar Vembunarayanan”)

    c2 <- MEmployee.createEmployee[TestState](2, “Aarathy Mayakumar”)

    u1 <- MEmployee.updateEmployee[TestState](1, “vmkr”)

    f1 <- MEmployee.findEmployee[TestState](1)

    _ = println(“Found Employee: ” + f1)

    _ <- MEmployee.deleteEmployee[TestState](1)

  } yield ()



The code does the following:

  • case class StateMonad[S, A](run: (S => (A, S))) is the crux of making state transitions pure.
  • It has map and flatMap making it a “FUNCTOR” and “MONAD”. To understand more about “FUNCTOR” and “MONAD” read here: FUNCTOR_AND_MONAD
  • case class Employee(id: Int, name: String) is the Employee data model with an id and a name.
  • trait MEmployee[M[_]]  defines the CRUD api’s. Think it as “interfaces” in Java parlance.
  • object MEmployee is the companion object. Its purpose is to make the usage of trait MEmployee easier for the client. The clients can simply call MEmployee.createEmployee. It will work as long as there is an implementation implicitly available as defined by the api: (implicit M: MEmployee[M]).
  • The test implementation of using the State Monad with in-memory Map is provided by MEmployeeInstances.
  • Cryptic syntax: ({ type λ[α] = StateMonad[Map[Int, Employee], α] })#λ is called TYPE_LAMBDAS
  • The absolute magic about the State Monad in the above example is the code execution actually happens when we run the State Monad which happens here :


  • One other important point is: flatMap implicitly passes the state(in this eg: Map) to each of the subsequent functions after the first createEmployee since the for-comprehension on a Monad is a syntactic sugar of using flatMap function all the way finally yielding a value using the map function.

Output of running the program:

Found Employee: Some(Employee(1,vmkr))

((),Map(2 -> Employee(2,Aarathy Mayakumar)))


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s