Inference in Google Earth Engine + Colab

Here we demonstrate how to take a trained model and apply to to imagery with Google Earth Engine + Colab + Tensorflow. This is adapted from an Earth Engine <> TensorFlow demonstration notebook. We'll be taking the trained model from the Deep Learning Crop Type Segmentation Model Example.

Setup software libraries

Authenticate and import as necessary.

# Import, authenticate and initialize the Earth Engine library.
import ee
ee.Authenticate()
ee.Initialize()
# Mount our Google Drive
from google.colab import drive
drive.mount('/content/drive')
# Add necessary libraries.
!pip install -q focal-loss
import os
from os import path as op
import tensorflow as tf
import folium
from focal_loss import SparseCategoricalFocalLoss

Variables

Declare the variables that will be in use throughout the notebook.

# Specify names locations for outputs in Google Drive. 
FOLDER = 'servir-inference-demo'
ROOT_DIR = '/content/drive/My Drive/'

# Specify inputs (Sentinel indexes) to the model.
BANDS = ['NDVI', 'WDRVI', 'SAVI']

# Specify the size and shape of patches expected by the model.
KERNEL_SIZE = 224
KERNEL_SHAPE = [KERNEL_SIZE, KERNEL_SIZE]

Imagery

Gather and setup the imagery to use for inputs. It's important that we match the index inputs from the earlier analysis. This is a three-month Sentinel-2 composite. Display it in the notebook for a sanity check.

# Use Sentinel-2 data.

def add_indexes(img): 
    ndvi = img.expression(
        '(nir - red) / (nir  + red + a)', {
            'a': 1e-5,
            'nir': img.select('B8'),
            'red': img.select('B4')
        }
       
    ).rename('NDVI')

    wdrvi = img.expression(
        '(a * nir - red) / (a * nir + red)', {
            'a': 0.2,
            'nir': img.select('B8'),
            'red': img.select('B4')
        }
    ).rename('WDRVI')

    savi = img.expression(
        '1.5 * (nir - red) / (nir + red + 0.5)', {
            'nir': img.select('B8'),
            'red': img.select('B4')
        }
    ).rename('SAVI')

    return ee.Image.cat([ndvi, wdrvi, savi])

image = ee.ImageCollection('COPERNICUS/S2') \
    .filterDate('2018-01-01', '2018-04-01') \
    .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20)) \
    .map(add_indexes) \
    .median()

# Use folium to visualize the imagery.
mapid = image.getMapId({'bands': BANDS, 'min': -1, 'max': 1})
map = folium.Map(location=[              
              -29.177943749121233,
              30.55984497070313,
])
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name='median composite',
  ).add_to(map)

map.add_child(folium.LayerControl())
map

Load our saved model

# Load a trained model.
MODEL_DIR = '/content/drive/Shared drives/servir-sat-ml/data/model_out/10062020/'
model =  tf.keras.models.load_model(MODEL_DIR)
model.summary()

Prediction

The prediction pipeline is:

  1. Export imagery on which to do predictions from Earth Engine in TFRecord format to Google Drive.
  2. Use the trained model to make the predictions.
  3. Write the predictions to a TFRecord file in Google Drive.
  4. Manually upload the predictions TFRecord file to Earth Engine.

The following functions handle this process. It's useful to separate the export from the predictions so that you can experiment with different models without running the export every time.

def doExport(out_image_base, shape, region):
  """Run the image export task.  Block until complete.
  """
  task = ee.batch.Export.image.toDrive(
    image = image.select(BANDS),
    description = out_image_base,
    fileNamePrefix = out_image_base,
    folder = FOLDER,
    region = region.getInfo()['coordinates'],
    scale = 30,
    fileFormat = 'TFRecord',
    maxPixels = 1e10,
    formatOptions = {
      'patchDimensions': shape,
      'compressed': True,
      'maxFileSize': 104857600
    }
  )
  task.start()

  # Block until the task completes.
  print('Running image export to Google Drive...')
  import time
  while task.active():
    time.sleep(30)

  # Error condition
  if task.status()['state'] != 'COMPLETED':
    print('Error with image export.')
  else:
    print('Image export completed.')
def doPrediction(out_image_base, kernel_shape, region):
  """Perform inference on exported imagery.
  """

  print('Looking for TFRecord files...')

  # Get a list of all the files in the output bucket.
  filesList = os.listdir(op.join(ROOT_DIR, FOLDER))

  # Get only the files generated by the image export.
  exportFilesList = [s for s in filesList if out_image_base in s]

  # Get the list of image files and the JSON mixer file.
  imageFilesList = []
  jsonFile = None
  for f in exportFilesList:
    if f.endswith('.tfrecord.gz'):
      imageFilesList.append(op.join(ROOT_DIR, FOLDER, f))
    elif f.endswith('.json'):
      jsonFile = f

  # Make sure the files are in the right order.
  imageFilesList.sort()

  from pprint import pprint
  pprint(imageFilesList)
  print(jsonFile)

  import json
  # Load the contents of the mixer file to a JSON object.
  with open(op.join(ROOT_DIR, FOLDER, jsonFile), 'r') as f:
    mixer = json.load(f)

  pprint(mixer)
  patches = mixer['totalPatches']

  # Get set up for prediction.

  imageColumns = [
    tf.io.FixedLenFeature(shape=kernel_shape, dtype=tf.float32) 
      for k in BANDS
  ]

  imageFeaturesDict = dict(zip(BANDS, imageColumns))

  def parse_image(example_proto):
    return tf.io.parse_single_example(example_proto, imageFeaturesDict)

  def toTupleImage(inputs):
    inputsList = [inputs.get(key) for key in BANDS]
    stacked = tf.stack(inputsList, axis=0)
    stacked = tf.transpose(stacked, [1, 2, 0])
    return stacked

   # Create a dataset from the TFRecord file(s) in Cloud Storage.
  imageDataset = tf.data.TFRecordDataset(imageFilesList, compression_type='GZIP')
  imageDataset = imageDataset.map(parse_image, num_parallel_calls=5)
  imageDataset = imageDataset.map(toTupleImage).batch(1)

  # Perform inference.
  print('Running predictions...')
  predictions = model.predict(imageDataset, steps=patches, verbose=1)
  # print(predictions[0])

  print('Writing predictions...')
  out_image_file = op.join(ROOT_DIR, FOLDER, f'{out_image_base}pred.TFRecord')
  writer = tf.io.TFRecordWriter(out_image_file)
  patches = 0
  for predictionPatch in predictions:
    print('Writing patch ' + str(patches) + '...')
    predictionPatch = tf.argmax(predictionPatch, axis=2)

    # Create an example.
    example = tf.train.Example(
      features=tf.train.Features(
        feature={
          'class': tf.train.Feature(
              float_list=tf.train.FloatList(
                  value=predictionPatch.numpy().flatten()))
        }
      )
    )
    # Write the example.
    writer.write(example.SerializeToString())
    patches += 1

  writer.close()

Now there's all the code needed to run the prediction pipeline, all that remains is to specify the output region in which to do the prediction, the names of the output files, where to put them, and the shape of the outputs.

# Base file name to use for TFRecord files and assets.
image_base = 'servir_inference_demo_'

# South Africa (near training data)
region = ee.Geometry.Polygon(
        [[[
              30.55984497070313,
              -29.177943749121233
            ],
            [
              30.843429565429684,
              -29.177943749121233
            ],
            [
              30.843429565429684,
              -28.994928377910732
            ],
            [
              30.55984497070313,
              -28.994928377910732
            ]]], None, False)
# Run the export.
doExport(image_base, KERNEL_SHAPE, region)
# Run the prediction.
doPrediction(image_base, KERNEL_SHAPE, region)

Display the output

One the data has been exported, the model has made predictions and the predictions have been written to a file, we need to manually import the TFRecord to Earth Engine. Then we can display our crop type predictions as an image asset

out_image = ee.Image('users/drew/servir_inference_demo_-mixer')
mapid = out_image.getMapId({'min': 0, 'max': 10, 'palette': ['00A600','63C600','E6E600','E9BD3A','ECB176','EFC2B3','F2F2F2']})
map = folium.Map(location=[              
              -29.177943749121233,
              30.55984497070313,
])
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name='predicted crop type',
  ).add_to(map)
map.add_child(folium.LayerControl())
map