A non-mathematician's introduction to monads
a-non-mathematician-s-introduction-to-monads.Rmd
library(chronicler)
Introduction
This vignette introduces the functional programming concept of
monad, without going into much technical detail.
{chronicler}
is an implementation of a logger monad, but in
truth, it is not necessary to know what monads are to use this package.
However, if you are curious, read on. A monad is a computation device
that offers two things:
- the possibility to decorate functions so they can provide additional output without having to touch the function’s core implementation;
- a way to compose these decorated functions;
(This definition is an oversimplification of the actual definition of a monad, but good enough for our purposes.)
To understand what a monad is, I believe it is useful to explain what sort of problem monads solve.
Suppose for instance that you wish for your functions to provide a log when they’re run. If your function looks like this:
my_sqrt <- function(x){
sqrt(x)
}
Then you would need to rewrite this function like this:
There are two problems with such an implementation:
- we need to rewrite every function we need to use so that they provide logs;
- these functions don’t compose.
What do I mean with “these functions don’t compose”? Consider another
such function my_log()
:
sqrt()
and log()
compose, or rather, they
can be chained:
while this is not true for my_sqrt()
and
my_log()
:
10 |>
my_sqrt() |>
my_log()
Error in log(x) (from #3) : non-numeric argument to mathematical function
This is because my_log()
expects a number, not a list
which is what my_sqrt()
returns.
A “monad” is what we need to solve these two problems. The first problem, not having to rewrite every function, can be tackled using function factories. Let’s write one for our problem:
log_it <- function(.f, ..., log = NULL){
fstring <- deparse(substitute(.f))
function(..., .log = log){
list(result = .f(...),
log = c(.log,
paste0("Running ", fstring, " with argument ", ...)))
}
}
We can now create our functions easily:
l_sqrt <- log_it(sqrt)
l_sqrt(10)
#> $result
#> [1] 3.162278
#>
#> $log
#> [1] "Running sqrt with argument 10"
l_log <- log_it(log)
l_log(10)
#> $result
#> [1] 2.302585
#>
#> $log
#> [1] "Running log with argument 10"
We can call l_sqrt()
and l_log()
decorated functions and the values they return monadic
values.
The second issue remains though; l_sqrt()
and
l_log()
can’t be composed/chained. To solve this issue, we
need another function, called bind()
:
bind <- function(.l, .f, ...){
.f(.l$result, ..., .log = .l$log)
}
Using bind()
, it is now possible to compose
l_sqrt()
and l_log()
:
10 |>
l_sqrt() |>
bind(l_log)
#> $result
#> [1] 1.151293
#>
#> $log
#> [1] "Running sqrt with argument 10"
#> [2] "Running log with argument 3.16227766016838"
bind()
takes care of providing the right arguments to
the underlying function. We can check that the result is correct by
comparing it the $result
value from the returned object to
log(sqrt(10))
:
This solution of using a function factory and defining a helper function to make the decorated functions compose is what constitutes a monad, but strictly speaking, this is not precisely correct. It can be interesting to see the actual definition from the programming language Haskell, which is a pure functional programming language where monads must be used to solve certain issues:
Monads can be viewed as a standard programming interface to various data or control structures, which is captured by Haskell’sMonad
class. All the common monads are members of it:
class Monad m where
(>>=) :: m a -> ( a -> m b) -> m b
(>>) :: m a -> m b -> m b
return :: a -> m a
(Source: Monad)
This definition is quite cryptic, especially if you don’t know
Haskell, but what this means is that a Monad
(in Haskell)
is something that has three methods:
-
>>=
which is what we calledbind()
; -
>>
which I didn’t bother implementing, because it’s not really needed for understanding what a monad is; - and
return
. Don’t be confused by the name, this has nothing to do with thereturn()
we use inside functions to return a value.return
is a function that wraps (or converts) a value into a monadic value, so if you consider any objecta
,return
takesa
as an input and returns the monadic valuem a
.
While we didn’t implement return
(also called
unit
, which is also not a good name), our function factory
log_it()
does return
/unit
’s job
but it returns m f(a)
instead of m a
.
Using function factories comes more naturally to R users than using
return
/unit
, hence why I did not focus on
return
/unit
. Also, using our function factory,
it is easy to implement return/unit
:
unit <- log_it(identity)
so return/unit
is just the identity()
function that went through the function factory. In a sense, the
function factory is even more necessary for defining a monad than
return/unit
.
Finally, you might read sometimes that monads are objects that have a
flatmap()
method. I think that this definition as well is
not strictly correct and very likely an oversimplification. But what is
flatmap()
anyways? In practical terms, it is equivalent to
bind()
, but it is how you get there that’s different. To
implement flatmap()
two additional functions are needed:
fmap()
and flatten()
(which is quite often
called join()
, but this has nothing to do with
joining data frames, so I used flatten()
instead).
fmap()
is a function that takes a monadic value as an
argument and an undecorated function and applies this undecorated
function to the monadic value:
fmap <- function(m, f, ...){
fstring <- deparse(substitute(f))
list(result = f(m$result, ...),
log = c(m$log,
paste0("fmapping ", fstring, " with arguments ", paste0(m$result, ..., collapse = ","))))
}
Let’s first define a monadic value:
# Let’s use unit(), which we defined above, for this.
(m <- unit(10))
#> $result
#> [1] 10
#>
#> $log
#> [1] "Running identity with argument 10"
Let’s now use fmap()
to apply a non-decorated function
to m
:
fmap(m, log)
#> $result
#> [1] 2.302585
#>
#> $log
#> [1] "Running identity with argument 10" "fmapping log with arguments 10"
Great, now what about flatten()
(or
join()
)? Why is that useful? Suppose that instead of
log()
we used l_log()
with fmap()
(so we’re using a decorated function instead of an undecorated one):
fmap(m, l_log)
#> $result
#> $result$result
#> [1] 2.302585
#>
#> $result$log
#> [1] "Running log with argument 10"
#>
#>
#> $log
#> [1] "Running identity with argument 10" "fmapping l_log with arguments 10"
As you can see from the output, this produced a nested list, a
monadic value where the value is itself a monadic value. We would like
flatten()/join()
to take care of this for us. So this could
be an implementation of flatten()
:
Let’s try now:
flatten(fmap(m, l_log))
#> $result
#> [1] 2.302585
#>
#> $log
#> [1] "Running identity with argument 10" "fmapping l_log with arguments 10"
Great! Now, as explained earlier, flatmap()
and
bind()
are the same thing. But we have implemented
flatten()
and fmap()
, so how do these two
functions relate to flatmap()
? It turns out that
flatmap()
is the composition of flatten()
and
fmap()
:
# I first define a composition operator for functions
`%.%` <- \(f,g)(function(...)(f(g(...))))
# I now compose flatten() and fmap()
# flatten %.% fmap is read as "flatten after fmap"
flatmap <- flatten %.% fmap
So this means that we can now replace:
10 |>
l_sqrt() |>
bind(l_log)
#> $result
#> [1] 1.151293
#>
#> $log
#> [1] "Running sqrt with argument 10"
#> [2] "Running log with argument 3.16227766016838"
by:
10 |>
l_sqrt() |>
flatmap(l_log)
#> $result
#> [1] 1.151293
#>
#> $log
#> [1] "Running sqrt with argument 10"
#> [2] "fmapping l_log with arguments 3.16227766016838"
and we get the same result (well, not quite, since the log is
different). I prefer introducing monads using bind()
,
because bind()
comes as a natural solution to the problem
of decorated functions not composing. Not so with
flatmap()
, but in some applications it might be easier to
first define flatten()
and join()
and get
flatmap()
instead of trying to write bind()
directly, so it’s good to know both approaches.
Before continuing with the final part of this introduction, I just
want to share with you that lists are also monads. We have everything we
need: as.list()
is unit()
,
purrr::map()
is fmap()
and
purrr::flatten()
is flatten()
. This means we
can obtain flatmap()
from composing
purrr::flatten()
and purrr::map()
:
# Since I'm using `{purrr}`, might as well use purrr::compose() instead of my own implementation
flatmap_list <- purrr::compose(purrr::flatten, purrr::map)
# Functions that return lists: they don't compose!
# no worries, we implemented `flatmap_list()`
list_sqrt <- \(x)(as.list(sqrt(x)))
list_log <- \(x)(as.list(log(x)))
10 |>
list_sqrt() |>
flatmap_list(list_log)
#> [[1]]
#> [1] 1.151293
(thanks to @armcn_ for showing me this)
In sum, monads are useful when you need values to also carry
something more with them. This something can be a log, as shown
here, but there are many examples. For another example of a monad
implemented as an R package, see the maybe monad.
{chronicle}
actually takes advantage of the
maybe package and uses the maybe monad to handle cases
where functions fail. I provide a short introduction to the maybe monad
in the Maybe
monad vignette.
Monadic laws
Monads need to satisfy the so-called “monadic laws”. We’re going to
verify if the monad implemented in {chronicler}
satisfies
these monadic laws.
First law
The first law states that passing a monadic value to a monadic
function using bind()
(or in the case of the
{chronicler}
package bind_record()
) or passing
a value to a monadic function is the same.
a <- as_chronicle(10)
r_sqrt <- record(sqrt)
test_that("first monadic law", {
expect_equal(bind_record(a, r_sqrt)$value, r_sqrt(10)$value)
})
#> Test passed
Turns out that this is not quite the case here; the logs of the two objects will be slightly different. So I only check the value.
Second law
The second law states that binding a monadic value to
return()
(called as_chronicle()
in this
package, in other words, the function that coerces values to chronicler
objects) does nothing. Here again we have an issue with the log, that’s
why I focus on the value:
test_that("second monadic law", {
expect_equal(bind_record(a, as_chronicle)$value, a$value)
})
#> Test passed
Third law
The third law is about associativity; applying monadic functions successively or composing them first gives the same result.
a <- as_chronicle(10)
r_sqrt <- record(sqrt)
r_exp <- record(exp)
r_mean <- record(mean)
test_that("third monadic law", {
expect_equal(
(
(bind_record(a, r_sqrt)) |>
bind_record(r_exp)
)$value,
(
a |>
(\(x) bind_record(x, r_sqrt) |> bind_record(r_exp))()
)$value
)
})
#> Test passed
flatmap() for chronicle
objects
For exhaustivity’s sake, I check that I can get
flatmap_record()
by composing flatten_record()
and fmap_record()
:
r_sqrt <- record(sqrt)
r_exp <- record(exp)
r_mean <- record(mean)
a <- 1:10 |>
r_sqrt() |>
bind_record(r_exp) |>
bind_record(r_mean)
flatmap_record <- purrr::compose(flatten_record, fmap_record)
b <- 1:10 |>
r_sqrt() |>
flatmap_record(r_exp) |>
flatmap_record(r_mean)
identical(a$value, b$value)
#> [1] TRUE