Commits

iorodeo committed c78348f

Worked on writing training data.

  • Participants
  • Parent commits ad4aebd

Comments (0)

Files changed (6)

src/demo/fly_sorter/fly_sorter_window.cpp

 #include <QNetworkReply>
 #include <QNetworkRequest>
 #include <QDateTime>
+#include <QDir>
 #include <iostream>
 #include <list>
 #include <random>
     if (!running_)
     {
 
+        // Setup sorting and tracking
         blobFinder_ = BlobFinder(param_.blobFinder);
         identityTracker_ = IdentityTracker(param_.identityTracker);
         flySegmenter_ = FlySegmenter(param_.flySegmenter);
         hogPositionFitter_ = HogPositionFitter(param_.hogPositionFitter);
         genderSorter_ = GenderSorter(param_.genderSorter);
+
+        // Create training data
+        setupTrainingDataWrite(param_.imageGrabber.captureInputFile);
+
+
         startImageCapture();
 
+        // DEBUG - open debug output data file
+        // --------------------------------------------------------------------
         //debugStream.open("debug_data.txt");
     }
     else
     {
         stopImageCapture();
 
+        // DEBUG - close debug output data file
+        // --------------------------------------------------------------------
         //debugStream.close();
     }
 }
         threadPoolPtr_ -> start(imageGrabberPtr_);
         startPushButtonPtr_ -> setText("Stop");
         reloadPushButtonPtr_ -> setEnabled(false);
+        trainingDataCheckBoxPtr_ -> setEnabled(false);
     }
 }
 
     }
     loadParamFromFile();
     updateParamText();
+    updateWidgetsOnLoad();
     if (isRunning)
     {
         startImageCapture();
 }
 
 
+void FlySorterWindow::trainingDataCheckBoxChanged(int state)
+{
+    //std::cout << "write training data ";
+    //if (state == Qt::Checked)
+    //{
+    //    std::cout << "checked";
+    //}
+    //else
+    //{
+    //    std::cout << "unchecked";
+    //}
+    //std::cout << std::endl;
+}
+
+
 void FlySorterWindow::newImage(ImageData imageData)
 {
     if (running_)
     {
+        imageData_.copy(imageData);
 
-        imageData_.copy(imageData);
+        // Find flies and sort by gender
         blobFinderData_ = blobFinder_.findBlobs(imageData_.mat);
         identityTracker_.update(blobFinderData_);
         flySegmenterData_ = flySegmenter_.segment(blobFinderData_);
-        hogPositionFitterData_ = hogPositionFitter_.fit(flySegmenterData_,imageData.frameCount,imageData.mat);
+
+        hogPositionFitterData_ = hogPositionFitter_.fit(
+                flySegmenterData_,
+                imageData.frameCount,
+                imageData.mat
+                );
+
         genderSorterData_ = genderSorter_.sort(hogPositionFitterData_);
+
+        // Send position and gender data via http.
         if ((httpOutputCheckBoxPtr_ -> checkState()) == Qt::Checked)
         {
             sendDataViaHttpRequest();
         }
 
-        // DEBUG - display gender data
-        // -------------------------------------------------------------------------
-        if (0)
-        {
-            std::cout << "Frame Count: " << imageData.frameCount << std::endl;
-            GenderDataList genderDataList = genderSorterData_.genderDataList;
-            GenderDataList::iterator it;
+        //// DEBUG - display gender data
+        //// -------------------------------------------------------------------------
+        //if (0)
+        //{
+        //    GenderDataList genderDataList = genderSorterData_.genderDataList;
+        //    GenderDataList::iterator it;
 
-            for (it=genderDataList.begin(); it!=genderDataList.end(); it++)
-            {
-                GenderData data = *it;
-                std::cout << data.toStdString(1) << std::endl;
-            }
-        }
-        // -------------------------------------------------------------------------
+        //    for (it=genderDataList.begin(); it!=genderDataList.end(); it++)
+        //    {
+        //        GenderData data = *it;
+        //        std::cout << data.toStdString(1) << std::endl;
+        //    }
+        //}
+        //// -------------------------------------------------------------------------
 
-
-        // DEBUG - write images to file
-        // -------------------------------------------------------------------------
-        if (0)
-        {
-            QString imgFileName = QString("image_%1.bmp").arg(imageData.frameCount);
-            cv::imwrite(imgFileName.toStdString(),imageData.mat);
-        }
-        // -------------------------------------------------------------------------
+        //// DEBUG - write images to file
+        //// -------------------------------------------------------------------------
+        //if (0)
+        //{
+        //    QString imgFileName = QString("image_%1.bmp").arg(imageData.frameCount);
+        //    cv::imwrite(imgFileName.toStdString(),imageData.mat);
+        //}
+        //// -------------------------------------------------------------------------
 
     }
 }
         running_ = false;
         startPushButtonPtr_ -> setText("Start");
         reloadPushButtonPtr_ -> setEnabled(true);
+        if (param_.imageGrabber.captureMode == QString("file"))
+        {
+            trainingDataCheckBoxPtr_ -> setEnabled(true);
+        }
 }
 
 
             this,
             SLOT(httpOutputCheckBoxChanged(int))
            );
+
+    connect(
+            trainingDataCheckBoxPtr_,
+            SIGNAL(stateChanged(int)),
+            this,
+            SLOT(trainingDataCheckBoxChanged(int))
+           );
 }
 
 
     setupNetworkAccessManager();
     loadParamFromFile();
     updateParamText();
+    updateWidgetsOnLoad();
 
     // Temporary
-    // --------------------------------------------------------------------------
+    // ----------------------------------------------------------------------------
     //distribution_ = std::uniform_int_distribution<unsigned int>(0,1);
-
-    QString appDirPath = QCoreApplication::applicationDirPath();
-    std::cout << "applicationDirPath = " << appDirPath.toStdString() << std::endl;
+    //QString appDirPath = QCoreApplication::applicationDirPath();
+    //std::cout << "applicationDirPath = " << appDirPath.toStdString() << std::endl;
 }
 
 
         QString genderString = QString::fromStdString( 
                 GenderSorter::GenderToString(genderData.gender)
                 );
+
         QVariant id = QVariant::fromValue<long>(
                 genderData.positionData.segmentData.blobData.id
                 );
     QByteArray prettyParamJson = prettyIndentJson(paramJson);
     paramsTextEditPtr_ -> setPlainText(QString(prettyParamJson));
 }
+
+
+void FlySorterWindow::updateWidgetsOnLoad()
+{
+    if (param_.imageGrabber.captureMode == QString("file"))
+    {
+        trainingDataCheckBoxPtr_ -> setEnabled(true);
+    }
+    else
+    {
+        trainingDataCheckBoxPtr_ -> setCheckState(Qt::Unchecked);
+        trainingDataCheckBoxPtr_ -> setEnabled(false);
+    }
+}
+
+
+void FlySorterWindow::setupTrainingDataWrite(QString videoFileName)
+{
+    if (trainingDataCheckBoxPtr_ -> checkState() == Qt::Checked)
+    {
+        // Get application directory
+        QString appPathString = QCoreApplication::applicationDirPath();
+        QDir appDir = QDir(appPathString);
+
+        // Create training data base directory if it doesn't exist
+        QString baseDirString = QString("training_data");
+        QDir baseDir = QDir(appDir.absolutePath() + "/" + baseDirString);
+        if (!baseDir.exists())
+        {
+            appDir.mkdir(baseDirString);
+        }
+
+        // Create trianing data directory if it doesn't exist
+        QString videoPrefix = videoFileName.split(".", QString::SkipEmptyParts).at(0);
+        QDir dataDir = QDir(baseDir.absolutePath() + "/" + videoPrefix);
+        if (!dataDir.exists())
+        {
+            baseDir.mkdir(videoPrefix);
+        }
+
+        // Create training data file prefix.
+        QString dataPrefix = dataDir.absoluteFilePath("data");
+        hogPositionFitter_.trainingDataWriteEnable(dataPrefix.toStdString());
+    }
+    else
+    {
+        hogPositionFitter_.trainingDataWriteDisable();
+    }
+}

src/demo/fly_sorter/fly_sorter_window.hpp

         void startPushButtonClicked();
         void reloadPushButtonClicked();
         void httpOutputCheckBoxChanged(int state);
+        void trainingDataCheckBoxChanged(int state);
         void newImage(ImageData imageData);
         void updateDisplayOnTimer(); 
         void networkAccessManagerFinished(QNetworkReply *reply);
         GenderSorterData genderSorterData_;
 
 
-
         void connectWidgets();
         void initialize();
         void startImageCapture();
         QByteArray dataToJson();
         void loadParamFromFile();
         void updateParamText();
+        void updateWidgetsOnLoad();
+        void setupTrainingDataWrite(QString videoFileName);
+
 
 
         // Devel 

src/demo/fly_sorter/fly_sorter_window.ui

         </widget>
        </item>
        <item>
+        <spacer name="horizontalSpacer_4">
+         <property name="orientation">
+          <enum>Qt::Horizontal</enum>
+         </property>
+         <property name="sizeType">
+          <enum>QSizePolicy::Fixed</enum>
+         </property>
+         <property name="sizeHint" stdset="0">
+          <size>
+           <width>20</width>
+           <height>20</height>
+          </size>
+         </property>
+        </spacer>
+       </item>
+       <item>
+        <widget class="QCheckBox" name="trainingDataCheckBoxPtr_">
+         <property name="text">
+          <string>Create Training Data</string>
+         </property>
+        </widget>
+       </item>
+       <item>
         <spacer name="horizontalSpacer">
          <property name="orientation">
           <enum>Qt::Horizontal</enum>

src/demo/fly_sorter/hog_position_fitter.cpp

 //  HogPositionFitter
 //  ---------------------------------------------------------------------------
 
-HogPositionFitter::HogPositionFitter() { };
+HogPositionFitter::HogPositionFitter() 
+{ 
+    showDebugWindow_ = false; 
+    writeTrainingData_ = false;
+    trainingFileNamePrefix_ = std::string("none");
+};
 
-HogPositionFitter::HogPositionFitter(HogPositionFitterParam param)
+
+HogPositionFitter::HogPositionFitter(HogPositionFitterParam param) : HogPositionFitter()
 {
     setParam(param);
-    showDebugWindow_ = false; 
     if (showDebugWindow_)
     {
         //cv::namedWindow(
     }
 };
 
+
 void HogPositionFitter::setParam(HogPositionFitterParam param)
 {
     param_ = param;
 }
 
 
+void HogPositionFitter::trainingDataWriteEnable(std::string fileNamePrefix)
+{
+    writeTrainingData_ = true;
+    trainingFileNamePrefix_ = fileNamePrefix;
+}
+
+
+void HogPositionFitter::trainingDataWriteDisable()
+{
+    writeTrainingData_ = false;
+}
+
+
 HogPositionFitterData HogPositionFitter::fit(
         FlySegmenterData flySegmenterData, 
         unsigned long frameCount,
 
             posData.success = true;
             fitterData.positionDataList.push_back(posData); 
+
+            if (writeTrainingData_)
+            {
+                createTrainingData(rotBoundingImageLUV);
+            }
            
             // DEBUG - Write pixel feature vector to file
             // ------------------------------------------------------------------------------------
-            if (0) 
-            {
-                std::ofstream pVecStream;
-                QString pVecFileName = QString("pVec_frm_%1_cnt_%2.txt").arg(frameCount).arg(cnt);
-                pVecStream.open(pVecFileName.toStdString());
-                for (int i=0; i<posData.pixelFeatureVector.size();i++)
-                {
-                    pVecStream << posData.pixelFeatureVector[i] << std::endl;
-                }
-                pVecStream.close();
-            }
-            if (showDebugWindow_)
-            {
-                if (cnt==0)
-                {
-                    //cv::imshow("hogPosMaxComp", maxCompMat);
-                    //cv::imshow("boundingImageLUV", posData.segmentData.boundingImageLUV);
-                    cv::imshow("rotBoundingImageLUV", posData.rotBoundingImageLUV);
-                }
-            }
+            //if (0) 
+            //{
+            //    std::ofstream pVecStream;
+            //    QString pVecFileName = QString("pVec_frm_%1_cnt_%2.txt").arg(frameCount).arg(cnt);
+            //    pVecStream.open(pVecFileName.toStdString());
+            //    for (int i=0; i<posData.pixelFeatureVector.size();i++)
+            //    {
+            //        pVecStream << posData.pixelFeatureVector[i] << std::endl;
+            //    }
+            //    pVecStream.close();
+            //}
+            //if (showDebugWindow_)
+            //{
+            //    if (cnt==0)
+            //    {
+            //        //cv::imshow("hogPosMaxComp", maxCompMat);
+            //        //cv::imshow("boundingImageLUV", posData.segmentData.boundingImageLUV);
+            //        cv::imshow("rotBoundingImageLUV", posData.rotBoundingImageLUV);
+            //    }
+            //}
             // ------------------------------------------------------------------------------------
         }
        
     return gradData;
 }
 
+
+void HogPositionFitter::createTrainingData(cv::Mat img)
+{
+    cv::Mat imgFlipX;
+    cv::Mat imgFlipY;
+    cv::Mat imgFlipXY;
+
+    cv::flip(img, imgFlipX,   0);
+    cv::flip(img, imgFlipY,   1);
+    cv::flip(img, imgFlipXY, -1);
+
+    std::vector<double> vector = getPixelFeatureVector(img);
+    std::vector<double> vectorFlipX = getPixelFeatureVector(imgFlipX);
+    std::vector<double> vectorFlipY = getPixelFeatureVector(imgFlipY);
+    std::vector<double> vectorFlipXY = getPixelFeatureVector(imgFlipXY);
+
+}
+
+
 cv::Mat getTriangleFilter1D(unsigned int normRadius)
 { 
     float normConst = std::pow(float(normRadius)+1,2);

src/demo/fly_sorter/hog_position_fitter.hpp

         HogPositionFitter();
         HogPositionFitter(HogPositionFitterParam param);
         void setParam(HogPositionFitterParam param);
-        HogPositionFitterData fit(FlySegmenterData flySegmenterData, unsigned long frameCount, cv::Mat img);
+        void trainingDataWriteEnable(std::string fileNamePrefix);
+        void trainingDataWriteDisable();
+
+        HogPositionFitterData fit(
+                FlySegmenterData flySegmenterData, 
+                unsigned long frameCount, 
+                cv::Mat img
+                );
 
     private:
         bool showDebugWindow_;
+        bool writeTrainingData_;
+        std::string trainingFileNamePrefix_;
         HogPositionFitterParam param_;
+
         cv::Mat getFillMask(cv::Mat image);
         std::vector<double> getPixelFeatureVector(cv::Mat image);
-        std::vector<double> getHistGradMag(cv::Mat normGradMag, cv::Mat mask);
-        std::vector<double> getHistGradOri(cv::Mat gradOri, cv::Mat normGradMag, cv::Mat mask);
-        std::vector<double> getHistColor(cv::Mat subImage, cv::Mat mask);
+
+        std::vector<double> getHistGradMag(
+                cv::Mat normGradMag, 
+                cv::Mat mask
+                );
+
+        std::vector<double> getHistGradOri(
+                cv::Mat gradOri, 
+                cv::Mat normGradMag, 
+                cv::Mat mask
+                );
+
+        std::vector<double> getHistColor(
+                cv::Mat subImage, 
+                cv::Mat mask
+                );
+
+        void createTrainingData(cv::Mat img);
+
 };
 
 

src/demo/fly_sorter/image_grabber.cpp

     // Read frame from input file at frameRate.
     unsigned long frameCount = 0;
     //float sleepDt = 1.0e3/param_.frameRate;
-    float sleepDt = 0.25*1.0e3/param_.frameRate;
+    //float sleepDt = 0.25*1.0e3/param_.frameRate;
+    float sleepDt = 0.1*1.0e3/param_.frameRate;
     
 
     //std::cout << "begin play back" << std::endl;
     ImageData imageData;
+    std::cout << param_.captureInputFile.toStdString() << std::endl;
 
     while ((!stopped_) && (frameCount < numFrames))
     {
+        std::cout << (frameCount+1) << "/" << numFrames << std::endl;
+
         cv::Mat mat;
-        //std::cout << param_.captureInputFile.toStdString() << ", frame = " << frameCount;
 
         try
         {
 
         ThreadHelper::msleep(sleepDt);
     }
-    //std::cout << "play back done" << std::endl;
+    std::cout << "play back done" << std::endl;
 
     // Release file
     try