Created by
erdisayar
last modified
| import time
import cv2
import numpy as np
import tensorflow.compat.v1 as tf
from nrp_core.engines.python_json import EngineScript
tf.disable_v2_behavior()
class Script(EngineScript):
def initialize(self):
self._registerDataPack("image_saliency")
self._registerDataPack("input_image")
self._setDataPack("image_saliency",
{"image_height": None, "image_width": None, "channel": None,
"image": None})
graph_def = tf.GraphDef()
# load tensorflow trained model here
with tf.gfile.Open("tensorflow_model/model_salicon_cpu.pb", "rb") as file:
graph_def.ParseFromString(file.read())
self.input_plhd = tf.placeholder(tf.float32, (None, None, None, 3))
[self.predicted_maps] = tf.import_graph_def(graph_def,
input_map={"input": self.input_plhd},
return_elements=["output:0"])
self.sess = tf.Session()
def runLoop(self, timestep_ns):
received_data = self._getDataPack("input_image")
if not received_data:
self._setDataPack("image_saliency",
{"image_height": None, "image_width": None, "channel": None,
"image": None})
else:
height, width, channel = received_data['imageHeight'], received_data['imageWidth'], \
received_data['channel']
image = np.array(received_data['image']).reshape(height, width, channel).astype(
np.uint8)
time_stamp = time.time_ns()
# image = cv2.resize(image, (320, 240)) # saliency model requires (320, 240) images
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Gazebo uses BGR images
cv2.imwrite(f"images/From_Gazebo/gazebo_{str(time_stamp)}.jpg", image)
image = image[np.newaxis, :, :, :]
saliency_image = self.sess.run(self.predicted_maps, feed_dict={self.input_plhd: image})
saliency_image = cv2.cvtColor(saliency_image.squeeze(), cv2.COLOR_GRAY2BGR)
saliency_image = np.uint8(saliency_image * 255)
cv2.imwrite(
f"images/From_Tensorflow_Engine/tensorflow_engine_{str(time_stamp)}.jpg",
saliency_image)
# Setting the datapack with the processed image
self._setDataPack("image_saliency",
{"image_height": height, "image_width": width, "channel": channel,
"image": saliency_image.flatten().tolist()})
def shutdown(self):
self.sess.close()
print("Tensorflow Python JSON Engine is shutting down")
|