from pyspark.sql import SparkSession spark = SparkSession \ .builder \ .master('local[4]') \ .config("spark.ui.showConsoleProgress", "false") \ .config("spark.mongodb.input.uri", "mongodb://127.0.0.1/test.coll") \ .config("spark.mongodb.output.uri", "mongodb://127.0.0.1/test.coll") \ .appName("mySparkApp") \ .getOrCreate() # ------ Test 1: small dataset 1k x 100 (OK: works as expected) spark.createDataFrame([(i, [k for k in range(100)]) for i in range(1000)], ["seq", "data"]) \ .write.format("com.mongodb.spark.sql.DefaultSource") \ .mode("overwrite").save() test1 = spark.read.format("com.mongodb.spark.sql.DefaultSource") \ .option("sampleSize", 100) \ .option("pipeline", [{'$limit': 1}]) \ .load() print('Test 1: Expected 1 row, got', test1.count(), 'row:') test1.show() # ------ Test 2: larger dataset 100k x 100 (FAILS: limit stage returns 3 rows instead of 1) spark.createDataFrame([(i, [k for k in range(100)]) for i in range(100000)], ["seq", "data"]) \ .write.format("com.mongodb.spark.sql.DefaultSource") \ .mode("overwrite").save() test2 = spark.read.format("com.mongodb.spark.sql.DefaultSource") \ .option("sampleSize", 100) \ .option("pipeline", [{'$limit': 1}]) \ .load() print('Test 2: Expected 1 row, got', test2.count(), 'rows:') test2.show() spark.stop()