I have some problem with using tensorflow serving.
I deployed my tensorflow model as RESTful APIs using tensorflow serving. But I doubt if tf-serving server supports multi-threading. I've done some experiments and it does not seem to be.
I also noticed that there is --tensorflow_session_parallelism option for tensorflow_model_server, but using the option makes my server more slow..
Is there any reference for using tensorflow serving with multi-threading?
Elaborating the content of the link provided by #ReInvent_IO, just in case if the link doesn't work in the future.
Code for the same is shown below:
"""A client that talks to tensorflow_model_server loaded with mnist model.
The client downloads test images of mnist data set, queries the service with
such test images to get predictions, and calculates the inference error rate.
Typical usage example:
mnist_client.py --num_tests=100 --server=localhost:9000
"""
from __future__ import print_function
import sys
import threading
# This is a placeholder for a Google-internal import.
import grpc
import numpy
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import mnist_input_data
Setting the value of concurrency to 5 asking the server to run 5 different threads
tf.app.flags.DEFINE_integer('concurrency', 5,
'maximum number of concurrent inference requests')
tf.app.flags.DEFINE_integer('num_tests', 100, 'Number of test images')
tf.app.flags.DEFINE_string('server', '', 'PredictionService host:port')
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory. ')
FLAGS = tf.app.flags.FLAGS
class _ResultCounter(object):
"""Counter for the prediction results."""
def __init__(self, num_tests, concurrency):
self._num_tests = num_tests
self._concurrency = concurrency
self._error = 0
self._done = 0
self._active = 0
self._condition = threading.Condition()
def inc_error(self):
with self._condition:
self._error += 1
def inc_done(self):
with self._condition:
self._done += 1
self._condition.notify()
def dec_active(self):
with self._condition:
self._active -= 1
self._condition.notify()
def get_error_rate(self):
with self._condition:
while self._done != self._num_tests:
self._condition.wait()
return self._error / float(self._num_tests)
def throttle(self):
with self._condition:
while self._active == self._concurrency:
self._condition.wait()
self._active += 1
def _create_rpc_callback(label, result_counter):
"""Creates RPC callback function.
Args:
label: The correct label for the predicted example.
result_counter: Counter for the prediction result.
Returns:
The callback function.
"""
def _callback(result_future):
"""Callback function.
Calculates the statistics for the prediction result.
Args:
result_future: Result future of the RPC.
"""
exception = result_future.exception()
if exception:
result_counter.inc_error()
print(exception)
else:
sys.stdout.write('.')
sys.stdout.flush()
response = numpy.array(
result_future.result().outputs['scores'].float_val)
prediction = numpy.argmax(response)
if label != prediction:
result_counter.inc_error()
result_counter.inc_done()
result_counter.dec_active()
return _callback
def do_inference(hostport, work_dir, concurrency, num_tests):
"""Tests PredictionService with concurrent requests.
Args:
hostport: Host:port address of the PredictionService.
work_dir: The full path of working directory for test data set.
concurrency: Maximum number of concurrent requests.
num_tests: Number of test images to use.
Returns:
The classification error rate.
Raises:
IOError: An error occurred processing test data set.
"""
test_data_set = mnist_input_data.read_data_sets(work_dir).test
channel = grpc.insecure_channel(hostport)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result_counter = _ResultCounter(num_tests, concurrency)
for _ in range(num_tests):
request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
request.model_spec.signature_name = 'predict_images'
image, label = test_data_set.next_batch(1)
request.inputs['images'].CopyFrom(
tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
result_counter.throttle()
result_future = stub.Predict.future(request, 5.0) # 5 seconds
result_future.add_done_callback(
_create_rpc_callback(label[0], result_counter))
return result_counter.get_error_rate()
def main(_):
if FLAGS.num_tests > 10000:
print('num_tests should not be greater than 10k')
return
if not FLAGS.server:
print('please specify server host:port')
return
error_rate = do_inference(FLAGS.server, FLAGS.work_dir,
FLAGS.concurrency, FLAGS.num_tests)
print('\nInference error rate: %s%%' % (error_rate * 100))
if __name__ == '__main__':
tf.app.run()
Related
To quickly sum up the problem, I need to transfer images (size is (1920,1200,3)) between PyTorch docker containers and process them. Containers are located in the same system. Speed is very important and transfer should not take more than 2-3ms one way. Two containers will be shared via IPC so I find no problem transferring NumPy arrays via shared memory using buffers (example https://docs.python.org/3/library/multiprocessing.shared_memory.html). I am curious is there a similar way to do that with PyTorch tensors allocated on GPU?
From what I've learned, CUDA Tensors are already in the shared memory. I tried transferring them and Pytorch Tensor Storage objects via socket but it takes around 50-60ms one way, which is way too slow. For testing purposes, I just run 2 programs in separate terminals.
Container 1 code:
import torch
import zmq
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.REQ)
sock.connect('tcp://0.0.0.0:6000')
x = torch.randn((1, 1920, 1200, 3), device='cuda')
storage = x.storage()
while True:
sock.send_pyobj(storage)
sock.recv()
if __name__ == "__main__":
main()
Container 2 code:
import torch
import zmq
import time
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.REP)
sock.bind('tcp://*:6000')
for i in range(10):
before = time.time()
storage = sock.recv_pyobj()
tensor = torch.tensor((), device=storage.device)
tensor.set_(storage)
after = time.time()
print(after - before)
sock.send_string('')
if __name__ == "__main__":
main()
Edit:
I found a similar topic discussed 4 years ago. There person extracts additional information from storage using share_cuda() function, which gives cudaIpcMemHandle_t.
Is there a way to reconstruct Storage/Tensor using cudaIpcMemHandle_t or information extracted from share_cuda() function using Pytoch functional? or there is a better way to achieve the same result?
I found a function in torch.multiprocessing.reductions that rebuilds tensors from the output generated by _share_cuda_(). Now my code looks something like this:
Container 1 code:
import torch
import zmq
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.REQ)
sock.connect('tcp://0.0.0.0:6000')
image = torch.randn((1, 1920, 1200, 3), dtype=torch.float, device='cuda:0')
storage = image.storage()
(storage_device, storage_handle, storage_size_bytes, storage_offset_bytes,
ref_counter_handle, ref_counter_offset, event_handle, event_sync_required) = storage._share_cuda_()
while True:
sock.send_pyobj({
"dtype": image.dtype,
"tensor_size": (1920, 1200, 3),
"tensor_stride": image.stride(),
"tensor_offset": image.storage_offset(), # !Not sure about this one.
"storage_cls": type(storage),
"storage_device": storage_device,
"storage_handle": storage_handle,
"storage_size_bytes": storage_size_bytes,
"storage_offset_bytes":storage_offset_bytes,
"requires_grad": False,
"ref_counter_handle": ref_counter_handle,
"ref_counter_offset": ref_counter_offset,
"event_handle": event_handle,
"event_sync_required": event_sync_required,
})
sock.recv_string()
if __name__ == "__main__":
main()
Container 2 code:
import torch
import zmq
import time
from torch.multiprocessing.reductions import rebuild_cuda_tensor
def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.REP)
sock.bind('tcp://*:6000')
for i in range(10):
before = time.time()
cuda_tensor_info = sock.recv_pyobj()
rebuilt_tensor = rebuild_cuda_tensor(torch.Tensor, **cuda_tensor_info)
after = time.time()
print(after - before)
sock.send_string('')
if __name__ == "__main__":
main()
I am running T5-base-grammar-correction for grammer correction on my dataframe with text column
from happytransformer import HappyTextToText
from happytransformer import TTSettings
from tqdm.notebook import tqdm
tqdm.pandas()
happy_tt = HappyTextToText("T5", "./t5-base-grammar-correction")
beam_settings = TTSettings(num_beams=5, min_length=1, max_length=30)
def grammer_pipeline(text):
text = "gec: " + text
result = happy_tt.generate_text(text, args=beam_settings)
return result.text
df['new_text'] = df['original_text'].progress_apply(grammer_pipeline)
Pandas apply function, though runs and provides required results, but runs quite slow.
Also I get the below warning while executing the code
/home/.local/lib/python3.6/site-packages/transformers/pipelines/base.py:908: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
UserWarning,
I have access to GPU. Can somebody provide some pointers to speed up the execution and utilising full capabilities of GPU
--------------------------------EDIT---------------------------------
I tried using pytorch Dataset in the below way, but still the processing is slow:
class CustomD(Dataset):
def __init__(self, text):
self.text = text
self.len = text.shape[0]
def __len__(self):
return self.len
def __getitem__(self, idx):
text = self.text[idx]
text = "gec: " + text
result = happy_tt.generate_text(text, args=beam_settings)
return result.text
TD = GramData(df.original_text)
final_data = DataLoader(dataset=TD,
batch_size=10,
shuffle=False
)
import itertools
list_modified=[]
for (idx, batch) in enumerate(final_data):
list_modified.append(batch)
flat_list = [item for sublist in list_modified for item in sublist]
df["new_text"]=flat_list
I want to publish messages to a Pub/Sub topic with some attributes thanks to Dataflow Job in batch mode.
My dataflow pipeline is write with python 3.8 and apache-beam 2.27.0
It works with the #Ankur solution here : https://stackoverflow.com/a/55824287/9455637
But I think it could be more efficient with a shared Pub/Sub Client : https://stackoverflow.com/a/55833997/9455637
However an error occurred:
return StockUnpickler.find_class(self, module, name) AttributeError:
Can't get attribute 'PublishFn' on <module 'dataflow_worker.start'
from
'/usr/local/lib/python3.8/site-packages/dataflow_worker/start.py'>
Questions:
Would the shared publisher implementation improve beam pipeline performance?
Is there another way to avoid pickling error on my shared publisher client ?
My Dataflow Pipeline :
import apache_beam as beam
from apache_beam.io.gcp import bigquery
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from google.cloud.pubsub_v1 import PublisherClient
import json
import argparse
import re
import logging
class PubsubClient(PublisherClient):
def __reduce__(self):
return self.__class__, (self.batch_settings,)
# The DoFn to perform on each element in the input PCollection.
class PublishFn(beam.DoFn):
def __init__(self):
from google.cloud import pubsub_v1
batch_settings = pubsub_v1.types.BatchSettings(
max_bytes=1024, # One kilobyte
max_latency=1, # One second
)
self.publisher = PubsubClient(batch_settings)
super().__init__()
def process(self, element, **kwargs):
future = self.publisher.publish(
topic=element["topic"],
data=json.dumps(element["data"]).encode("utf-8"),
**element["attributes"],
)
return future.result()
def run(argv=None, save_main_session=True):
"""Main entry point; defines and runs the pipeline."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_table_id",
dest="source_table_id",
default="",
help="BigQuery source table <project>.<dataset>.<table> with columns (topic, attributes, data)",
)
known_args, pipeline_args = parser.parse_known_args(argv)
# We use the save_main_session option because one or more DoFn's in this
# workflow rely on global context (e.g., a module imported at module level).
pipeline_options = PipelineOptions(pipeline_args)
# pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
bq_source_table = known_args.source_table_id
bq_table_regex = r"^(?P<PROJECT_ID>[a-zA-Z0-9_-]*)[\.|\:](?P<DATASET_ID>[a-zA-Z0-9_]*)\.(?P<TABLE_ID>[a-zA-Z0-9_-]*)$"
regex_match = re.search(bq_table_regex, bq_source_table)
if not regex_match:
raise ValueError(
f"Bad BigQuery table id : `{bq_source_table}` please match {bq_table_regex}"
)
table_ref = bigquery.TableReference(
projectId=regex_match.group("PROJECT_ID"),
datasetId=regex_match.group("DATASET_ID"),
tableId=regex_match.group("TABLE_ID"),
)
with beam.Pipeline(options=pipeline_options) as p:
(
p
| "ReadFromBqTable" #
>> bigquery.ReadFromBigQuery(table=table_ref, use_json_exports=True) # Each row contains : topic / attributes / data
| "PublishRowsToPubSub" >> beam.ParDo(PublishFn())
)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
run()
After fussing with this a bit, I think I have an answer that works consistently and is, if not world-beatingly performant, at least tolerably usable:
import logging
import apache_beam as beam
from apache_beam.io.gcp.pubsub import PubsubMessage
from google.cloud.pubsub_v1 import PublisherClient
from google.cloud.pubsub_v1.types import (
BatchSettings,
LimitExceededBehavior,
PublishFlowControl,
PublisherOptions,
)
class PublishClient(PublisherClient):
"""
You have to override __reduce__ to make PublisherClient pickleable 😡 😤 🤬
Props to 'Ankur' and 'Benjamin' on SO for figuring this part out; god knows
I would not have...
"""
def __reduce__(self):
return self.__class__, (self.batch_settings, self.publisher_options)
class PubsubWriter(beam.DoFn):
"""
beam.io.gcp.pubsub does not yet support batch operations, so
we do this the hard way. it's not as performant as the native
pubsubio but it does the job.
"""
def __init__(self, topic: str):
self.topic = topic
self.window = beam.window.GlobalWindow()
self.count = 0
def setup(self):
batch_settings = BatchSettings(
max_bytes=1e6, # 1MB
# by default it is 10 ms, should be less than timeout used in future.result() to avoid timeout
max_latency=1,
)
publisher_options = PublisherOptions(
enable_message_ordering=False,
# better to be slow than to drop messages during a recovery...
flow_control=PublishFlowControl(limit_exceeded_behavior=LimitExceededBehavior.BLOCK),
)
self.publisher = PublishClient(batch_settings, publisher_options)
def start_bundle(self):
self.futures = []
def process(self, element: PubsubMessage, window=beam.DoFn.WindowParam):
self.window = window
self.futures.append(
self.publisher.publish(
topic=self.topic,
data=element.data,
**element.attributes,
)
)
def finish_bundle(self):
"""Iterate over the list of async publish results and block
until all of them have either succeeded or timed out. Yield
a WindowedValue of the success/fail counts."""
results = []
self.count = self.count + len(self.futures)
for fut in self.futures:
try:
# future.result() blocks until success or timeout;
# we've set a max_latency of 60s upstairs in BatchSettings,
# so we should never spend much time waiting here.
results.append(fut.result(timeout=60))
except Exception as ex:
results.append(ex)
res_count = {"success": 0}
for res in results:
if isinstance(res, str):
res_count["success"] += 1
else:
# if it's not a string, it's an exception
msg = str(res)
if msg not in res_count:
res_count[msg] = 1
else:
res_count[msg] += 1
logging.info(f"Pubsub publish results: {res_count}")
yield beam.utils.windowed_value.WindowedValue(
value=res_count,
timestamp=0,
windows=[self.window],
)
def teardown(self):
logging.info(f"Published {self.count} messages")
The trick is that if you call future.result() inside the process() method, you will block until that single message is successfully published, so instead collect a list of futures and then at the end of the bundle make sure they're all either published or definitively timed out. Some quick testing with one of our internal pipelines suggested that this approach can publish 1.6M messages in ~200s.
I have the following situation: Using python 3.6 and Tornado 5.1 to receive client requests by web socket. Some of these requests require you to invoke an external processing, which returns a queue and then deposits results periodically in it. These results are transmitted via websocket to the clients.
External processing is NOT a coroutine, so I invoke it using run_in_executor.
My problem:
When the response time of the external processing is very large, the run_in_executor reaches the maximum number of workers (default: number of processors x 5)!
Is it safe to increase the maximum number of workers?
Or is another solution recommended? !!
Below a simplified code.
From already thank you very much!!!!
#########################
## SERVER CODE ##
#########################
from random import randint
import tornado.httpserver
import tornado.websocket
import tornado.ioloop
import tornado.web
from random import randint
from tornado import gen
import threading
import asyncio
import queue
import time
class WSHandler(tornado.websocket.WebSocketHandler):
"""entry point for all WS request"""
def open(self):
print('new connection. Request: ' + str(self.request))
async def on_message(self, message):
# Emulates the subscription to an external object
# that returns a queue to listen
producer = Producer()
q = producer.q
while True:
rta = await tornado.ioloop.IOLoop.current().run_in_executor(None, self.loop_on_q, q)
if rta != None:
await self.write_message(str(rta))
else:
break
def on_close(self):
print('connection closed. Request: ' + str(self.request) +
'. close_reason: ' + str(self.close_reason) +
'. close_code: ' + str(self.close_code) +
'. get_status: ' + str(self.get_status()))
def loop_on_q(self, q):
rta = q.get()
return rta
class Producer:
def __init__(self):
self.q = queue.Queue()
t = threading.Thread(target=self.start)
t.daemon = True
t.start()
def start(self):
count = 1
while True:
# time.sleep(randint(1,5))
if count < 100:
self.q.put(count)
else:
self.q.put(None)
break
time.sleep(50)
count += 1
application = tornado.web.Application([
(r'/ws', WSHandler),
])
if __name__ == "__main__":
asyncio.set_event_loop(asyncio.new_event_loop())
http_server = tornado.httpserver.HTTPServer(application)
http_server.listen(8888)
print('SRV START')
tornado.ioloop.IOLoop.instance().instance().start()
#########################
## CLIENT CODE ##
#########################
# If you run it more than 20 times in less than 50 seconds ==> Block
# (number of processors x 5), I have 4 cores
from websocket import create_connection
def conect():
url = 'ws://localhost:8888/ws'
ws = create_connection(url)
print('Conecting')
return ws
print('Conecting to srv')
con_ws = conect()
print('Established connection. Sending msg ...')
msj = '{"type":"Socket"}'
con_ws.send(msj)
print('Package sent. Waiting answer...')
while True:
result = con_ws.recv()
print('Answer: ' + str(result))
Is it safe to increase the maximum number of workers Yes, up to a certain fixed amount which can be calculated with load testing.
Or is another solution recommended? If you reach workers limit you can move workers to multiple separated servers (this approach is called horizontal scaling) and pass jobs to them with a message queue. See Celery as a batteries-included-solution or RabbitMQ, Kafka etc. if you prefer to write everything by yourself.
I run an evaluation at the end of each epoch and need to show an image calculated from the features and labels arguments of the model function model_fn. Including a tf.summary.image(name, image) in evaluation part of the model function does not help and it looks to me that the only way to do so is to pass the correct eval_metric_ops to construct the EstimatorSpec for mode EVAL. So I first sub-class Estimator so that it considers images. The following code is mostly from estimator.py; the only change is the few lines marked by "my change" inside _write_dict_to_summary:
import logging
import io
import numpy as np
import matplotlib.pyplot as plt
import six
from google.protobuf import message
import tensorflow as tf
from tensorflow.python.training import evaluation
from tensorflow.python import ops
from tensorflow.python.estimator.estimator import _dict_to_str, _write_checkpoint_path_to_summary
from tensorflow.core.framework import summary_pb2
from tensorflow.python.framework import tensor_util
from tensorflow.python.summary.writer import writer_cache
def dump_as_image(a):
vmin = np.min(a)
vmax = np.max(a)
img = np.squeeze((img - vmin) / (vmax - vmin) * 255).astype(np.uint8)
s = io.BytesIO()
plt.imsave(s, img, format='png', vmin=0, vmax=255, cmap='gray')
return s.getvalue()
# see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/estimator/estimator.py
def _write_dict_to_summary(output_dir, dictionary, current_global_step):
logging.info('Saving dict for global step %d: %s', current_global_step, _dict_to_str(dictionary))
summary_writer = writer_cache.FileWriterCache.get(output_dir)
summary_proto = summary_pb2.Summary()
for key in dictionary:
if dictionary[key] is None:
continue
if key == 'global_step':
continue
if (isinstance(dictionary[key], np.float32) or
isinstance(dictionary[key], float)):
summary_proto.value.add(tag=key, simple_value=float(dictionary[key]))
elif (isinstance(dictionary[key], np.int64) or
isinstance(dictionary[key], np.int32) or
isinstance(dictionary[key], int)):
summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
elif isinstance(dictionary[key], six.binary_type):
try:
summ = summary_pb2.Summary.FromString(dictionary[key])
for i, img_bytes in enumerate(summ.value):
summ.value[i].tag = '%s/%d' % (key, i)
summary_proto.value.extend(summ.value)
except message.DecodeError:
logging.warn('Skipping summary for %s, cannot parse string to Summary.', key)
continue
elif isinstance(dictionary[key], np.ndarray):
value = summary_proto.value.add()
value.tag = key
value.node_name = key
array = dictionary[key]
# my change begins
if array.ndim == 2:
buffer = dump_as_image(array)
value.image.encoded_image_string = buffer
# my change ends
else:
tensor_proto = tensor_util.make_tensor_proto(array)
value.tensor.CopyFrom(tensor_proto)
logging.info(
'Summary for np.ndarray is not visible in Tensorboard by default. '
'Consider using a Tensorboard plugin for visualization (see '
'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
' for more information).')
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, '
'np.int32 or int or np.ndarray or a serialized string of Summary.',
key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()
class ImageMonitoringEstimator(tf.estimator.Estimator):
def __init__(self, *args, **kwargs):
tf.estimator.Estimator._assert_members_are_not_overridden = lambda self: None
super(ImageMonitoringEstimator, self).__init__(*args, **kwargs)
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict, all_hooks, output_dir):
eval_results = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
master=self._config.evaluation_master,
scaffold=scaffold,
eval_ops=update_op,
final_ops=eval_dict,
hooks=all_hooks,
config=self._session_config)
current_global_step = eval_results[ops.GraphKeys.GLOBAL_STEP]
_write_dict_to_summary(
output_dir=output_dir,
dictionary=eval_results,
current_global_step=current_global_step)
if checkpoint_path:
_write_checkpoint_path_to_summary(
output_dir=output_dir,
checkpoint_path=checkpoint_path,
current_global_step=current_global_step)
return eval_results
the model function is like --
def model_func(features, labels, mode):
# calculate network_output
if mode == tf.estimator.ModeKeys.TRAIN:
# training
elif mode == tf.estimator.ModeKeys.EVAL:
# make_image consists of slicing and concatenations
images = tf.map_fn(make_image, (features, network_output, labels), dtype=features.dtype)
eval_metric_ops = images, tf.no_op() # not working
return tf.estimator.EstimatorSpec(mode, loss=loss)
eval_metric_ops={'images': eval_metric_ops})
else:
# prediction
And the main part --
# mon_features and mon_labels are np.ndarray
estimator = ImageMonitoringEstimator(model_fn=model_func,...)
mon_input_func = tf.estimator.inputs.numpy_input_fn(mon_features,
mon_labels,
shuffle=False,
num_epochs=num_epochs,
batch_size=len(mon_features))
for _ in range(num_epochs):
estimator.train(...)
estimator.evaluate(input_fn=mon_input_func)
The code above will give a warning (later an error):
WARNING:tensorflow:An OutOfRangeError or StopIteration exception is
raised by the code in FinalOpsHook. This typically means the Ops
running by the FinalOpsHook have a dependency back to some input
source, which should not happen. For example, for metrics in
tf.estimator.Estimator, all metrics functions return two Ops:
value_op and update_op. Estimator.evaluate calls the update_op
for each batch of the data in input source and, once it is exhausted,
it call the value_op to get the metric values. The value_op here
should have dependency back to variables reading only, rather than
reading another batch from input. Otherwise, the value_op, executed
by FinalOpsHook, triggers another data reading, which ends
OutOfRangeError/StopIteration. Please fix that.
Looks like I didn't set the eval_metric_ops correctly. I guess tf.map_fn touches another batch as the warning message hints; maybe I need some stacking operation as the update_op to build the images used for monitoring incrementally? But I am not sure how to do that. So how to add an image to summary during evaluation when using Estimator?
The way I make it work is by passing a tf.train.SummarySaverHook under the evaluation mode and then declaring it to the tf.estimator.EstimatorSpec at evaluation_hooks=.
images is a list of the desired tf.summary.image you want to print during evaluation.
example:
eval_summary_hook = tf.train.SummarySaverHook(output_dir=params['eval_save_path'], summary_op=images, save_secs=120)
spec = tf.estimator.EstimatorSpec(mode=mode, predictions=y_pred, loss=loss, eval_metric_ops=eval_metric_ops,
evaluation_hooks=[eval_summary_hook])