Snippets

Frederick Vallaeys Prediction API for Shopping Queries

Created by Frederick Vallaeys last modified
// Give the model a unique name
var MODEL_NAME = 'Query and Conversion Rate';
// To get your project ID, open the Advanced APIs dialog, click the
// "Google Developers Console" and select the project number from the
// Overview page.
var PROJECT_ID = '40310033428';


// Set up a Google Spreadsheet with training data and enter your own URL and name of the sheet for the initial training model and for data to update the model
// we recommend around 3000 rows of data for the initial model
// data for updating the model can be as many rows as you want. It may not all get processed in one run but it will pick up where it left off last time
var INITIAL_TRAINING_SHEET = SpreadsheetApp.openByUrl("enter your spreadsheet url").getSheetByName("the name of the tab on the sheet with initial training data"); 
var UPDATE_TRAINING_SHEET = SpreadsheetApp.openByUrl("enter your spreadsheet url").getSheetByName("the name of the tab on the sheet with new training data"); 
  

// These are the names of your columns from the training
// data to ignore. Change these to create variations of your
// model for testing
var COLS_TO_IGNORE = [
];
// This is the output column for your training data, or
// what value the model is supposed to predict
var OUTPUT_COLUMN = 'Conv. rate';

function main() {
  //createTrainingModel(INITIAL_TRAINING_SHEET);
  //updateTrainedModelData(UPDATE_TRAINING_SHEET);
  var trainingStatus = Prediction.Trainedmodels.get(PROJECT_ID, MODEL_NAME).trainingStatus;
  Logger.log("Training Status: " + trainingStatus);
  
  var queries = [];
  // We are going to test it by querying with training data
  var testData = INITIAL_TRAINING_SHEET.getDataRange().getValues();
  var headers = testData.shift();
  for(var r in testData) {
    var query = [];
    var row = testData[r];
    for(var i in headers) {
      if(COLS_TO_IGNORE.indexOf(headers[i]) == -1 && headers[i] != OUTPUT_COLUMN) {
        query.push(row[i])
      }
    }
    queries.push(query);
  }
  Logger.log(makePrediction(queries));
  
}


/***********************************
 * This function accepts a sheet full of training
 * data and creates a trained model for you to query.
 ***********************************/


function createTrainingModel(sheet) {
  var trainingInstances = [];
  // get the spreadsheet values
  var trainingData = sheet.getDataRange().getValues();
  var headers = trainingData.shift();
  for(var r in trainingData) {
    var inputs = [];
    var row = trainingData[r];
    for(var i in headers) {
      if(COLS_TO_IGNORE.indexOf(headers[i]) == -1 && headers[i] != OUTPUT_COLUMN) {
        inputs.push(row[i])
      }
    }
    var output = row[headers.indexOf(OUTPUT_COLUMN)];
    trainingInstances.push(createTrainingInstance(inputs,output));
  }

  var insert = Prediction.newInsert();
  insert.id = MODEL_NAME;
  insert.trainingInstances = trainingInstances;

  var insertReply = Prediction.Trainedmodels.insert(insert, PROJECT_ID);
  Logger.log('Trained model with data.');
}

// Helper function to create the training instance.
function createTrainingInstance(inputs,output) {
  
  var trainingInstances = Prediction.newInsertTrainingInstances();
  trainingInstances.csvInstance = inputs;
  trainingInstances.output = output;
  return trainingInstances;
}

function updateTrainedModelData(sheet) {
  var updateData = sheet.getDataRange().getValues();
  var headers = updateData.shift();
  for(var r in updateData) {
    Logger.log("r: " + r);
    var inputs = [];
    var row = updateData[r];
    for(var i in headers) {
      if(COLS_TO_IGNORE.indexOf(headers[i]) == -1 && headers[i] != OUTPUT_COLUMN) {
        inputs.push(row[i])
      }
    }
    
    Logger.log("inputs: " + inputs);
    var output = row[headers.indexOf(OUTPUT_COLUMN)];
    Logger.log("output: " + output);
    var update = createUpdateInstance(inputs,output)
    var updateResponse = Prediction.Trainedmodels.update(update, PROJECT_ID, MODEL_NAME);
    Logger.log('Trained model updated with new data.');
    sheet.deleteRow(2);
  }
}

// Helper function to create the update instance.
function createUpdateInstance(inputs,output) {
  var updateInstance = Prediction.newUpdate();
  updateInstance.csvInstance = inputs;
  updateInstance.output = output;
  return updateInstance;
}

/***************************
 * Accepts a 2d array of query data and returns the
 * predicted output in an array.
 ***************************/
function makePrediction(data) {
  var retVal = [];
  for(var r in data) {
    var request = Prediction.newInput();
    request.input = Prediction.newInputInput();
    request.input.csvInstance = data[r];
    var predictionResult = Prediction.Trainedmodels.predict(
      request, PROJECT_ID, MODEL_NAME);
    Logger.log("Prediction for data: %s is %s",
               JSON.stringify(data[r]), predictionResult.outputValue);
    retVal.push(predictionResult.outputValue);
  }
  return retVal;
}

Comments (0)