I am trying to do OCR(Optical Character recognition) on multiple images present in a folder. I am using python multiprocessing module to do it parallely. I am spawning new processes.
When I am running code directly using python command it works fine but when I import Recogniser class and call process_file, it keeps on spawning new process from calling class.
I tried freeze_support but that didn't help.
Right now to make it work, I am doing subprocess call which is not at all right
ImageRecogniser code:
from tesserocr import PyTessBaseAPI
import os
from multiprocessing import get_context, Value
import sys
import time
from datetime import datetime
import json
from PIL import Image
from itertools import islice
queue_counter = Value('i', 0)
class ImageRecogniser:
def __init__(self, input_folder_name):
self.input_folder_name = input_folder_name
self.total_files = 0
def get_images(self, extension=(".png",)):
files_in_dir = os.listdir(self.input_folder_name)
image_list = [os.path.join(self.input_folder_name, file_) for file_ in files_in_dir
if os.path.splitext(file_)[-1] in extension]
self.total_files = len(image_list)
return image_list
def process_images(self, images):
string_list = []
api = PyTessBaseAPI()
for count, image in enumerate(images):
img = Image.open(image)
api.SetImage(img)
output_text = api.GetUTF8Text()
temp_filename = os.path.splitext(os.path.split(image)[-1])[0]
page_number = temp_filename.split("_")[-1]
print("Processed page : {}".format(page_number))
output_dict = {"page_number": page_number, "output": output_text}
string_list.append(output_dict)
api.End()
return string_list
def process_file(self, parallel=False):
print("Getting images from pdf : {}".format(datetime.now()))
image_list = self.get_images()
total_no_of_pages = len(image_list)
print("Initiating extraction : {}".format(datetime.now()))
begin_time = datetime.now()
if parallel:
available_cpus = len(os.sched_getaffinity(0))
pool_workers = available_cpus // 1
if total_no_of_pages > pool_workers:
size_of_queue = total_no_of_pages // pool_workers
split_queue = [size_of_queue] * pool_workers
sum_of_split_queue = sum(split_queue)
if sum_of_split_queue != total_no_of_pages:
pages_left = total_no_of_pages - sum_of_split_queue
for i in range(pages_left):
split_queue[i] = split_queue[i] + 1
else:
split_queue = [0] * pool_workers
for i in range(total_no_of_pages):
split_queue[i] = 1
# page_ids_list_iterator = iter(page_ids_list)
print('Split queue : {}'.format(split_queue))
to_be_processed = []
start = 1
for chunksize in split_queue:
end = chunksize + start
to_be_processed.append(image_list[start:end])
start = end
with get_context('spawn').Pool(processes=pool_workers) as pool:
result = pool.map_async(self.process_images, to_be_processed)
output = result.get()
output = [actual_value for queue_list in output for actual_value in queue_list]
#json_output = json.dumps(output)
else:
output = self.process_images(image_list)
json_file_name = self.input_folder_name+ ".json"
with open(json_file_name, "w") as fp:
json.dump(output, fp)
if __name__ == '__main__':
source = sys.argv[1]
t = ImageRecogniser(source)
try:
current = os.environ["OMP_THREAD_LIMIT"]
except:
current = None
os.environ["OMP_THREAD_LIMIT"] = "1"
t.process_file(parallel=True)
if current:
os.environ["OMP_THREAD_LIMIT"] = str(current)
worker.py
import os
import logging
import yaml
from pdf_converter.pdf_to_image import PDFtoImage
from pdf_converter.settings import DATA_FOLDER_PATH, LOGGER_FILE_PATH
from recogniser.image_recogniser import ImageRecogniser
import subprocess
with open(LOGGER_FILE_PATH, 'r') as f:
config = yaml.safe_load(f.read())
logging.config.dictConfig(config)
logger = logging.getLogger(__name__)
logger.info("STARTING WORKER")
file_name = "4506-T.pdf"
file_path = os.path.join(DATA_FOLDER_PATH, file_name)
logger.info("STARTING PDFTOIMAGE CONVERSION")
pdf_image = PDFtoImage(file_path)
files = pdf_image.run()
logger.info("CREATED {} IMAGE FILES FROM {}".format(len(files),file_name))
image_files_path = os.path.join(DATA_FOLDER_PATH, os.path.splitext(file_name)[0])
# subprocess.call(['python', 'recogniser/image_recogniser_cmd.py', image_files_path])
img_rec = ImageRecogniser(image_files_path)
img_rec.process_file(parallel=True)
error
2020-07-27 19:57:49,244 - __mp_main__ - INFO - CREATED 53 IMAGE FILES FROM 4506-T.pdf
Getting images from pdf : /home/shashank-mq/project/doOCR/data/4506-T.pdf
Initiating extraction : 2020-07-27 19:57:49.245889
Split queue : [7, 7, 7, 7, 7, 6, 6, 6]
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/usr/lib/python3.6/multiprocessing/spawn.py", line 105, in spawn_main
exitcode = _main(fd)
File "/usr/lib/python3.6/multiprocessing/spawn.py", line 114, in _main
prepare(preparation_data)
File "/usr/lib/python3.6/multiprocessing/spawn.py", line 225, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "/usr/lib/python3.6/multiprocessing/spawn.py", line 277, in _fixup_main_from_path
run_name="__mp_main__")
File "/usr/lib/python3.6/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/usr/lib/python3.6/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/shashank-mq/project/doOCR/worker.py", line 27, in <module>
img_rec.process_file(parallel=True)
File "/home/shashank-mq/project/doOCR/recogniser/image_recogniser.py", line 78, in process_file
with get_context('spawn').Pool(processes=pool_workers) as pool:
File "/usr/lib/python3.6/multiprocessing/context.py", line 119, in Pool
context=self.get_context())
File "/usr/lib/python3.6/multiprocessing/pool.py", line 174, in __init__
self._repopulate_pool()
File "/usr/lib/python3.6/multiprocessing/pool.py", line 239, in _repopulate_pool
w.start()
File "/usr/lib/python3.6/multiprocessing/process.py", line 105, in start
self._popen = self._Popen(self)
File "/usr/lib/python3.6/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/usr/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/usr/lib/python3.6/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/usr/lib/python3.6/multiprocessing/popen_spawn_posix.py", line 42, in _launch
prep_data = spawn.get_preparation_data(process_obj._name)
File "/usr/lib/python3.6/multiprocessing/spawn.py", line 143, in get_preparation_data
_check_not_importing_main()
File "/usr/lib/python3.6/multiprocessing/spawn.py", line 136, in _check_not_importing_main
is not going to be frozen to produce an executable.''')
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
Please suggest what I am missing and how to do it
Related
i have been trying to show a score converted from krn file
import music21 as m
import os
test_data = "D:/Programming/DATA - SCIENCE/deep learning/music generation/data/test"
def load_krn_files(data_path):
# go through the whole files
songs = []
for path , subdirs,files in os.walk(data_path):
for file in files :
if file [-3:] == "krn":
song = m.converter.parse(os.path.join(path,file))
songs.append(song)
return songs
def preproccessing(data_path):
pass
#1 ) load the kern files and pars them
if __name__ == "__main__":
songs = load_krn_files(test_data)
print(f"loaded {len(songs)} songs.")
song = songs[0]
song.show()
but the method (show) return the following error
loaded 12 songs.
Traceback (most recent call last):
File "d:/Programming/DATA - SCIENCE/deep learning/music generation/scripts/preprocess.py", line 25, in
song.show()
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\site-packages\music21\stream\base.py", line 334, in show
return super().show(fmt=fmt, app=app, **keywords) File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\site-packages\music21\base.py", line 2788, in show
return formatWriter.show(self,
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\site-packages\music21\converter\subConverters.py", line 1114, in show
self.launch(returnedFilePath, fmt=fmt, app=app)
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\site-packages\music21\converter\subConverters.py", line 197, in launch
subprocess.run(cmd, check=False, shell=shell)
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\site-packages\run_init_.py", line 145, in new
process = cls.create_process(command, stdin, cwd=cwd, env=env, shell=shell)
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\site-packages\run_init_.py", line 121, in create_process
shlex.split(command),
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\shlex.py", line 311, in split
return list(lex)
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\shlex.py", line 300, in next
token = self.get_token()
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\shlex.py", line 109, in get_token
raw = self.read_token()
File "C:\Users\ae504\AppData\Local\Programs\Python\Python38\lib\shlex.py", line 140, in read_token
nextchar = self.instream.read(1)
AttributeError: 'tuple' object has no attribute 'read'
Trying to read CSV files into pandas data frames using multiprocessing but get pickle error.
python 3.8.8
pandas 1.2.4
import os
import pandas as PD
import time
from multiprocessing import Pool
def getExcelData(fn):
data = pd.DataFrame()
return data.append(pd.read_csv(fn), sort=False)
if __name__ == "__main__":
dir = '.'
fn_ls = [ f'{fn}' for fn in os.listdir(dir) if fn.endswith('test.csv') ]
startTime = time.time()
pool = Pool(2)
pool_data_list = []
data = pd.DataFrame()
for file_name in fn_ls:
pool_data_list.append(pool.apply_async(getExcelData, (os.path.join(dir, file_name),)))
pool.close()
pool.join()
for pool_data in pool_data_list:
data = data.append(pool_data.get())
res_ls = []
for pool_data in pool_data_list:
res_ls = pool_data.get()
endTime = time.time()
print(endTime - startTime)
print(len(data))
Traceback (most recent call last):
File "/Users/cxx/opt/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 1, in
runfile('/Users/cxx/xiaoxi/18_Mercury/raw_data/raw/5000bp/test/test.py', wdir='/Users/cxx/xiaoxi/18_Mercury/raw_data/raw/5000bp/test')
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/cxx/xiaoxi/18_Mercury/raw_data/raw/5000bp/test/test.py", line 33, in
data = data.append(pool_data.get())
File "/Users/cxx/opt/anaconda3/lib/python3.8/multiprocessing/pool.py", line 771, in get
raise self._value
File "/Users/cxx/opt/anaconda3/lib/python3.8/multiprocessing/pool.py", line 537, in _handle_tasks
put(task)
File "/Users/cxx/opt/anaconda3/lib/python3.8/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/Users/cxx/opt/anaconda3/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function getExcelData at 0x7f84e9ad19d0>: attribute lookup getExcelData on main failed
Replace everything between startTime and endTime with a simple map call in a context manager:
with Pool(2) as pool:
data = [df for df in pool.imap(getExcelData, fn_ls)]
When I tried to call Cloud Video Intelligence API to detec subtitle in local video file.It always returned error 400 or 504, but use gas is fine.I have tried to adjusted timeout in Cloud Video Intelligence config but it still show error 400 with invalid argument.
this is my python code for detecting video subtitle:
"""This application demonstrates detection subtitles in video using the Google Cloud API.
Usage Examples:
use video in google cloud storge:
python analyze.py text_gcs gs://"video path"
use video in computer:
python analyze.py text_file video.mp4
"""
import argparse
import io
from google.cloud import videointelligence
from google.cloud.videointelligence import enums
def video_detect_text_gcs(input_uri):
# [START video_detect_text_gcs]
"""Detect text in a video stored on GCS."""
from google.cloud import videointelligence
video_client = videointelligence.VideoIntelligenceServiceClient()
features = [videointelligence.enums.Feature.TEXT_DETECTION]
config = videointelligence.types.TextDetectionConfig(language_hints=["zh-TW","en-US"])
video_context = videointelligence.types.VideoContext(text_detection_config=config)
operation = video_client.annotate_video(input_uri=input_uri, features=features, video_context=video_context)
print("\nSubtitle detecting......")
result = operation.result(timeout=300)
# The first result is retrieved because a single video was processed.
annotation_result = result.annotation_results[0]
subtitle_data=[ ]
for text_annotation in annotation_result.text_annotations:
text_segment = text_annotation.segments[0]
start_time = text_segment.segment.start_time_offset
frame = text_segment.frames[0]
vertex=frame.rotated_bounding_box.vertices[0]
if text_segment.confidence > 0.95 and vertex.y >0.7:
lists=[text_annotation.text,start_time.seconds+ start_time.nanos * 1e-9,vertex.y]
subtitle_data=subtitle_data+[lists]
length=len(subtitle_data)
subtitle_sort=sorted(subtitle_data,key = lambda x: (x[1],x[2]))
i=0
subtitle=[ ]
while i<length :
subtitle=subtitle+[subtitle_sort[i][0]]
i=i+1
with open("subtitle.txt",mode="w",encoding="utf-8") as file:
for x in subtitle:
file.write(x+'\n')
def video_detect_text(path):
# [START video_detect_text]
"""Detect text in a local video."""
from google.cloud import videointelligence
video_client = videointelligence.VideoIntelligenceServiceClient()
features = [videointelligence.enums.Feature.TEXT_DETECTION]
video_context = videointelligence.types.VideoContext()
with io.open(path, "rb") as file:
input_content = file.read()
operation = video_client.annotate_video(
input_content=input_content, # the bytes of the video file
features=features,
video_context=video_context
)
print("\nSubtitle detecting......")
result = operation.result(timeout=300)
# The first result is retrieved because a single video was processed.
annotation_result = result.annotation_results[0]
subtitle_data=[ ]
for text_annotation in annotation_result.text_annotations:
text_segment = text_annotation.segments[0]
start_time = text_segment.segment.start_time_offset
frame = text_segment.frames[0]
vertex=frame.rotated_bounding_box.vertices[0]
if text_segment.confidence > 0.95 and vertex.y >0.7:
lists=[text_annotation.text,start_time.seconds+ start_time.nanos * 1e-9,vertex.y]
subtitle_data=subtitle_data+[lists]
length=len(subtitle_data)
subtitle_sort=sorted(subtitle_data,key = lambda x: (x[1],x[2]))
i=0
subtitle=[ ]
while i<length :
subtitle=subtitle+[subtitle_sort[i][0]]
i=i+1
with open("subtitle.txt",mode="w",encoding="utf-8") as file:
for x in subtitle:
file.write(x+'\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
subparsers = parser.add_subparsers(dest="command")
detect_text_parser = subparsers.add_parser(
"text_gcs", help=video_detect_text_gcs.__doc__
)
detect_text_parser.add_argument("path")
detect_text_file_parser = subparsers.add_parser(
"text_file", help=video_detect_text.__doc__
)
detect_text_file_parser.add_argument("path")
args = parser.parse_args()
if args.command == "text_gcs":
video_detect_text_gcs(args.path)
if args.command == "text_file":
video_detect_text(args.path)
This is error report:
Ghuang#/Users/Ghuang/Documents/GitHub/Video-subtitles-detection$ python3 analyze.py text_file video.mp4
Traceback (most recent call last):
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/google/api_core/grpc_helpers.py", line 57, in error_remapped_callable
return callable_(*args, **kwargs)
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/grpc/_channel.py", line 826, in __call__
return _end_unary_response_blocking(state, call, False, None)
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/grpc/_channel.py", line 729, in _end_unary_response_blocking
raise _InactiveRpcError(state)
grpc._channel._InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
status = StatusCode.DEADLINE_EXCEEDED
details = "Deadline Exceeded"
debug_error_string = "{"created":"#1587691109.677447000","description":"Error received from peer ipv4:172.217.24.10:443","file":"src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Deadline Exceeded","grpc_status":4}"
>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "analyze.py", line 144, in <module>
video_detect_text(args.path)
File "analyze.py", line 90, in video_detect_text
video_context=video_context
File "/Library/Python/3.7/site-packages/google/cloud/videointelligence_v1/gapic/video_intelligence_service_client.py", line 303, in annotate_video
request, retry=retry, timeout=timeout, metadata=metadata
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/google/api_core/gapic_v1/method.py", line 143, in __call__
return wrapped_func(*args, **kwargs)
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/google/api_core/retry.py", line 286, in retry_wrapped_func
on_error=on_error,
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/google/api_core/retry.py", line 184, in retry_target
return target()
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/google/api_core/timeout.py", line 214, in func_with_timeout
return func(*args, **kwargs)
File "/Users/Ghuang/Library/Python/3.7/lib/python/site-packages/google/api_core/grpc_helpers.py", line 59, in error_remapped_callable
six.raise_from(exceptions.from_grpc_error(exc), exc)
File "<string>", line 3, in raise_from
google.api_core.exceptions.DeadlineExceeded: 504 Deadline Exceeded
I'm newbie and trying to use the google-cloud speech-to-text with python and multiprocessing. Here is a simple example to reproduce my issue.
I'm running the code on Windows.
When I run the code without multiprocessing, it works fine.
import io
from tqdm import tqdm
from multiprocessing import Pool, freeze_support, cpu_count
from google.cloud import speech
from google.cloud.speech import enums
from google.cloud.speech import types
# Instantiates a client
CLIENT = speech.SpeechClient()
def speech_to_text(file_name, language= "en-US"):
with io.open(file_name, 'rb') as audio_file:
content = audio_file.read()
audio = types.RecognitionAudio(content=content)
config = types.RecognitionConfig(
encoding=enums.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED,
sample_rate_hertz=16000,
language_code= language)
# Detects speech in the audio file
response = CLIENT.recognize(config, audio)
transcript = ""
if len(response.results):
transcript = response.results[0].alternatives[0].transcript
return transcript
def worker(ix):
audio_file_name = "audio.mp3"
transcript = speech_to_text(audio_file_name)
if __name__ == "__main__":
n_cores = cpu_count() - 1
freeze_support() # for Windows support
with Pool(n_cores) as p:
max_ = len(range(2))
with tqdm(total=max_) as pbar:
for i, result in enumerate(tqdm(p.imap_unordered(worker, range(2)))):
pbar.update()
Here is the Error message that I get:
Traceback (most recent call last):
File "C:\Users\me\Anaconda3\lib\multiprocessing\spawn.py", line 114, in _main
prepare(preparation_data)
File "C:\Users\me\Anaconda3\lib\multiprocessing\spawn.py", line 225, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "C:\Users\me\Anaconda3\lib\multiprocessing\spawn.py", line 277, in _fixup_main_from_path
run_name="__mp_main__")
File "C:\Users\me\Anaconda3\lib\runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "C:\Users\me\Anaconda3\lib\runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "C:\Users\me\Anaconda3\lib\runpy.py", line 85, in _run_code
exec(code, run_globals)
File "C:\Users\me\Desktop\outCaptcha\multiproc.py", line 10, in <module>
from google.cloud import speech
File "C:\Users\me\.virtualenvs\outCaptcha\lib\site-packages\google\cloud\speech.py", line 20, in <module>
from google.cloud.speech_v1 import SpeechClient
File "C:\Users\me\.virtualenvs\outCaptcha\lib\site-packages\google\cloud\speech_v1\__init__.py", line 17, in <module>
from google.cloud.speech_v1.gapic import speech_client
File "C:\Users\me\.virtualenvs\outCaptcha\lib\site-packages\google\cloud\speech_v1\gapic\speech_client.py", line 23, in <module>
import google.api_core.client_options
File "C:\Users\me\.virtualenvs\outCaptcha\lib\site-packages\google\api_core\__init__.py", line 23, in <module>
__version__ = get_distribution("google-api-core").version
File "C:\Users\me\AppData\Roaming\Python\Python37\site-packages\pkg_resources\__init__.py", line 481, in get_distribution
dist = get_provider(dist)
File "C:\Users\me\AppData\Roaming\Python\Python37\site-packages\pkg_resources\__init__.py", line 357, in get_provider
return working_set.find(moduleOrReq) or require(str(moduleOrReq))[0]
File "C:\Users\me\AppData\Roaming\Python\Python37\site-packages\pkg_resources\__init__.py", line 900, in require
needed = self.resolve(parse_requirements(requirements))
File "C:\Users\me\AppData\Roaming\Python\Python37\site-packages\pkg_resources\__init__.py", line 786, in resolve
raise DistributionNotFound(req, requirers)
pkg_resources.DistributionNotFound: The 'google-api-core' distribution was not found and is required by the application
Thanks a lot for your help.
Please let me know if you need any details about the issue
In my case this solved the problem.
easy_install --upgrade google-api-core
easy_install --upgrade
google-cloud-speech
I hope this helps.
I am using the following code for machine learning purposes (I am also quite new to python and pytorch). Basically, I think the problem is that multitasking is not happening for some reason.
I am using code from here: https://raw.githubusercontent.com/harryhan618/LaneNet/master/demo_test.py
The purpose of the code is draw lane markings on an image.
import cv2
import torch
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from lane_files.model import LaneNet
from lane_files.utils.transforms import *
from lane_files.utils.postprocess import embedding_post_process
if __name__=='__main__':
net = LaneNet(pretrained=False, embed_dim=7, delta_v=.5, delta_d=3.)
transform = Compose(Resize((800, 288)), ToTensor(),
Normalize(mean=(0.3598, 0.3653, 0.3662), std=(0.2573, 0.2663, 0.2756)))
img = cv2.imread('data/train_images/frame0.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB for net model input
x = transform(img)[0]
x.unsqueeze_(0)
save_dict = torch.load('lane_files/experiments/exp0/exp0_best.pth', map_location='cpu')
net.load_state_dict(save_dict['net'])
net.eval()
output = net(x)
embedding = output['embedding']
embedding = embedding.detach().cpu().numpy()
embedding = np.transpose(embedding[0], (1, 2, 0))
binary_seg = output['binary_seg']
bin_seg_prob = binary_seg.detach().cpu().numpy()
bin_seg_pred = np.argmax(bin_seg_prob, axis=1)[0]
seg_img = np.zeros_like(img)
lane_seg_img = embedding_post_process(embedding, bin_seg_pred, 0.5)
color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8')
for i, lane_idx in enumerate(np.unique(lane_seg_img)):
seg_img[lane_seg_img == lane_idx] = color[i]
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = cv2.resize(img, (800, 288))
img = cv2.addWeighted(src1=seg_img, alpha=0.8, src2=img, beta=1., gamma=0.)
cv2.imshow("", img)
cv2.waitKey(5000)
cv2.destroyAllWindows()
Expected result: An image displayed with lane markings on it
Actual result:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:/Users/sarim/PycharmProjects/thesis/pytorch_learning.py", line 36, in <module>
lane_seg_img = embedding_post_process(embedding, bin_seg_pred, 0.5)
File "C:\Users\sarim\PycharmProjects\thesis\lane_files\utils\postprocess.py", line 29, in embedding_post_process
mean_shift.fit(embedding_reshaped)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\cluster\mean_shift_.py", line 424, in fit
cluster_all=self.cluster_all, n_jobs=self.n_jobs)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\cluster\mean_shift_.py", line 204, in mean_shift
(seed, X, nbrs, max_iter) for seed in seeds)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 934, in __call__
self.retrieve()
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 833, in retrieve
self._output.extend(job.get(timeout=self.timeout))
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\_parallel_backends.py", line 521, in wrap_future_result
return future.result(timeout=timeout)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\concurrent\futures\_base.py", line 435, in result
return self.__get_result()
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\concurrent\futures\_base.py", line 384, in __get_result
raise self._exception
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\externals\loky\_base.py", line 625, in _invoke_callbacks
callback(self)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 309, in __call__
self.parallel.dispatch_next()
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 731, in dispatch_next
if not self.dispatch_one_batch(self._original_iterator):
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 759, in dispatch_one_batch
self._dispatch(tasks)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\parallel.py", line 716, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\_parallel_backends.py", line 510, in apply_async
future = self._workers.submit(SafeFunction(func))
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\externals\loky\reusable_executor.py", line 151, in submit
fn, *args, **kwargs)
File "C:\Users\sarim\AppData\Local\Programs\Python\Python37\lib\site-packages\joblib\externals\loky\process_executor.py", line 1022, in submit
raise self._flags.broken
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.