Comparing TypedDatasets with Spark's Datasets

Goal: This tutorial compares the standard Spark Datasets api with the one provided by frameless' TypedDataset. It shows how TypedDatsets allows for an expressive and type-safe api with no compromises on performance.

For this tutorial we first create a simple dataset and save it on disk as a parquet file. Parquet is a popular columnar format and well supported by Spark. It's important to note that when operating on parquet datasets, Spark knows that each column is stored separately, so if we only need a subset of the columns Spark will optimize for this and avoid reading the entire dataset. This is a rather simplistic view of how Spark and parquet work together but it will serve us well for the context of this discussion.

import spark.implicits._
// import spark.implicits._

// Our example case class Foo acting here as a schema
case class Foo(i: Long, j: String)
// defined class Foo

// Assuming spark is loaded and SparkSession is bind to spark
val initialDs = spark.createDataset( Foo(1, "Q") :: Foo(10, "W") :: Foo(100, "E") :: Nil )
// initialDs: org.apache.spark.sql.Dataset[Foo] = [i: bigint, j: string]

// Assuming you are on Linux or Mac OS
initialDs.write.parquet("/tmp/foo")

val ds = spark.read.parquet("/tmp/foo").as[Foo]
// ds: org.apache.spark.sql.Dataset[Foo] = [i: bigint, j: string]

ds.show()
// +---+---+
// |  i|  j|
// +---+---+
// |100|  E|
// |  1|  Q|
// | 10|  W|
// +---+---+
//

The value ds holds the content of the initialDs read from a parquet file. Let's try to only use field i from Foo and see how Spark's Catalyst (the query optimizer) optimizes this.

// Using a standard Spark TypedColumn in select()
val filteredDs = ds.filter($"i" === 10).select($"i".as[Long])
// filteredDs: org.apache.spark.sql.Dataset[Long] = [i: bigint]

filteredDs.show()
// +---+
// |  i|
// +---+
// | 10|
// +---+
//

The filteredDs is of type Dataset[Long]. Since we only access field i from Foo the type is correct. Unfortunately, this syntax requires handholding by explicitly setting the TypedColumn in the select statement to return type Long (look at the as[Long] statement). We will discuss this limitation next in more detail. Now, let's take a quick look at the optimized Physical Plan that Spark's Catalyst generated.

filteredDs.explain()
// == Physical Plan ==
// *Project [i#69L]
// +- *Filter (isnotnull(i#69L) && (i#69L = 10))
//    +- *BatchedScan parquet [i#69L] Format: ParquetFormat, InputPaths: file:/tmp/foo, PartitionFilters: [], PushedFilters: [IsNotNull(i), EqualTo(i,10)], ReadSchema: struct<i:bigint>

The last line is very important (see ReadSchema). The schema read from the parquet file only required reading column i without needing to access column j. This is great! We have both an optimized query plan and type-safety!

Unfortunately, this syntax is not bulletproof: it fails at run-time if we try to access a non existing column x:

scala> ds.filter($"i" === 10).select($"x".as[Long])
org.apache.spark.sql.AnalysisException: cannot resolve '`x`' given input columns: [i, j];;
'Project ['x]
+- Filter (i#69L = cast(10 as bigint))
   +- Relation[i#69L,j#70] parquet

  at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:77)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:74)
  at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:308)
  at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:308)
  at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69)
  at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
  at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionUp$1(QueryPlan.scala:269)
  at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$2(QueryPlan.scala:279)
  at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$2$1.apply(QueryPlan.scala:283)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
  at scala.collection.immutable.List.foreach(List.scala:392)
  at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
  at scala.collection.immutable.List.map(List.scala:296)
  at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$2(QueryPlan.scala:283)
  at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$8.apply(QueryPlan.scala:288)
  at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:186)
  at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:288)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:74)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:67)
  at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:126)
  at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:67)
  at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:58)
  at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:49)
  at org.apache.spark.sql.Dataset.<init>(Dataset.scala:161)
  at org.apache.spark.sql.Dataset.<init>(Dataset.scala:167)
  at org.apache.spark.sql.Dataset.select(Dataset.scala:1023)
  ... 450 elided

There are two things to improve here. First, we would want to avoid the at[Long] casting that we are required to type for type-safety. This is clearly an area where we can introduce a bug by casting to an incompatible type. Second, we want a solution where reference to a non existing column name fails at compilation time. The standard Spark Dataset can achieve this using the following syntax.

ds.filter(_.i == 10).map(_.i).show()
// +-----+
// |value|
// +-----+
// |   10|
// +-----+
//

This looks great! It reminds us the familiar syntax from Scala. The two closures in filter and map are functions that operate on Foo and the compiler will helps us capture all the mistakes we mentioned above.

scala> ds.filter(_.i == 10).map(_.x).show()
<console>:20: error: value x is not a member of Foo
       ds.filter(_.i == 10).map(_.x).show()
                                  ^

Unfortunately, this syntax does not allow Spark to optimize the code.

ds.filter(_.i == 10).map(_.i).explain()
// == Physical Plan ==
// *SerializeFromObject [input[0, bigint, true] AS value#105L]
// +- *MapElements <function1>, obj#104: bigint
//    +- *DeserializeToObject newInstance(class $line14.$read$$iw$$iw$$iw$$iw$Foo), obj#103: $line14.$read$$iw$$iw$$iw$$iw$Foo
//       +- *Filter <function1>.apply
//          +- *BatchedScan parquet [i#69L,j#70] Format: ParquetFormat, InputPaths: file:/tmp/foo, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<i:bigint,j:string>

As we see from the explained Physical Plan, Spark was not able to optimize our query as before. Reading the parquet file will required loading all the fields of Foo. This might be ok for small datasets or for datasets with few columns, but will be extremely slow for most practical applications. Intuitively, Spark currently doesn't have a way to look inside the code we pass in these two closures. It only knows that they both take one argument of type Foo, but it has no way of knowing if we use just one or all of Foo's fields.

The TypedDataset in frameless solves this problem. It allows for a simple and type-safe syntax with a fully optimized query plan.

import frameless.TypedDataset
// import frameless.TypedDataset

val fds = TypedDataset.create(ds)
// fds: frameless.TypedDataset[Foo] = [i: bigint, j: string]

fds.filter( fds('i) === 10 ).select( fds('i) ).show().run()
// +---+
// | _1|
// +---+
// | 10|
// +---+
//

And the optimized Physical Plan:

fds.filter( fds('i) === 10 ).select( fds('i) ).explain()
// == Physical Plan ==
// *Project [i#69L AS _1#176L]
// +- *Filter (isnotnull(i#69L) && (i#69L = 10))
//    +- *BatchedScan parquet [i#69L] Format: ParquetFormat, InputPaths: file:/tmp/foo, PartitionFilters: [], PushedFilters: [IsNotNull(i), EqualTo(i,10)], ReadSchema: struct<i:bigint>

And the compiler is our friend.

scala> fds.filter( fds('i) === 10 ).select( fds('x) )
<console>:21: error: No column Symbol with shapeless.tag.Tagged[String("x")] of type A in Foo
       fds.filter( fds('i) === 10 ).select( fds('x) )
                                               ^

Differences in Encoders

Encoders in Spark's Datasets are partially type-safe. If you try to create a Dataset using a type that is not a Scala Product then you get a compilation error:

class Bar(i: Int)
// defined class Bar

Bar is neither a case class nor a Product, so the following correctly gives a compilation error in Spark:

scala> spark.createDataset(Seq(new Bar(1)))
<console>:21: error: Unable to find encoder for type stored in a Dataset.  Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases.
       spark.createDataset(Seq(new Bar(1)))
                          ^

However, the compile type guards implemented in Spark are not sufficient to detect non encodable members. For example, using the following case class leads to a runtime failure:

case class MyDate(jday: java.util.Date)
// defined class MyDate
val myDateDs = spark.createDataset(Seq(MyDate(new java.util.Date(System.currentTimeMillis))))
// java.lang.UnsupportedOperationException: No Encoder found for java.util.Date
// - field (class: "java.util.Date", name: "jday")
// - root class: "MyDate"
//   at org.apache.spark.sql.catalyst.ScalaReflection$.org$apache$spark$sql$catalyst$ScalaReflection$$serializerFor(ScalaReflection.scala:598)
//   at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:592)
//   at org.apache.spark.sql.catalyst.ScalaReflection$$anonfun$9.apply(ScalaReflection.scala:583)
//   at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
//   at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)
//   at scala.collection.immutable.List.foreach(List.scala:392)
//   at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)
//   at scala.collection.immutable.List.flatMap(List.scala:355)
//   at org.apache.spark.sql.catalyst.ScalaReflection$.org$apache$spark$sql$catalyst$ScalaReflection$$serializerFor(ScalaReflection.scala:583)
//   at org.apache.spark.sql.catalyst.ScalaReflection$.serializerFor(ScalaReflection.scala:425)
//   at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:61)
//   at org.apache.spark.sql.Encoders$.product(Encoders.scala:274)
//   at org.apache.spark.sql.SQLImplicits.newProductEncoder(SQLImplicits.scala:47)
//   ... 770 elided

In comparison, a TypedDataset will notify about the encoding problem at compile time:

TypedDataset.create(Seq(MyDate(new java.util.Date(System.currentTimeMillis))))
// <console>:22: error: could not find implicit value for parameter encoder: frameless.TypedEncoder[MyDate]
//        TypedDataset.create(Seq(MyDate(new java.util.Date(System.currentTimeMillis))))
//                           ^

results matching ""

    No results matching ""