Keras Tutorial Notes

I’ve recently finished the first pass of CS231N Convolutional Neural Networks for Visual Recognition. Now it’s time to try out a library to get hands dirty. Keras seems to be an easy-to-use high-level library, which wraps over 3 different backend engine: TensorFlow, CNTK and Theano. Just perfect for a beginner in Deep Learning.

The tutorial I picked is the one on the MNIST dataset. I’m adding some notes along the way to refresh my memory on what I have learned as well as some links so that I can find the references in CS231N quickly in the future.

Step 1 Importing libraries and prepare parameters for training

'''Trains a simple convnet on the MNIST dataset.

Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

# input image dimensions
img_rows, img_cols = 28, 28

Not much to explain on the import statements, so let’s look at some of the parameters defined in this section.

What are batch_size and epochs?

A good explanation can be found Training a Model from DL4J. Epoc means to train the model on all of your data once—a single pass over the whole dataset. Why do we need to train the model with multiple epochs?

To answer this question, we need to know what happens in the training process in a Neural Network. Using the example in CS231N, it is minimizing the loss function using gradient descent. One gradient descent update most likely won’t give you the minimal loss, so we have to do multiple passes until it converges or hitting a pre-set limit—for example, the epoch number. Of course, not all machine learning require multiple passes like this, for example, K-Nearest Neigbour (K-NN) algorithm.

Now let’s talk about batch_size. It relates to how we train the model, specifically how to optimize the loss function. In the naive form, we compute the loss function over the whole dataset. Quoted from CS231N:

while True:
  weights_grad = evaluate_gradient(loss_fun, data, weights)
  weights += - step_size * weights_grad # perform parameter update

However if we have millions of records, it becomes wasteful and inefficient to repeatedly compute the loss function to do a simple gradient update. Therefore, a common way to solve the scalability issue is to compute the gradient over batches of training data.

while True:
  data_batch = sample_training_data(data, 256) # sample 256 examples
  weights_grad = evaluate_gradient(loss_fun, data_batch, weights)
  weights += - step_size * weights_grad # perform parameter update

So why does this work? To quote from the course note:

“….the gradient from a mini-batch is a good approximation of the gradient of the full objective. Therefore, much faster convergence can be achieved in practice by evaluating the mini-batch gradients to perform more frequent parameter updates.”

Gradient descent using mini-batch like this is called Minibatch Gradient Descent (MGD) but in practice this is usually referred as another concept Stochastic Gradient Descent (SGD) when the batch size is 1.

  • One question I have: with epoch and batch_size, does this mean that we update the gradient with SGD multiple times in one epoch?

Set the image dimension

  • We specified the image dimension in the code, which raised two questions:
  • Do all the images in the dataset have to be in the same dimension?
  • I assume if they don’t, we will have to resize them into the same size. How? Doesn’t the resizing make the subject in the image disproportional?

Step 2: Prepare the dataset for training and testing

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

Specify the depth of the images

Since we are using CNN, one important step is to arrange the neurons in 3D (width, height and depth). I’ll skip the details but the depth here in code is 1. That means our images have only 1 channel, instead of 3 (RGB channels).

Normalize the mean and standard-deviation

It seems that the code above doesn’t perform this processing except for the two lines below:

x_train /= 255
x_test /= 255

As a guideline:

“Normally we would want to preprocess the dataset so that each feature has zero mean and unit standard deviation, but in this case the features are already in a nice range from -1 to 1, so we skip this step.”

Preprocess the class labels

Well, we need the class label to be a 10-dimensional array for each record. Not sure if this is related, but the scoring function of the model is a 10-dimensional array with each value representing a score assigned to a particular class. If we look at the labels, we will find the labels in a 1-dimensional array. Hence the conversion.

print y_train[:10]
# [5 0 4 1 9 2 1 3 1 4]

Step 3: Define the model structure

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

A Sequential model a linear stack of layers. Here we added 8 layers. Why do we add these layers but not others? I don’t know. People spend a great deal of time trying out different architectures of the network. If we are just starting out, we might just rely on architectures that are proven to be useful, like the examples provided by Keras.

Layer Patterns

……A ConvNet is made up of Layers. Every Layer has a simple API: It transforms an input 3D volume to an output 3D volume with some differentiable function that may or may not have parameters……We use three main types of layers to build ConvNet architectures: Convolutional Layer, Pooling Layer, and Fully-Connected Layer (exactly as seen in regular Neural Networks). We will stack these layers to form a full ConvNet architecture…..

http://cs231n.github.io/convolutional-networks/#layers

So why do we use Convolutional layers instead of the regular ones? In short, to solve performance and scalability issues as well as to avoid overfitting when processing full images.

The links above covered Conv2D, MaxPooling, and Dense layers. What about Dropout and Flatten here?

At this point, the model structure is defined. We then specify the loss function, the way to optimize it and the measurement metric in the compile method.

Step 4: train the model and test it

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Not much to explain here. We train the model with the training data. However, one concern I have with this piece is that the model is validating itself on the testing data after each epoch and it’s also evaluating on the same testing data to get the score. What we should do is to have a dedicated validation set split from the training set (as suggested in the courser note validation set is considered already burned during training (see the last point in the summary)). Therefore, using the option validation_split may be a better idea.

Advertisements
Keras Tutorial Notes

Root-causing the random failures in the integration tests with ElasticSearch

In our recent development we were creating an integration test framework and some tests for manipulating data in the ElasticSearch cluster. Strangely the tests could succeed or fail randomly, even though we never made any changes to the code on the business logic at that time.

What did we have in the test cases?

  • @BeforeClass: load the test data into ElasticSearch cluster through ElasticSearch TransportClient.
  • @Test: retrieve test data and check equality on some fields.
  • @AfterClass: clean up the test data through ElasticSearch TransportClient.

Really just simple as this.

What did it the error message say when the tests failed? Well it complained about not being able to find the test data.

Strange. The @BeforeClass annotated method should always load the data into the cluster before executing the test cases and there were no errors about failing to load data. Feeling a bit stuck, I commented out the clean up code in the @AfterClass method. Now the tests passed consecutively on every test I issued but once I added back the cleanup code, it started failing occasionally again, especially when I ran the test right after the previous one finished.

This got me thinking: “Could it be possible that the test data was cleaned up at the end of the previous test but not loaded into the cluster in the next run even though @BeforeClass method was executed? ”

My suspicion was confirmed after some reading on how ElasticSearch loads data. Why did this happen? Because loading data into ElasticSearch cluster takes time and so does deleting them. The test cases were executed right after the load request was issued in the @BeforeClass method, but not necessarily after the request was processed by the cluster. In other words, it is asynchronous. We made a false assumption that the load request was processed and the data was present in the cluster immediately. This mindset may be OK in unit test but with integration test it can be problematic.

Stupid solution: Add a buffer before actually executing the tests, for example, Thread.sleep(30000) in the @BeforeClass method. However, this does not always guarantee the data was loaded if the data size is large.

Better solution: Send a request to verify that the request is actually processed given the request id. Wait in the @BeforeClass method until the request is finished.

Whatever you do, make sure that the test data are actually in the cluster before moving on.

 

 

Root-causing the random failures in the integration tests with ElasticSearch

How to use JavaConfig Bean in Spring XML

Our current project is at the first stage to wire all the components together and do a simple integration test. When I took on this task, I found that all beans were defined in XML. Given the number of beans I have to create, it would be tedious to write them all in XML. Personally I prefer using JavaConfig to the XML files as the navigation is easier for me in JavaConfig. But I don’t want to change the XML configurations into JavaConfig all at once. Can I define JavaConfig Beans and use them in the XML?

A bit of search revealed a simple way. Now assume that we have a provider class as follows:

package com.example.xyz;

@Configuration
public class ResourceProvider{
    @Bean
    public SQSWrapper sqsWrapper() {
        return new SQSWrapper();
    }
}

Assume that we have an application.xml file and we want to use the SQSWrapper Bean in a bean definition in the file:

<bean id="SQSConsumer" class="com.example.xyz.SQSConsumerImpl">
    <constructor-arg ref="THE_ID_OF_THE_SQSWRAPPER_BEAN">
</bean>

To do that we need to add two extra lines to the file and then we specify the id of the SQSWrapper bean by using the method name sqsWrapper. The complete xml file looks like this:

<context:annotation-config/>

<!-- The following line brings in the beans defined in the ResourceProvider -->
<bean class="com.example.xyz.ResourceProvider" />

<bean id="SQSConsumer" class="com.example.xyz.SQSConsumerImpl">
    <constructor-arg ref="sqsWrapper">
</bean>

The first line “annotation-config” is crucial as noted in this stackoverflow answer: “while annotation-config is switched on, the container will recognize the @Configuration annotation and process the @Bean methods declared in JavaConfig properly”.

Now that saved me from creating more xml files!

How to use JavaConfig Bean in Spring XML

Randomly Draw k unique integers out of an array of N unique integers

Given an array of n unique integers (1 to n), write an algorithm to draw a random combination of k distinct numbers (n >= k). (This problem comes from Core Java Vol I as an example.)

Unknown: A way to draw k distinct integers out of  n distinct integers.

Data: An array of integers 1 to n.

Constraint: k numbers must be distinct and randomly picked.

A straightforward solution would be:

  1. Randomly pick one number out of the an array
  2. If this number is not picked before, add it to the result. Return the result if we have k numbers.
  3. Otherwise, back to step 1.

Q: So what is the time complexity of this solution?

A: If we are unlucky, in the worst case, O(k^2) and if k close to n, O(n^2).

Q: How so?

A: At some point, we will have problem selecting a number that’s not in the result set.  The first number is easy, just once. The second, if unlucky, twice. The third, if unlucky, 3 times since the first 2 times picked something in the result set…so up to k numbers, it can take 1+2+3+…+k picks which is approximately O(k^2). If k is close to n, then we have a O(n^2) algorithm. check the link at the bottom for the code.

Q: Alright, can we make it faster? Say let’s make it O(n) time and you cannot use additional space except for the result set.

A: Hmm, the bottleneck of the previous solution is that every time we pick a number, we have to check if it exists in the result set. If it does, we have to go back and pick again. If we can skip this step it will be faster.

Q: What do you mean by skipping this step?

A: I mean that every time we pick a number, it is guaranteed not picked before.

Q: How do we do that?

A: Hmm. we need to keep track of what has not been picked instead. Since we cannot use additional space, I assume that we have to do something on the original array. I can replace the picked number with some special value like n+1, but this sounds useless since if I happened to pick this number, I would have to choose again, exactly like before. I don’t know…

Q: OK, in what situation can we safely draw an unpicked number?

A: If the array only contains unpicked numbers, we can do that safely. But again, I don’t think we can recreate a new array to exclude the picked one in every pass. That’s O(n^2) again.

Q: True. So why can’t we draw the numbers safely now? What’s the matter?

A: Because there are picked values between unpicked ones.

Q: Good. You mentioned about excluding them. Is there a way to do that without creating a new array?

A: I suppose I can re-arrange the array? For example, if I picked the number at index i, I can move the numbers from i+1 to n-1 forward. But then I should pick a random index between 0 to n-1 (exclusive). Wait, this is still O(n^2)…

Q: Do we have to move all the elements after index i? Can we reduce this O(n) move to a O(1) move?

A: O(1)? So I should move only 1 element instead. But which one…

Q: Let’s use an example: 1,2,3,4,5 Say we picked 3. In your case, we change the array to 1,2,4,5,5 and then we pick from index 0 to 3 next time. We do the move because we want to make sure next time we are choosing from 1,2,4,5. So is there another way to do it?

A: Yes! I can move the last element to that position to achieve the same effect! So every time after the pick, I move the last element within the current range to the picked position then reduce the range by 1.

Q: That’s right 🙂

Link to the code: https://gist.github.com/jingz8804/53955bbaf817a6c2e179

 

Randomly Draw k unique integers out of an array of N unique integers

Some thoughts on JUnit Test with Mockito Part II

So in the previous post we talked about some resources and concepts we should understand before writing JUnit tests with Mockito.

Well, based on my recent experience, they may not be enough, especially when our code has many dependencies or collaborators and we have to mock all of them to do the test. Things can become complicated very soon. If you find yourself in this situation, stop writing the test and refactor your code first, for example, by using the strategy pattern. In this way, you will likely have less collaborators to mock, but of course, you need to write tests for code delegated to other classes. There may be some extra work, but cleaner and easier.

Aside from this, my recent struggle also made me think about how to write unit test in general. I found another post about unit testing by Martin Fowler and read the first part of a book called “Effective Unit Testing”. There are some paragraphs that I find very helpful:

1) “We write tests assuming everything other than that unit under test is working correctly.” — Martin Fowler

Here “other than that unit under testing” usually means collaborators and we want to isolate them from the code under testing. By isolating them, we mean using test doubles to replace the real collaborators.

2) “A test should test just one thing and test it well while communicating its intent clearly.” We must ask ourselves:

What is the desired behavior for the code under testing? Think about how each step of your method should behave with certain inputs. The desired behavior of the dependencies is something we should configure the mocks to expect.

A general process could be:

Q: What do we want to test?

A: We want to test if method A is working properly.

Q: What do you mean by working properly?

A: By properly I mean that with a certain type of input INPUT, A should first call a dependency with INPUT and return a RESPONSE. Then A does some calculation by adding a number to some value stored in RESPONSE and return the sum. I expect that the sum is XXX.

Q: OK, so here we have a dependency. How do we deal with it?

A: We can isolate it by using a mock object to stub the RESPONSE when called with INPUT.

Q: Are there any requirements on the RESPONSE object?

A: Yes, since we are using some value in the RESPONSE, we should stub that value in the RESPONSE object. Otherwise, we might hit a null pointer exception situation.

Anyway, the key point is that we shouldn’t have too many dependencies for mocking in one unit test. Refactor the code first. Then work slowly through this process. Take your time.

Some thoughts on JUnit Test with Mockito Part II

Some thoughts on JUnit Test with Mockito

In the past a few days I’ve worked on writing unit tests for a service. It was a completely different experience. The unit tests I wrote before were very simple ones with several assertions and that’s it. However, since the service I’m writing involves calling other services, we have to mock their behaviors before testing our own part. Otherwise, if the dependent services are not working, or even not implemented, we won’t be able to do any test on our business logic at all.

It’s an easy idea to grasp, but it took me a while to understand the code. Before jumping into Mockito, there’s an article one should read: Mocks Aren’t Stubs by Martin Fowler. It explains classical and mockist testing in some details, a bit long but worth reading.

After reading that we should get an idea about mockist testing:

Set up -> Set Expectations -> Exercise -> Verification.

The second step is the part where we can tell the mocked dependent services how to behavior under certain conditions.

In the verification, we can still do the usual assertion but we also may need to verify that a dependent service is called as expected. Why do we need that? Say that you expect a method to return -1 with argument A. In an ideal successful case, we can just use assertEqual for testing. However, if the malfunction of a dependent service also causes the return value to be -1, then using assert statement alone will be insufficient. We need to make sure that a service is called and/or has the right return value.

After I got this idea, I moved on to Mockito framework, which is different from the ones mentioned in the article. But the idea is similar. There are two (actually three) articles I found online that helps you ease into using Mockito:

Occasionally we may have the need to mock static method. I had a really hard time with it when I was copying and pasting others’ test code. DON’T DO THAT! We are using PowerMock and I don’t know why I didn’t search for it early. Just RTFM, it’s actually very easy to use.

For my own reference:

To solve a problem, we need to locate the root cause first. We should set up some assumptions and verify them one by one.

Some thoughts on JUnit Test with Mockito

Notes on Java Daemon Thread

I’m going to work on Daemon thread in my new job, but I have no idea what it is. This post summarizes some of the key points from a stackoverflow post.


 

First, let’s look at daemon threads in Unix. Simply put, they are threads running in the background that answer requests for services. You can check more of it on Wikipedia.

There are two types of Java thread:

  • Normal/User thread: Generally all threads created by programmer are user thread (unless you specify it to be daemon or the parent thread spawning the new thread is a daemon thread). The main thread is by default a non daemon thread.
  • Daemon thread: it is similar (I don’t know if I can say that. Correct me if I’m wrong please). Daemon threads are like a service providers for other threads or objects running in the same process as the daemon thread (In other words, they may serve the user threads). They are typically used for background supporting tasks.

Points to Note about Java Daemon Threads:

  • (needs verification) It has very low priority and only executes when no other threads of the same program is running
  • When there are no more user threads (meaning that only daemon threads are running in a program), the JVM will ends the program and exit. This is reasonable. If there are no one to serve any more, why keep the servants? (This is my own thoughts) 
  • When the JVM halts, all daemon threads are abandoned. The “finally blocks“ are not executed and stacks are not unwound (not sure what this means).
  • Daemon threads usually have an infinite loop in its run() method that waits for the service request or performs the tasks of the thread.
  • We can set a thread to be daemon through the setDaemon() method but we can only do that before the start of the thread.
  • We can check if a thread is a user thread or daemon thread using isDaemon() thread.

Examples of Java Daemon Threads:

  • Garbage collection. It runs in the background, claiming resources from unwanted objects.
  • A good Java code example from that post, reposted on gist

Things to check…

  • Non-daemon threads (default) can even live longer than the main thread.
Notes on Java Daemon Thread