Deep Learning and US Equity Data
This article will explore the effectiveness of a recurrent neural network (RNN) on making predictions on US equity 1 second bar data.
Why use an RNN over other types of neural networks, for example a CNN? In short, RNNs have become increasingly popular for making predictions on sequential data. For more information on RNNs, I highly recommend giving The Unreasonable Effectiveness of Recurrent Neural Networks a read. And to give credit where credit is due, the following work expands on some of the ideas first presented in LSTM Neural Network for Time Series Prediction.
The objective of the experiment performed in this article will be to predict the percentage change of a stock price 1 hour (3600 seconds) from the last time step in a 1 hour sequence (3600 seconds). In other words, let’s say we were given 1 second bar IBM equity data from 9:30 until 10:30, we want to predict the percentage change of the stock price at time 11:30.
The Data
The financial data used in this experiment was rented from Quantgo. Quantgo provides access to years worth of very granular historical data via their virtual private network. If instead you would like to purchase your financial data, here are some data vendors:
Quandl NOT FREE
Xignite NOT FREE
Yahoo Finance FREE
The snippet below is a sample of the US equities 1 second bar data provided by Quantgo.
Date, Timestamp, Ticker, OpenPrice, HighPrice, LowPrice, ClosePrice, TotalVolume, TotalQuantity, BuyQuantity, SellQuantity, TotalTradeCount, BuyTradeCount, SellTradeCount 20140505, 09:30:01.000, IBM, 190.71, 190.71, 190.67, 190.71, 67503.34, 354, 0, 0, 3, 0, 0 20140505, 09:30:02.000, IBM, 190.70, 190.70, 190.67, 190.67, 57204.00, 300, 0, 0, 3, 0, 0 20140505, 09:30:03.000, IBM, 190.67, 190.67, 190.67, 190.67, 57201.00, 300, 0, 0, 3, 0, 0 20140505, 09:30:05.000, IBM, 190.65, 190.65, 190.63, 190.65, 253368.77, 1329, 0, 0, 10, 0, 0 20140505, 09:30:06.000, IBM, 190.77, 190.77, 190.65, 190.77, 44637.18, 234, 0, 0, 5, 0, 0 20140505, 09:30:11.000, IBM, 190.62, 190.62, 190.60, 190.60, 19060.12, 100, 0, 0, 3, 0, 0
You may notice that there are missing seconds in the sample above. In actuality, we are viewing tick data: the change in price of a security from trade to trade. We could go ahead and pass this data as input into our RNN; however, this data will not generalize well amongst different stocks. Certain stocks will be more volatile than others, and equi-length samples chosen randomly from different securities may span varying time lengths. If it is our intention to use fixed duration sequences as our inputs, we must normalize the data by filling in the missing seconds with the last quoted prices. After a transformation, the above example will now look like this:
20140505, 09:30:01.000, IBM, 190.71 20140505, 09:30:02.000, IBM, 190.70 20140505, 09:30:03.000, IBM, 190.67 20140505, 09:30:04.000, IBM, 190.67 20140505, 09:30:05.000, IBM, 190.65 20140505, 09:30:06.000, IBM, 190.77 20140505, 09:30:07.000, IBM, 190.77 20140505, 09:30:08.000, IBM, 190.77 20140505, 09:30:09.000, IBM, 190.77 20140505, 09:30:10.000, IBM, 190.77 20140505, 09:30:11.000, IBM, 190.62
Imports
Let’s first start out by looking at imports and hyper-parameters. For a complete view of the code, see Full code section below.
import numpy as np import time import os import json from keras.layers.core import Dense, Activation, Dropout from keras.layers.recurrent import LSTM from keras.models import Sequential from keras.utils import np_utils from keras import backend as K from metrics import _precision_21, _accuracy_21, _return_11, _return_12, _return_21, _sum_11, _sum_12, _sum_21 EPOCHS = 1 SEQUENCE_LENGTH = 3600 PREDICTION_LENGTH = 3600 BATCH_SIZE = 3000 STEPS_PER_EPOCH = 1000 VALIDATION_STEPS = 100 BUCKETS = np.array([-float("inf"), -0.50, -0.45, -0.40, -0.35, -0.30, -0.25, -0.20, -.15, -.10, -.05, 0.0, .05, .10, .15, .20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, float("inf")]) NB_CLASSES = 22 INTERESTING_CLASS_ID_11 = 11 INTERESTING_CLASS_ID_12 = 12 INTERESTING_CLASS_ID_21 = 21
- numpy, time, os, and json are self explanatory
- Every package prefixed with keras are the building blocks/abstractions used to build our neural network
- All functions imported from metrics are custom metric functions that will be used to measure success while training
- All capitalized variables are our model hyper-parameters
Generating the data
Data will be fed into our model using a generator. This technique will significantly decrease memory utilization while training.
def generate_data(data, BATCH_SIZE, SEQUENCE_LENGTH, PREDICTION_LENGTH): """ Incrementally generate batches for training @param data: a numpy array @param BATCH_SIZE: int @param SEQUENCE_LENGTH: int @param PREDICTION_LENGTH: int @return: a tuple (inputs, targets) """ x_sequence_batch = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH)) y_sequence_batch = np.zeros(BATCH_SIZE) while 1: for i in range (BATCH_SIZE): # The random start index must occur no later than 1 day and 1 hour before the last timestamp idx = np.random.randint(data.size - (SEQUENCE_LENGTH + PREDICTION_LENGTH)) x_sequence = data[idx:idx + SEQUENCE_LENGTH] x_sequence = normalize(x_sequence) last_sequence_value = data[idx + SEQUENCE_LENGTH - 1] #Decrement by 1 because the range above is inclusive y_sequence = data[idx + SEQUENCE_LENGTH + PREDICTION_LENGTH] y_sequence = get_percentage_change(last_sequence_value, y_sequence) y_sequence = encode(y_sequence) x_sequence_batch[i] = x_sequence y_sequence_batch[i] = y_sequence x_sequence_batch_final = np.reshape(x_sequence_batch, (x_sequence_batch.shape[0], x_sequence_batch.shape[1], 1)) y_sequence_batch_final = np_utils.to_categorical(y_sequence_batch, NB_CLASSES) yield x_sequence_batch_final, y_sequence_batch_final
The generator will yield both the inputs and the targets of the neural network. The input in this example (controlled by SEQUENCE_LENGTH) is a 3600 step sequence representing 1 hour (3600 seconds) of normalized prices. The prices are normalized by removing the mean and scaling to unit variance.
def normalize(prices): """ Normalize a numpy array by removing the mean and scaling to unit variance @param prices: a numpy array @return: a normalized numpy array """ mean = np.mean(prices) std = np.std(prices) normalized_prices = (prices - mean) / std return normalized_prices
The target in this example is a discrete value (22 classes/buckets) corresponding to the percentage change in price of the security 3600 seconds beyond the last time step in our 3600 step input sequence.
def get_percentage_change(previous, current): """ Calculate the percentage increase or decrease of the equity's price @param previous: a float, the last observation/price in the sequence @param current: a float, the price of the equity at some designated time in the future @return: a float """ if current == previous: return 0 try: return ((current - previous)/previous)*100.0 except ZeroDivisionError: return 0
The above code obtains the percentage change of the security sometime in the future (controlled by PREDICTION_LENGTH), and the following snippet transforms these continuous values into discrete values, based on a pre-defined range of values broken down into 22 buckets. See hyper-parameter BUCKETS for pre-defined range of values.
def encode(percent_change): """ Encode continuous value to discrete value @param percent_change: a float @return: an int, that corresponds to a percent change range """ encoded_percentage = np.digitize(percent_change, BUCKETS) - 1 #Subtract 1 to retain zero based indexing return encoded_percentage
Let’s walk through a concrete example. Given a starting index of 9:30:01 and a SEQUENCE_LENGTH = 60, our input would be a sequence of 60 time-steps, ranging from (09:30:01 … 9:31:00). Given a PREDICTION_LENGTH = 60, our target would be a discrete value that represents the price of the security at 9:32:00. If the price of the security at time 9:31:00 was $10.00 and the price of the security at time 9:32:00 was $10.08, we observed a .8% increase. We then transform this percentage change into a discrete value by looking at our bucket ranges to see where a .8% increase belongs.
BUCKETS = np.array([-float("inf"), -0.50, -0.45, -0.40, -0.35, -0.30, -0.25, -0.20, -.15, -.10, -.05, 0.0, .05, .10, .15, .20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, float("inf")])
Since .8% is between .5 and float(“inf”), our target will be the 21st bucket (assuming 0 indexing). Going forward, I will be frequently referring to this 21st bucket because these are the price movements that we actually want to predict.
You may be wondering why I am even bothering with this bucketing system, why don’t I just keep the target variable as a percentage change in price. Well, categorizing the continuous price changes into classes will allow us to use a categorical loss function such as categorical_crossentropy instead of a continuous value loss function such as mean_squared_error. Using a categorical loss function will be very helpful when implementing precision-based custom keras metrics that will help us decipher the behavior of the RNN.
Building the model
Below is the model. I will let the model speak for itself. However, I would like to draw your attention to the custom metrics listed in the metrics array. The function definitions can be found in the Full Code section below. We will discuss these shortly.
def build_model(): model = Sequential() layers = [1, 100, 200, 22] model.add(LSTM( input_dim=layers[0], output_dim=layers[1], return_sequences=True)) model.add(Dropout(0.2)) model.add(LSTM( layers[2], return_sequences=False)) model.add(Dropout(0.2)) model.add(Dense( output_dim=layers[3])) model.add(Activation("softmax")) start = time.time() #Compile and train the model model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', _precision_21(INTERESTING_CLASS_ID_21), _accuracy_21(INTERESTING_CLASS_ID_21), _return_11(INTERESTING_CLASS_ID_11), _return_12(INTERESTING_CLASS_ID_12), _return_21(INTERESTING_CLASS_ID_21), _sum_11(INTERESTING_CLASS_ID_11), _sum_12(INTERESTING_CLASS_ID_12), _sum_21(INTERESTING_CLASS_ID_21)]) return model
Training the network
We will now begin training the model. For each security, we will load 10 years worth of second by second data into a 1D numpy array, which is approximately 58 million data points. Once loaded, we will call fit_generator.
def run_network(model=None, data=None): global_start_time = time.time() print("Proceeding to train network at time (s) ", global_start_time) data_directory = '/data/complete_stock_data/' data_files = os.listdir(data_directory) if model is None: model = build_model() # 1. Iterate through all the equity files # 2. Run the fit generator on each equity try: for data_file in data_files: print(data_file) data_file_path = data_directory + data_file data = load_data(data_file_path) model.fit_generator(generate_data(data, BATCH_SIZE, SEQUENCE_LENGTH, PREDICTION_LENGTH), steps_per_epoch=STEPS_PER_EPOCH, nb_epoch=EPOCHS, validation_data=generate_data(data, BATCH_SIZE, SEQUENCE_LENGTH, PREDICTION_LENGTH), validation_steps=VALIDATION_STEPS) except KeyboardInterrupt: print('Training duration (s) : ', time.time() - global_start_time) return model print('Training duration (s) : ', time.time() - global_start_time) return model
Analyze results
Now for the fun part. We need to be able to answer the following question: Will this make money? The snippet below is the training output of the 999th batch.
999/1000 [============================>.] - ETA: 0s - loss: 1.7876 - acc: 0.6161 - prec_21: 0.6372 - return_21: 10.6725 - sum_21: 299.7007
We will pay attention to the following four metrics: acc, prec_21, return_21, and sum_21. Let’s go through each metric and understand what they signify.
acc: 61.61%
Accuracy measures the models likelihood of predicting the correct label. I was very excited the first time I saw this metric, almost as excited as Brian Fantana introducing Ron Burgundy to Sex Panther.
Really? 60% of the time, it works everytime?… Nope. This metric actually tells us very little. Considering the ultimate objective of this experiment is to identify and take action on patterns that indicate imminent price increases, knowing that the model is accurate 61% of the time tells me nothing about the accuracy of the model when predicting imminent price increases. In our example, a price spike of .5% or greater would correspond to the 21st bucket. For all we know, it may only be our 10th, 11th, or 12th bucket that are yielding correct predictions, and all of our 21st bucket predictions are incorrect.
prec_21: 0.6372
Knowing the precision of the 21st bucket would be much more valuable to us. Precision can be calculated with the following formula:
Precision = tp / p = tp / (tp + fp)
The chart below illustrates the meaning of the variables above:
On the 999th batch, we see a precision very close to our accuracy metric. Again, this is a very optimistic indicator. This means that approximately 63% of our predictions for the 21st bucket were indeed correct predictions. However, this still does not tell us if we will be profitable! For all we know, the 37% of our false positives predictions could actually be huge price drops!
return_21: 10.6725
What we really need to know is what our total return will be if we actually bought and sold the securities based on our predictions. This will account for the magnitude of error in our false positive predictions. return_21 will provide us with a metric that will allow us to calculate our annualized rate of return if we executed on only 21st bucket predictions. Laying approximately a quarter of the way between 10.5 and 11, 10.6325 corresponds to a .01325% price increase. We can now use the future value formula to calculate our annualized rate of return.
FV = PV (1 + r)^n
To keep things simple, let’s set PV equal to 1000. The only thing missing is the value of n, the number of periods. Our investment horizon in this model is only 1 hour (PREDICTION_LENGTH), so theoretically, we could have up to 6 trades in a single day. With 252 trading days a year, this could translate to 1512 trades! Plugging all these values into our formula, we are looking at an annualized rate of return of around 20 – 22%. HOWEVER, this is all too good to be true. Let’s look at one final metric that will undermine everything.
sum_21: 299.7007
sum_21 is a measure of how many predictions (out of the total batch size) were made for the 21st bucket. Our batch size in this example was 300. Having a sum_21 of approximately 300 means almost all predictions were for the 21st bucket. What does this mean? It means that the RNN was unable to observe any patterns. Because the model was optimizing on accuracy, accuracy was maximized by simply always choosing the bucket that was most likely to occur. In other words, if we ranked our true labels by quantity: (bucket_21: 200, bucket_0: 75, bucket 20: 10,…), our model will always end up picking the top ranked bucket. To prove this, I adjusted the SEQUENCE_LENGTH and PREDICTION_LENGTH variables, and took note of the buckets that had the most real values. After long enough training periods, the bucket with the most real values always had a sum of 300. If it is not already obvious, this would translate to us buying and selling a security every hour on the hour. I think it goes without saying that this is a very bad strategy.
Conclusion
Thats right, if you can dodge a wrench you can dodge a ball. Quick hedge — I can’t say I successfully dodged the wrench, so to speak — End of hedge. BUT, if dodging the wrench meant understanding an RNNs behavior in the face of a quintessential random-walk data set, I think we can go on and say that we can now dodge some balls.
Next steps?
- Use a custom loss function that optimizes on rate of return. I actually attempted to use return_21 as the loss function; however, there seem to be outstanding issues with implementing custom loss functions in Keras. Keras was great for prototyping, but moving forward, I will be using Tensorflow. There are just to many abstractions that conceal core issues.
- Use a CNN on sector aggregated end of day equity data. I still believe there are is alpha to be found, just not in this short of a time span. The problem with end of day equity data is that there is far too little data from just one stock, or index. 10 years of end of day data on just one stock is only ~2500 data points, which is not enough to train a neural network. However, assuming stocks in the same sector trade similarly, aggregating by sector would provide us with sufficient data.
- Use a multi-featured RNN on second by second equity data and various leading indicators. Equity prices are obviously derivatives of many other factors / indicators. A multi-featured RNN could potentially find some causality between sequences of indicators. However, acquiring the data for this strategy would be expensive AF.
- Experiment with Quantopia and Numer.ai. These hedge funds provide training data and offer profit sharing incentives for successful models.
Full code
import numpy as np import time import os import json from keras.layers.core import Dense, Activation, Dropout from keras.layers.recurrent import LSTM from keras.models import Sequential from keras.utils import np_utils from keras import backend as K from metrics import _precision_21, _accuracy_21, _return_11, _return_12, _return_21, _sum_11, _sum_12, _sum_21 os.environ["TF_CPP_MIN_LOG_LEVEL"]="1" EPOCHS = 1 SEQUENCE_LENGTH = 360 PREDICTION_LENGTH = 60 BATCH_SIZE = 2000 STEPS_PER_EPOCH = 1000 VALIDATION_STEPS = 100 BUCKETS = np.array([-float("inf"), -0.50, -0.45, -0.40, -0.35, -0.30, -0.25, -0.20, -.15, -.10, -.05, 0.0, .05, .10, .15, .20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, float("inf")]) NB_CLASSES = 22 INTERESTING_CLASS_ID_11 = 11 INTERESTING_CLASS_ID_12 = 12 INTERESTING_CLASS_ID_21 = 21 def generate_data(data, BATCH_SIZE, SEQUENCE_LENGTH, PREDICTION_LENGTH): """ Incrementally generate batches for training @param data: a numpy array @param BATCH_SIZE: int @param SEQUENCE_LENGTH: int @param PREDICTION_LENGTH: int @return: a tuple (inputs, targets) """ x_sequence_batch = np.zeros((BATCH_SIZE, SEQUENCE_LENGTH)) y_sequence_batch = np.zeros(BATCH_SIZE) while 1: for i in range (BATCH_SIZE): # The random start index must occur no later than 1 day and 1 hour before the last timestamp idx = np.random.randint(data.size - (SEQUENCE_LENGTH + PREDICTION_LENGTH)) x_sequence = data[idx:idx + SEQUENCE_LENGTH] x_sequence = normalize(x_sequence) last_sequence_value = data[idx + SEQUENCE_LENGTH - 1] #Decrement by 1 because the range above is inclusive y_sequence = data[idx + SEQUENCE_LENGTH + PREDICTION_LENGTH] y_sequence = get_percentage_change(last_sequence_value, y_sequence) y_sequence = encode(y_sequence) x_sequence_batch[i] = x_sequence y_sequence_batch[i] = y_sequence x_sequence_batch_final = np.reshape(x_sequence_batch, (x_sequence_batch.shape[0], x_sequence_batch.shape[1], 1)) y_sequence_batch_final = np_utils.to_categorical(y_sequence_batch, NB_CLASSES) yield x_sequence_batch_final, y_sequence_batch_final def build_model(): model = Sequential() layers = [1, 100, 200, 22] model.add(LSTM( input_dim=layers[0], output_dim=layers[1], return_sequences=True)) model.add(Dropout(0.2)) model.add(LSTM( layers[2], return_sequences=False)) model.add(Dropout(0.2)) model.add(Dense( output_dim=layers[3])) model.add(Activation("softmax")) start = time.time() #Compile and train the model model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', _precision_21(INTERESTING_CLASS_ID_21), _accuracy_21(INTERESTING_CLASS_ID_21), _return_11(INTERESTING_CLASS_ID_11), _return_12(INTERESTING_CLASS_ID_12), _return_21(INTERESTING_CLASS_ID_21), _sum_11(INTERESTING_CLASS_ID_11), _sum_12(INTERESTING_CLASS_ID_12), _sum_21(INTERESTING_CLASS_ID_21)]) print("Compilation Time : ", time.time() - start) return model def encode(percent_change): """ Encode continuous value to discrete value @param percent_change: a float @return: an int, that corresponds to a percent change range """ encoded_percentage = np.digitize(percent_change, BUCKETS) - 1 #Subtract 1 to retain zero based indexing return encoded_percentage def get_percentage_change(previous, current): """ Calculate the percentage increase or decrease of the equity's price @param previous: a float, the last observation/price in the sequence @param current: a float, the price of the equity at some designated time in the future @return: a float """ if current == previous: return 0 try: return ((current - previous)/previous)*100.0 except ZeroDivisionError: return 0 def normalize(prices): """ Normalize a numpy array by removing the mean and scaling to unit variance @param prices: a numpy array @return: a normalized numpy array """ mean = np.mean(prices) std = np.std(prices) normalized_prices = (prices - mean) / std return normalized_prices def load_data(data_file): """ Load 10 years of second by second equity pricing for a single security into a numpy array. This will yield a 1D array of approximately 58 million data points. @param data_file: a file for 1 security @return: A 1D numpy array """ prices = np.fromfile(data_file, dtype=float, sep="n") return prices def run_network(model=None, data=None): global_start_time = time.time() print("Proceeding to train network at time (s) ", global_start_time) data_directory = '/data/complete_stock_data/' data_files = os.listdir(data_directory) if model is None: model = build_model() # 1. Iterate through all the equity files # 2. Run the fit generator on each equity try: for data_file in data_files: print(data_file) data_file_path = data_directory + data_file data = load_data(data_file_path) model.fit_generator(generate_data(data, BATCH_SIZE, SEQUENCE_LENGTH, PREDICTION_LENGTH), steps_per_epoch=STEPS_PER_EPOCH, nb_epoch=EPOCHS, validation_data=generate_data(data, BATCH_SIZE, SEQUENCE_LENGTH, PREDICTION_LENGTH), validation_steps=VALIDATION_STEPS) except KeyboardInterrupt: print('Training duration (s) : ', time.time() - global_start_time) return model print('Training duration (s) : ', time.time() - global_start_time) return model model = run_network() # SAVE MODEL JSON AND WEIGHTS if not os.path.exists("/data"): os.makedirs("/data") model.save_weights("/data/model.h5", True) with open('/data/model.json', 'w') as outfile: json.dump(model.to_json(), outfile)
from keras import backend as K def _precision_21(interesting_class_id): def prec_21(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) # Replace class_id_true with class_id_preds for recall here accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) return class_acc return prec_21 def _accuracy_21(interesting_class_id): def accuracy_21(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) # Replace class_id_true with class_id_preds for recall here accuracy_mask = K.cast(K.equal(class_id_true, interesting_class_id), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) return class_acc return accuracy_21 def _return_11(interesting_class_id): def return_11(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) predictions_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int64') weighted_bucket = K.sum(class_id_true * predictions_mask) / K.maximum(K.sum(predictions_mask), 1) return weighted_bucket return return_11 def _return_12(interesting_class_id): def return_12(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) predictions_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int64') weighted_bucket = K.sum(class_id_true * predictions_mask) / K.maximum(K.sum(predictions_mask), 1) return weighted_bucket return return_12 def _return_21(interesting_class_id): def return_21(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) predictions_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int64') weighted_bucket = K.sum(class_id_true * predictions_mask) / K.maximum(K.sum(predictions_mask), 1) return weighted_bucket return return_21 def _sum_11(interesting_class_id): def sum_11(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) predictions_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int64') return K.sum(predictions_mask) return sum_11 def _sum_12(interesting_class_id): def sum_12(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) predictions_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int64') return K.sum(predictions_mask) return sum_12 def _sum_21(interesting_class_id): def sum_21(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) predictions_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int64') return K.sum(predictions_mask) return sum_21