I wanted to read the pubsub topic and write data to BigTable with the dataflow code written in Python. I could find the sample code in JAVA but not in Python.
How can we assign columns in a row from pubsub to different column families and write the data to Bigtable?

To write to Bigtable in a Dataflow pipeline, you'll need to create direct rows and pass them to the WriteToBigTable doFn. Here is a brief example that just passes in the row keys and adds one cell for each key nothing too fancy:
import datetime
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from import WriteToBigTable
from import row
class MyOptions(PipelineOptions):
def _add_argparse_args(cls, parser):
help='The Bigtable project ID, this can be different than your '
'Dataflow project',
help='The Bigtable instance ID',
help='The Bigtable table ID in the instance.',
class CreateRowFn(beam.DoFn):
def process(self, key):
direct_row = row.DirectRow(row_key=key)
return [direct_row]
def run(argv=None):
"""Build and run the pipeline."""
options = MyOptions(argv)
with beam.Pipeline(options=options) as p:
p | beam.Create(["phone#4c410523#20190501",
"phone#4c410523#20190502"]) | beam.ParDo(
CreateRowFn()) | WriteToBigTable(
if __name__ == '__main__':
I am just starting to explore this now and can link to a more polished version on GitHub once it's complete. Hope this helps you get started.

Building on top of what was proposed and adding PubSub, here’s a working version..
Pre requisites
GCS Bucket created (for Dataflow temp/staging files)
PubSub topic created
PubSub subscription created
BigTable instance created
BigTable table created
BigTable column family must be created (no visible error otherwise !)
Example of the latter with cbt:
cbt -instance test-instance createfamily test-table cf1
Define and run the Dataflow pipeline.
# Packages
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from import WriteToBigTable
from import pubsub_v1
# Classes
class CreateRowFn(beam.DoFn):
def __init__(self, pipeline_options):
self.instance_id = pipeline_options.bigtable_instance
self.table_id = pipeline_options.bigtable_table
def process(self, key):
from import row
import datetime
direct_row = row.DirectRow(row_key=key)
yield direct_row
# Options
class XyzOptions(PipelineOptions):
def _add_argparse_args(cls, parser):
parser.add_argument('--bigtable_project', default='nested'),
parser.add_argument('--bigtable_instance', default='instance'),
parser.add_argument('--bigtable_table', default='table')
pipeline_options = XyzOptions(
save_main_session=True, streaming=True,
# Pipeline
def run (argv=None):
with beam.Pipeline(options=pipeline_options) as p:
_ = (p
| 'Read from Pub/Sub' >>
| 'Conversion UTF-8 bytes to string' >> beam.Map(lambda msg: msg.decode('utf-8'))
| 'Conversion string to row object' >> beam.ParDo(CreateRowFn(pipeline_options))
| 'Writing row object to BigTable' >> WriteToBigTable(project_id=pipeline_options.bigtable_project,
if __name__ == '__main__':
Publish a message b"phone#1111" to PubSub topic (e.g. using the Python PublisherClient()).
Table content (using happybase)
b'phone#1111': {b'cf1:field1': b'value1'}
Row length: 1


I am trying to read from a kafka topic using Apache Beam and Dataflow, print the data to the console and finally write them to a pubsub topic. But it seems to get stuck in the ReadFromKafka function. There are many data written into the kafka topic, but nothing happen in this pipeline when it runs.
import apache_beam as beam
import argparse
from import ReadFromKafka
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
def run(argv=None, save_main_session=True):
parser = argparse.ArgumentParser()
known_args, pipeline_args = parser.parse_known_args(argv)
class PrintValue(beam.DoFn):
def process(self, element):
return [element]
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
with beam.Pipeline(options=pipeline_options) as pipeline:
_ = (
| 'Read from Kafka' >> ReadFromKafka(
consumer_config={'bootstrap.servers': 'ip:port' },
| 'print' >> beam.ParDo(PrintValue())
| 'write to pubsub' >>'projects/sample/topics/test')
if __name__ == '__main__':
I know there is an Issue
but as i understand it, this problem only belongs to portable runners. Does anybody know if ReadFromKafka is working with unbounded data in Dataflow?
Python 3.8.10
I had a similar issue, and switched to using a beam.Map transform instead (make sure your printValue function is defined within the run function, or you have a proper dependency management method):
| Map(lambda value: printValue(value))
Note that the type of elements you get from ReadFromKafka is an ad hoc class named BeamSchema_xxxxxxxxx, having the following attributes (assuming you configure reader with_metadata=True): 'topic', 'value', 'count', 'headers', 'index', 'key', 'offset', 'partition', 'timestamp', 'timestampTypeId', 'timestampTypeName'. It does't print nice if at all.
So you want to decode your values first, for example:
def decode_kafka_message(record) -> str:
Record attributes passed from ReadFromKafka transform: 'topic', 'value'
'count', 'headers', 'index', 'key', 'offset', 'partition',
'timestamp', 'timestampTypeId', 'timestampTypeName'.
:return: Message value as string
if hasattr(record, 'value'):
value = record.value
elif isinstance(record, tuple):
value = record[1]
raise RuntimeError('unknown record type: %s' % type(record))
return value.decode("UTF-8") if isinstance(value, bytes) else value
That connector could use some work and better docs.

I want to publish messages to a Pub/Sub topic with some attributes thanks to Dataflow Job in batch mode.
My dataflow pipeline is write with python 3.8 and apache-beam 2.27.0
It works with the #Ankur solution here :
But I think it could be more efficient with a shared Pub/Sub Client :
However an error occurred:
return StockUnpickler.find_class(self, module, name) AttributeError:
Can't get attribute 'PublishFn' on <module 'dataflow_worker.start'
Would the shared publisher implementation improve beam pipeline performance?
Is there another way to avoid pickling error on my shared publisher client ?
My Dataflow Pipeline :
import apache_beam as beam
from import bigquery
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from import PublisherClient
import json
import argparse
import re
import logging
class PubsubClient(PublisherClient):
def __reduce__(self):
return self.__class__, (self.batch_settings,)
# The DoFn to perform on each element in the input PCollection.
class PublishFn(beam.DoFn):
def __init__(self):
from import pubsub_v1
batch_settings = pubsub_v1.types.BatchSettings(
max_bytes=1024, # One kilobyte
max_latency=1, # One second
self.publisher = PubsubClient(batch_settings)
def process(self, element, **kwargs):
future = self.publisher.publish(
return future.result()
def run(argv=None, save_main_session=True):
"""Main entry point; defines and runs the pipeline."""
parser = argparse.ArgumentParser()
help="BigQuery source table <project>.<dataset>.<table> with columns (topic, attributes, data)",
known_args, pipeline_args = parser.parse_known_args(argv)
# We use the save_main_session option because one or more DoFn's in this
# workflow rely on global context (e.g., a module imported at module level).
pipeline_options = PipelineOptions(pipeline_args)
# pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
bq_source_table = known_args.source_table_id
bq_table_regex = r"^(?P<PROJECT_ID>[a-zA-Z0-9_-]*)[\.|\:](?P<DATASET_ID>[a-zA-Z0-9_]*)\.(?P<TABLE_ID>[a-zA-Z0-9_-]*)$"
regex_match =, bq_source_table)
if not regex_match:
raise ValueError(
f"Bad BigQuery table id : `{bq_source_table}` please match {bq_table_regex}"
table_ref = bigquery.TableReference("PROJECT_ID"),"DATASET_ID"),"TABLE_ID"),
with beam.Pipeline(options=pipeline_options) as p:
| "ReadFromBqTable" #
>> bigquery.ReadFromBigQuery(table=table_ref, use_json_exports=True) # Each row contains : topic / attributes / data
| "PublishRowsToPubSub" >> beam.ParDo(PublishFn())
if __name__ == "__main__":
After fussing with this a bit, I think I have an answer that works consistently and is, if not world-beatingly performant, at least tolerably usable:
import logging
import apache_beam as beam
from import PubsubMessage
from import PublisherClient
from import (
class PublishClient(PublisherClient):
You have to override __reduce__ to make PublisherClient pickleable 😡 😤 🤬
Props to 'Ankur' and 'Benjamin' on SO for figuring this part out; god knows
I would not have...
def __reduce__(self):
return self.__class__, (self.batch_settings, self.publisher_options)
class PubsubWriter(beam.DoFn):
""" does not yet support batch operations, so
we do this the hard way. it's not as performant as the native
pubsubio but it does the job.
def __init__(self, topic: str):
self.topic = topic
self.window = beam.window.GlobalWindow()
self.count = 0
def setup(self):
batch_settings = BatchSettings(
max_bytes=1e6, # 1MB
# by default it is 10 ms, should be less than timeout used in future.result() to avoid timeout
publisher_options = PublisherOptions(
# better to be slow than to drop messages during a recovery...
self.publisher = PublishClient(batch_settings, publisher_options)
def start_bundle(self):
self.futures = []
def process(self, element: PubsubMessage, window=beam.DoFn.WindowParam):
self.window = window
def finish_bundle(self):
"""Iterate over the list of async publish results and block
until all of them have either succeeded or timed out. Yield
a WindowedValue of the success/fail counts."""
results = []
self.count = self.count + len(self.futures)
for fut in self.futures:
# future.result() blocks until success or timeout;
# we've set a max_latency of 60s upstairs in BatchSettings,
# so we should never spend much time waiting here.
except Exception as ex:
res_count = {"success": 0}
for res in results:
if isinstance(res, str):
res_count["success"] += 1
# if it's not a string, it's an exception
msg = str(res)
if msg not in res_count:
res_count[msg] = 1
res_count[msg] += 1"Pubsub publish results: {res_count}")
yield beam.utils.windowed_value.WindowedValue(
def teardown(self):"Published {self.count} messages")
The trick is that if you call future.result() inside the process() method, you will block until that single message is successfully published, so instead collect a list of futures and then at the end of the bundle make sure they're all either published or definitively timed out. Some quick testing with one of our internal pipelines suggested that this approach can publish 1.6M messages in ~200s.

I am writing Apache beam python code which read data from pubsub subscription and print it on console but it is getting struck and not getting completed.
import argparse
import logging
import ast
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
class FlattenJson(beam.DoFn):
def process(self, element, *args, **kwargs):
print("Element: {element}")
class DecodeMsgs(beam.DoFn):
def process(self, element):
print("############Before Decode:element",element)"Before Decode: {element}")
return [element]
class PubsubStreamingToBq:
def __init__(self):
def run(self, subscription, pipeline_args=None):
pipeline_options = PipelineOptions(pipeline_args, streaming=True, save_main_session=True)
print("Args: {pipeline_args}", pipeline_args)
with beam.Pipeline(options=pipeline_options) as p:
pubsub_logs = (p | "Read Pubsub Msg" >> | "Decoding" >> beam.ParDo(DecodeMsgs()))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--subscription', required=True)
known_args, pipeline_args = parser.parse_known_args()
print("Known Args: {known_args}", known_args)
PubsubStreamingToBq_obj = PubsubStreamingToBq(), pipeline_args)
Could anyone let me know what is the issue?
Using beam version 2.27.0 and Python version 3.6

New to Python and IB API and stuck on this simple thing. This application works correctly and prints IB server reply. However, I cannot figure out how to get this data into a panda's dataframe or any other variable for that matter. How do you "get the data out?" Thanks!
Nothing on forums, documentation or youtube that I can find with a useful example. I think the answer must be to return accountSummary to pd.Series, but no idea how.
Expected output would be a data series or variable that can be manipulated outside of the application.
from ibapi import wrapper
from ibapi.client import EClient
from ibapi.utils import iswrapper #just for decorator
from ibapi.common import *
import pandas as pd
class TestApp(wrapper.EWrapper, EClient):
def __init__(self):
EClient.__init__(self, wrapper=self)
def nextValidId(self, orderId:int):
print("setting nextValidOrderId: %d", orderId)
self.nextValidOrderId = orderId
# here is where you start using api
self.reqAccountSummary(9002, "All", "$LEDGER")
def error(self, reqId:TickerId, errorCode:int, errorString:str):
print("Error. Id: " , reqId, " Code: " , errorCode , " Msg: " , errorString)
def accountSummary(self, reqId:int, account:str, tag:str, value:str, currency:str):
print("Acct Summary. ReqId:" , reqId , "Acct:", account,
"Tag: ", tag, "Value:", value, "Currency:", currency)
#IB API data returns here, how to pass it to a variable or pd.series
def accountSummaryEnd(self, reqId:int):
print("AccountSummaryEnd. Req Id: ", reqId)
# now we can disconnect
def main():
app = TestApp()
app.connect("", 4001, clientId=123)
test = app.accountSummary
if __name__ == "__main__":
Hi had the same problem and collections did it for me. Here is my code for CFDs data. Maybe it will help somebody. You will have your data in app.df. Any suggestion for improvement are more than welcome.
import collections
import datetime as dt
from threading import Timer
from ibapi.client import EClient
from ibapi.wrapper import EWrapper
from ibapi.contract import Contract
import pandas as pd
# get yesterday and put it to correct format yyyymmdd{space}{space}hh:mm:dd
yesterday = str( - dt.timedelta(1))
yesterday = yesterday.replace('-','')
IP = ''
PORT = 7497
class App(EClient, EWrapper):
def __init__(self):
super().__init__(self) = collections.defaultdict(list)
def error(self, reqId, errorCode, errorString):
print(f'Error {reqId}, {errorCode}, {errorString}')
def historicalData(self, reqId, bar):['date'].append(['open'].append(['high'].append(bar.high)['low'].append(bar.low)['close'].append(bar.close)['volume'].append(bar.volume)
self.df = pd.DataFrame.from_dict(
def stop(self):
self.done = True
# create App object
app = App()
print('App created...')
app.connect(IP, PORT, 0)
print('App connected...')
# create contract
contract = Contract()
contract.symbol = 'IBDE30'
contract.secType = 'CFD' = 'SMART'
contract.currency = 'EUR'
print('Contract created...')
# request historical data for contract
durationStr='1 W',
barSizeSetting='15 mins',
Timer(4, app.stop).start()
I'd store the data to a dictionary, create a dataframe from the dictionary, and append the new dataframe to the main dataframe using the concat function. Here's an example:
def accountSummary(self, reqId:int, account:str, tag:str, value:str, currency:str):
acct_dict = {"account": account, "value": value, "currency": currency}
acct_df = pd.DataFrame([acct_dict], columns=acct_dict.keys())
main_df = pd.concat([main_df, acct_df], axis=0).reset_index()
For more information, you might like Algorithmic Trading with Interactive Brokers

I found the same discussion in comments section of Create a custom Transformer in PySpark ML, but there is no clear answer. There is also an unresolved JIRA corresponding to that:
Given that there is no option provided by Pyspark ML pipeline for saving a custom transformer written in python, what are the other options to get it done? How can I implement the _to_java method in my python class that returns a compatible java object?
As of Spark 2.3.0 there's a much, much better way to do this.
Simply extend DefaultParamsWritable and DefaultParamsReadable and your class will automatically have write and read methods that will save your params and will be used by the PipelineModel serialization system.
The docs were not really clear, and I had to do a bit of source reading to understand this was the way that deserialization worked. instantiates a PipelineModelReader
PipelineModelReader loads metadata and checks if language is 'Python'. If it's not, then the typical JavaMLReader is used (what most of these answers are designed for)
Otherwise, PipelineSharedReadWrite is used, which calls DefaultParamsReader.loadParamsInstance
loadParamsInstance will find class from the saved metadata. It will instantiate that class and call .load(path) on it. You can extend DefaultParamsReader and get the DefaultParamsReader.load method automatically. If you do have specialized deserialization logic you need to implement, I would look at that load method as a starting place.
On the opposite side:
PipelineModel.write will check if all stages are Java (implement JavaMLWritable). If so, the typical JavaMLWriter is used (what most of these answers are designed for)
Otherwise, PipelineWriter is used, which checks that all stages implement MLWritable and calls PipelineSharedReadWrite.saveImpl
PipelineSharedReadWrite.saveImpl will call .write().save(path) on each stage.
You can extend DefaultParamsWriter to get the DefaultParamsWritable.write method that saves metadata for your class and params in the right format. If you have custom serialization logic you need to implement, I would look at that and DefaultParamsWriter as a starting point.
Ok, so finally, you have a pretty simple transformer that extends Params and all your parameters are stored in the typical Params fashion:
from pyspark import keyword_only
from import Transformer
from import HasOutputCols, Param, Params
from import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform
class SetValueTransformer(
Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
value = Param(
"value to fill",
def __init__(self, outputCols=None, value=0.0):
super(SetValueTransformer, self).__init__()
kwargs = self._input_kwargs
def setParams(self, outputCols=None, value=0.0):
setParams(self, outputCols=None, value=0.0)
Sets params for this SetValueTransformer.
kwargs = self._input_kwargs
return self._set(**kwargs)
def setValue(self, value):
Sets the value of :py:attr:`value`.
return self._set(value=value)
def getValue(self):
Gets the value of :py:attr:`value` or its default value.
return self.getOrDefault(self.value)
def _transform(self, dataset):
for col in self.getOutputCols():
dataset = dataset.withColumn(col, lit(self.getValue()))
return dataset
Now we can use it:
from import Pipeline, PipelineModel
svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)
p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm =
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
|key|value| a| b|
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
matches? True
|key|value| a| b|
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
I am not sure this is the best approach, but I too need the ability to save custom Estimators, Transformers and Models that I have created in Pyspark, and also to support their use in the Pipeline API with persistence. Custom Pyspark Estimators, Transformers and Models may be created and used in the Pipeline API but cannot be saved. This poses an issue in production when the model training takes longer than an event prediction cycle.
In general, Pyspark Estimators, Transformers and Models are just wrappers around the Java or Scala equivalents and the Pyspark wrappers just marshal the parameters to and from Java via py4j. Any persisting of the model is then done on the Java side. Because of this current structure, this limits Custom Pyspark Estimators, Transformers and Models to living only in the python world.
In a previous attempt, I was able to save a single Pyspark model by using Pickle/dill serialization. This worked well, but still did not allow saving or loading back such from within the Pipeline API. But, pointed to by another SO post I was directed to the OneVsRest classifier, and inspected the _to_java and _from_java methods. They do all the heavy lifting on the Pyspark side. After looking I thought, if one had a way to save the pickle dump to an already made and supported savable java object, then it should be possible to save a Custom Pyspark Estimator, Transformer and Model with the Pipeline API.
To that end, I found the StopWordsRemover to be the ideal object to hijack because it has an attribute, stopwords, that is a list of strings. The dill.dumps method returns a pickled representation of the object as a string. The plan was to turn the string into a list and then set the stopwords parameter of a StopWordsRemover to this list. Though a list strings, I found that some of the characters would not marshal to the java object. So the characters get converted to integers then the integers to strings. This all works great for saving a single instance, and also when saving within in a Pipeline, because the Pipeline dutifully calls the _to_java method of my python class (we are still on the Pyspark side so this works). But, coming back to Pyspark from java did not in the Pipeline API.
Because I am hiding my python object in a StopWordsRemover instance, the Pipeline, when coming back to Pyspark, does not know anything about my hidden class object, it knows only it has a StopWordsRemover instance. Ideally, it would be great to subclass Pipeline and PipelineModel, but alas this brings us back to trying to serialize a Python object. To combat this, I created a PysparkPipelineWrapper that takes a Pipeline or PipelineModel and just scans the stages, looking for a coded ID in the stopwords list (remember, this is just the pickled bytes of my python object) that tells it to unwraps the list to my instance and stores it back in the stage it came from. Below is code that shows how this all works.
For any Custom Pyspark Estimator, Transformer and Model, just inherit from Identifiable, PysparkReaderWriter, MLReadable, MLWritable. Then when loading a Pipeline and PipelineModel, pass such through PysparkPipelineWrapper.unwrap(pipeline).
This method does not address using the Pyspark code in Java or Scala, but at least we can save and load Custom Pyspark Estimators, Transformers and Models and work with Pipeline API.
import dill
from import Transformer, Pipeline, PipelineModel
from import Param, Params
from import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter
from import StopWordsRemover
from import JavaParams
from pyspark.context import SparkContext
from pyspark.sql import Row
class PysparkObjId(object):
A class to specify constants used to idenify and setup python
Estimators, Transformers and Models so they can be serialized on there
own and from within a Pipline or PipelineModel.
def __init__(self):
super(PysparkObjId, self).__init__()
def _getPyObjId():
return '4c1740b00d3c4ff6806a1402321572cb'
def _getCarrierClass(javaName=False):
return '' if javaName else StopWordsRemover
class PysparkPipelineWrapper(object):
A class to facilitate converting the stages of a Pipeline or PipelineModel
that were saved from PysparkReaderWriter.
def __init__(self):
super(PysparkPipelineWrapper, self).__init__()
def unwrap(pipeline):
if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))
stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
for i, stage in enumerate(stages):
if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
stages[i] = PysparkPipelineWrapper.unwrap(stage)
if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
swords = stage.getStopWords()[:-1] # strip the id
lst = [chr(int(d)) for d in swords]
dmp = ''.join(lst)
py_obj = dill.loads(dmp)
stages[i] = py_obj
if isinstance(pipeline, Pipeline):
pipeline.stages = stages
return pipeline
class PysparkReaderWriter(object):
A mixin class so custom pyspark Estimators, Transformers and Models may
support saving and loading directly or be saved within a Pipline or PipelineModel.
def __init__(self):
super(PysparkReaderWriter, self).__init__()
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
def read(cls):
"""Returns an MLReader instance for our clarrier class."""
return JavaMLReader(PysparkObjId._getCarrierClass())
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
swr_java_obj =
return cls._from_java(swr_java_obj)
def _from_java(cls, java_obj):
Get the dumby the stopwords that are the characters of the dills dump plus our guid
and convert, via dill, back to our python instance.
swords = java_obj.getStopWords()[:-1] # strip the id
lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes
dmp = ''.join(lst)
py_obj = dill.loads(dmp)
return py_obj
def _to_java(self):
Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
:return: Java object equivalent to this instance.
dmp = dill.dumps(self)
pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
sc = SparkContext._active_spark_context
java_class =
java_array = sc._gateway.new_array(java_class, len(pylist))
for i in xrange(len(pylist)):
java_array[i] = pylist[i]
_java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
return _java_obj
class HasFake(Params):
def __init__(self):
super(HasFake, self).__init__()
self.fake = Param(self, "fake", "fake param")
def getFake(self):
return self.getOrDefault(self.fake)
class MockTransformer(Transformer, HasFake, Identifiable):
def __init__(self):
super(MockTransformer, self).__init__()
self.dataset_count = 0
def _transform(self, dataset):
self.dataset_count = dataset.count()
return dataset
class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
def __init__(self):
super(MyTransformer, self).__init__()
def make_a_dataframe(sc):
df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()
return df
def test1():
trA = MyTransformer()
trA.dataset_count = 999
print trA.dataset_count'test.trans')
trB = MyTransformer.load('test.trans')
print trB.dataset_count
def test2():
trA = MyTransformer()
pipeA = Pipeline(stages=[trA])
print type(pipeA)'testA.pipe')
pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))
stagesAA = pipeAA.getStages()
trAA = stagesAA[0]
print trAA.dataset_count
def test3():
dfA = make_a_dataframe(sc)
trA = MyTransformer()
pipeA = Pipeline(stages=[trA]).fit(dfA)
print type(pipeA)'testB.pipe')
pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))
stagesAA = pipeAA.stages
trAA = stagesAA[0]
print trAA.dataset_count
dfB = pipeAA.transform(dfA)
I couldn't get #dmbaker's ingenious solution to work using Python 2 on Spark 2.2.0; I kept getting pickling errors. After several blind alleys I got a working solution by modifying his (her?) idea to write and read the parameter values as strings into StopWordsRemover's stop words directly.
Here's the base class you need if you want to save and load your own estimators or transformers:
from pyspark import SparkContext
from import StopWordsRemover
from import Identifiable, MLWritable, JavaMLWriter, MLReadable, JavaMLReader
from import JavaWrapper, JavaParams
class PysparkReaderWriter(Identifiable, MLReadable, MLWritable):
A base class for custom pyspark Estimators and Models to support saving and loading directly
or within a Pipeline or PipelineModel.
def __init__(self):
super(PysparkReaderWriter, self).__init__()
def _getPyObjIdPrefix():
return "_ThisIsReallyA_"
def _getPyObjId(cls):
return PysparkReaderWriter._getPyObjIdPrefix() + cls.__name__
def getParamsAsListOfStrings(self):
raise NotImplementedError("PysparkReaderWriter.getParamsAsListOfStrings() not implemented for instance: %r" % self)
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
def _to_java(self):
# Convert all our parameters to strings:
paramValuesAsStrings = self.getParamsAsListOfStrings()
# Append our own type-specific id so PysparkPipelineLoader can detect this algorithm when unwrapping us.
# Convert the parameter values to a Java array:
sc = SparkContext._active_spark_context
java_array = JavaWrapper._new_java_array(paramValuesAsStrings,
# Create a Java (Scala) StopWordsRemover and give it the parameters as its stop words.
_java_obj = JavaParams._new_java_obj("", self.uid)
return _java_obj
def _from_java(cls, java_obj):
# Get the stop words, ignoring the id at the end:
stopWords = java_obj.getStopWords()[:-1]
return cls.createAndInitialisePyObj(stopWords)
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
raise NotImplementedError("PysparkReaderWriter.createAndInitialisePyObj() not implemented for type: %r" % cls)
def read(cls):
"""Returns an MLReader instance for our clarrier class."""
return JavaMLReader(StopWordsRemover)
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
swr_java_obj =
return cls._from_java(swr_java_obj)
Your own pyspark algorithm must then inherit from PysparkReaderWriter and override the getParamsAsListOfStrings() method which saves your parameters to a list of strings. Your algorithm must also override the createAndInitialisePyObj() method for converting a list of strings back into your parameters. Behind the scenes the parameters are converted to and from the stop words used by StopWordsRemover.
Example estimator with 3 parameters of different type:
from import Param, Params, TypeConverters
from import Estimator
class MyEstimator(Estimator, PysparkReaderWriter):
def __init__(self):
super(MyEstimator, self).__init__()
# 3 sample parameters, deliberately of different types:
stringParam = Param(Params._dummy(), "stringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)
def setStringParam(self, value):
return self._set(stringParam=value)
def getStringParam(self):
return self.getOrDefault(self.stringParam)
listOfStringsParam = Param(Params._dummy(), "listOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)
def setListOfStringsParam(self, value):
return self._set(listOfStringsParam=value)
def getListOfStringsParam(self):
return self.getOrDefault(self.listOfStringsParam)
intParam = Param(Params._dummy(), "intParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)
def setIntParam(self, value):
return self._set(intParam=value)
def getIntParam(self):
return self.getOrDefault(self.intParam)
def _fit(self, dataset):
model = MyModel()
# Just some changes to verify we can modify the model (and also it's something we can expect to see when restoring it later):
model.setAnotherStringParam(self.getStringParam() + " World!")
model.setAnotherListOfStringsParam(self.getListOfStringsParam() + ["E", "F"])
model.setAnotherIntParam(self.getIntParam() + 10)
return model
def getParamsAsListOfStrings(self):
paramValuesAsStrings = []
paramValuesAsStrings.append(self.getStringParam()) # Parameter is already a string
paramValuesAsStrings.append(','.join(self.getListOfStringsParam())) # ...convert from a list of strings
paramValuesAsStrings.append(str(self.getIntParam())) # ...convert from an int
return paramValuesAsStrings
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
# Convert back into our parameters. Make sure you do this in the same order you saved them!
py_obj = cls()
return py_obj
Example Model (also a Transformer) which has 3 different parameters:
from import Model
class MyModel(Model, PysparkReaderWriter):
def __init__(self):
super(MyModel, self).__init__()
# 3 sample parameters, deliberately of different types:
anotherStringParam = Param(Params._dummy(), "anotherStringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)
def setAnotherStringParam(self, value):
return self._set(anotherStringParam=value)
def getAnotherStringParam(self):
return self.getOrDefault(self.anotherStringParam)
anotherListOfStringsParam = Param(Params._dummy(), "anotherListOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)
def setAnotherListOfStringsParam(self, value):
return self._set(anotherListOfStringsParam=value)
def getAnotherListOfStringsParam(self):
return self.getOrDefault(self.anotherListOfStringsParam)
anotherIntParam = Param(Params._dummy(), "anotherIntParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)
def setAnotherIntParam(self, value):
return self._set(anotherIntParam=value)
def getAnotherIntParam(self):
return self.getOrDefault(self.anotherIntParam)
def _transform(self, dataset):
# Dummy transform code:
return dataset.withColumn('age2', dataset.age + self.getAnotherIntParam())
def getParamsAsListOfStrings(self):
paramValuesAsStrings = []
paramValuesAsStrings.append(self.getAnotherStringParam()) # Parameter is already a string
paramValuesAsStrings.append(','.join(self.getAnotherListOfStringsParam())) # ...convert from a list of strings
paramValuesAsStrings.append(str(self.getAnotherIntParam())) # ...convert from an int
return paramValuesAsStrings
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
# Convert back into our parameters. Make sure you do this in the same order you saved them!
py_obj = cls()
return py_obj
Below is a sample test case showing how you can save and load your model. It's similar for the estimator so I omit that for brevity.
def createAModel():
m = MyModel()
m.setAnotherListOfStringsParam(["P", "Q", "R"])
return m
def testSaveLoadModel():
modA = createAModel()
savePath = "/whatever/path/you/want" # Can't overwrite, so...
modB = MyModel.load(savePath)
anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: ['P', 'Q', 'R'])
anotherStringParam: A dummy string parameter (current: Boo!)
anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: [u'P', u'Q', u'R'])
anotherStringParam: A dummy string parameter (current: Boo!)
Notice how the parameters have come back in as unicode strings. This may or may not make a difference to your underlying algorithm that you implement in _transform() (or _fit() for the estimator). So be aware of this.
Finally, because the Scala algorithm behind the scenes is really a StopWordsRemover, you need to unwrap it back into your own class when loading the Pipeline or PipelineModel from disk. Here's the utility class that does this unwrapping:
from import Pipeline, PipelineModel
from import StopWordsRemover
class PysparkPipelineLoader(object):
A class to facilitate converting the stages of a Pipeline or PipelineModel
that were saved from PysparkReaderWriter.
def __init__(self):
super(PysparkPipelineLoader, self).__init__()
def unwrap(thingToUnwrap, customClassList):
if not (isinstance(thingToUnwrap, Pipeline) or isinstance(thingToUnwrap, PipelineModel)):
raise TypeError("Cannot recognize an object of type %s." % type(thingToUnwrap))
stages = thingToUnwrap.getStages() if isinstance(thingToUnwrap, Pipeline) else thingToUnwrap.stages
for i, stage in enumerate(stages):
if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
stages[i] = PysparkPipelineLoader.unwrap(stage)
if isinstance(stage, StopWordsRemover) and stage.getStopWords()[-1].startswith(PysparkReaderWriter._getPyObjIdPrefix()):
lastWord = stage.getStopWords()[-1]
className = lastWord[len(PysparkReaderWriter._getPyObjIdPrefix()):]
stopWords = stage.getStopWords()[:-1] # Strip the id
# Create and initialise the appropriate class:
py_obj = None
for clazz in customClassList:
if clazz.__name__ == className:
py_obj = clazz.createAndInitialisePyObj(stopWords)
if py_obj is None:
raise TypeError("I don't know how to create an instance of type: %s" % className)
stages[i] = py_obj
if isinstance(thingToUnwrap, Pipeline):
# PipelineModel
thingToUnwrap.stages = stages
return thingToUnwrap
Test for saving and loading a pipeline:
def testSaveAndLoadUnfittedPipeline():
estA = createAnEstimator()
pipelineA = Pipeline(stages=[estA])
savePath = "/whatever/path/you/want" # Can't overwrite, so...
pipelineReloaded = PysparkPipelineLoader.unwrap(Pipeline.load(savePath), [MyEstimator])
estB = pipelineReloaded.getStages()[0]
intParam: A dummy int parameter. (current: 42)
listOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D'])
stringParam: A dummy string parameter (current: Hello)
Test for saving and loading a pipeline model:
from pyspark.sql import Row
def make_a_dataframe(sc):
df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Bob', age=7, height=85), Row(name='Chris', age=10, height=90)]).toDF()
return df
def testSaveAndLoadPipelineModel():
dfA = make_a_dataframe(sc)
estA = createAnEstimator()
pipelineModelA = Pipeline(stages=[estA]).fit(dfA)
savePath = "/whatever/path/you/want" # Can't overwrite, so...
pipelineModelReloaded = PysparkPipelineLoader.unwrap(PipelineModel.load(savePath), [MyModel])
modB = pipelineModelReloaded.stages[0]
dfB = pipelineModelReloaded.transform(dfA)
anotherIntParam: A dummy int parameter. (current: 52)
anotherListOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D', u'E', u'F'])
anotherStringParam: A dummy string parameter (current: Hello World!)
|age|height| name|age2|
| 5| 80|Alice| 57|
| 7| 85| Bob| 59|
| 10| 90|Chris| 62|
When unwrapping a pipeline or pipeline model you have to pass in a list of the classes that correspond to your own pyspark algorithms that are masquerading as StopWordsRemover objects in the saved pipeline or pipeline model. The last stop word in your saved object is used to identify your own class's name and then createAndInitialisePyObj() is called to create an instance of your class and initialise its parameters with the remaining stop words.
Various refinements could be made. But hopefully this will enable you to save and load custom estimators and transformers, both inside and outside pipelines, until SPARK-17025 is resolved and available to you.
Similar to the working answer by #dmbaker, I wrapped my custom transformer called Aggregator inside of a built-in Spark transformer, in this example, Binarizer, though I'm sure you can inherit from other transformers, too. That allowed my custom transformer to inherit the methods necessary for serialization.
from import Pipeline
from import VectorAssembler, Binarizer
from import LinearRegression
class Aggregator(Binarizer):
"""A huge hack to allow serialization of custom transformer."""
def transform(self, input_df):
agg_df = input_df\
'foo': 'avg',
'bar': 'avg',
.withColumnRenamed('avg(foo)', 'avg_foo')\
.withColumnRenamed('avg(bar)', 'avg_bar')
return agg_df
# Create pipeline stages.
aggregator = Aggregator()
vector_assembler = VectorAssembler(...)
linear_regression = LinearRegression()
# Create pipeline.
pipeline = Pipeline(stages=[aggregator, vector_assembler, linear_regression])
# Train.
pipeline_model =
# Save model file to S3.'s3n://example')
The #dmbaker solution didn't work for me. I believe that is because the python version (2.x versus 3.x). I made some updates on his solution and now it works on Python 3. My setup is listed below:
python: 3.6.3
spark: 2.2.1
class PysparkObjId(object):
A class to specify constants used to idenify and setup python
Estimators, Transformers and Models so they can be serialized on there
own and from within a Pipline or PipelineModel.
def __init__(self):
super(PysparkObjId, self).__init__()
def _getPyObjId():
return '4c1740b00d3c4ff6806a1402321572cb'
def _getCarrierClass(javaName=False):
return '' if javaName else StopWordsRemover
class PysparkPipelineWrapper(object):
A class to facilitate converting the stages of a Pipeline or PipelineModel
that were saved from PysparkReaderWriter.
def __init__(self):
super(PysparkPipelineWrapper, self).__init__()
def unwrap(pipeline):
if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))
stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
for i, stage in enumerate(stages):
if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
stages[i] = PysparkPipelineWrapper.unwrap(stage)
if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
swords = stage.getStopWords()[:-1] # strip the id
# convert stop words to int
swords = [int(d) for d in swords]
# get the byte value of all ints
lst = [x.to_bytes(length=1, byteorder='big') for x in
swords] # convert from string integer list to bytes
# return the first byte and concatenates all the others
dmp = lst[0]
for byte_counter in range(1, len(lst)):
dmp = dmp + lst[byte_counter]
py_obj = dill.loads(dmp)
stages[i] = py_obj
if isinstance(pipeline, Pipeline):
pipeline.stages = stages
return pipeline
class PysparkReaderWriter(object):
A mixin class so custom pyspark Estimators, Transformers and Models may
support saving and loading directly or be saved within a Pipline or PipelineModel.
def __init__(self):
super(PysparkReaderWriter, self).__init__()
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
def read(cls):
"""Returns an MLReader instance for our clarrier class."""
return JavaMLReader(PysparkObjId._getCarrierClass())
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
swr_java_obj =
return cls._from_java(swr_java_obj)
def _from_java(cls, java_obj):
Get the dumby the stopwords that are the characters of the dills dump plus our guid
and convert, via dill, back to our python instance.
swords = java_obj.getStopWords()[:-1] # strip the id
lst = [x.to_bytes(length=1, byteorder='big') for x in swords] # convert from string integer list to bytes
dmp = lst[0]
for i in range(1, len(lst)):
dmp = dmp + lst[i]
py_obj = dill.loads(dmp)
return py_obj
def _to_java(self):
Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
:return: Java object equivalent to this instance.
dmp = dill.dumps(self)
pylist = [str(int(d)) for d in dmp] # convert bytes to string integer list
pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
sc = SparkContext._active_spark_context
java_class =
java_array = sc._gateway.new_array(java_class, len(pylist))
for i in range(len(pylist)):
java_array[i] = pylist[i]
_java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
return _java_obj
class HasFake(Params):
def __init__(self):
super(HasFake, self).__init__()
self.fake = Param(self, "fake", "fake param")
def getFake(self):
return self.getOrDefault(self.fake)
class CleanText(Transformer, HasInputCol, HasOutputCol, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
def __init__(self, inputCol=None, outputCol=None):
super(CleanText, self).__init__()
kwargs = self._input_kwargs
I wrote some base classes to make this easier. Basically I abstract all the complication of the code and initialisation into some base classes that expose a much simpler API to build custom ones. This includes taking care of the serialisation/deserialisation problem and saving and loading SparkML objects. Then you can use concentrate in the __init__ and transform/fit functions. You can find a full explanation with examples in here.
