Machine learning and security
Computer Cop
Although machine learning (ML) applications always put a great deal of effort into preprocessing data, the algorithms can also automatically detect structures. Deep learning in particular has led to further progress in the field of feature extraction, which makes ML algorithms even more interesting, especially for cybersecurity tasks.
In IT security, data volumes are often huge, and interpreting them involves massive effort, either because of the sheer bulk or the complexity. Not surprisingly, then, cybersecurity product vendors often offer special ML toolkits such as Splunk [1] or Darktrace [2], which apparently relies almost entirely on machine learning.
Although machine learning has not suddenly turned the cybersecurity world completely on its head (even if some product vendors believe it has), you need to answer the following questions – if only to stay on top of the latest developments:
- Which machine learning principles apply to cybersecurity?
- What do typical scenarios for defense and attack look like?
- What trends can be expected in the area of combining machine learning and cybersecurity?
In this article, I try to answer these questions, without claiming to be exhaustive.
ML at a Glance
Every ML system (Figure 1) has an input (x ) through which it (one hopes) receives relevant information and from which it typically makes a classification (y ). In the field of cybersecurity, for example, this would be triggering an alarm or determining that everything is okay.
A target vector (t ) or a reward signal (r ) is used for adaptation (learning) of system M . This feedback does not exist in unsupervised learning, for which the method uses the statistical properties of the input signals, such as the accumulation of similar input patterns. In machine learning, a basic distinction is made between unsupervised learning, supervised learning, and reinforcement learning.
In unsupervised learning, the system independently identifies patterns in the input data, which allows it to detect unusual events, such as a user suddenly emailing a large volume of data to an address outside the organization. Server logs or a direct stream of network data, for example, serve as output data. In the case of anomalies, the system executes actions according to the specifications defined in playbooks (e.g., informing a cybersecurity team, which then checks to see whether a problem exists or whether an employee has had to perform unusual actions for legitimate tasks).
Supervised learning (Figure 2) requires an assignment of the input and output, wherein the system is presented with examples of log data that should not trigger an alarm and other data that should. For example, you can run some malware in a sandbox and track its actions (e.g., which registry entries it makes). The corresponding log data then provides an example for the Malware_Alert class, and other normal log entries are assigned to the No_Alert class. Typically, however, the logs are not typed in directly as input; instead, the data is first cleaned up and features are extracted – the only way to achieve efficient classification – such as the frequency of certain words.
The special feature of reinforcement learning is the feedback received by the system. In this context, experts refer to the reward, which is only given after a number of actions. This technique is also known as delayed reinforcement. The idea behind reinforcement learning is that the system interacts with the environment as an agent, much like a sentient being (Figure 3). The agent has to explore the environment and typically only learns after a certain number of actions whether it was successful and will receive a reward.
Cybersecurity and ML
The use of machine learning in cybersecurity is based on transforming the security problem into an ML problem or, more generally, into a data science problem. For example, one approach could be to use data generated by malware and by a harmless program to distinguish between the two cases. To do this, you set up an identified malware program on a separate virtual machine and track which logfiles it generates. Similarly, you would collect the logfiles from harmless programs. In this way, you can teach the system to distinguish "good" from "bad" log data through supervised learning.
The principle of logfile classification sounds simple, but it requires intensive preprocessing. For classifiers like neural networks, only numerical values are suitable as input variables – and preferably values between 0 and 1. Therefore, the text information from the log first needs to be coded numerically.
In principle, this just means transforming the security problem (malware detection) into a text recognition problem. In the case of logfile classification, you can then turn to proven algorithms as provided by tried and trusted libraries. For example, the Python sklearn library can be used to transform a problem from textual to numerical. However, logfile classification is only one of many examples.
Figure 4 roughly visualizes the approach of the machine learning part in cybersecurity. The breakdown is intended to help provide an overview, but it does not claim to be universally valid.
The starting point (1) is typically a manually identified threat, such as malware or phishing attacks. To identify the threat, you must acquire relevant data (2) (e.g., from operating system or application logfiles). The acquired data is then prepared (3) (e.g., by data science algorithms) for the respective threat scenarios.
In the next step, the data is then further processed by ML algorithms. For selected threat scenarios, one or more algorithms is chosen to prepare for a subsequent alert/no alert decision. Text analysis (4a) helps by combining logfile entries and creating clusters. This technique is used to tag the data as belonging to a pattern class. Feature analysis (4b) is able to detect further patterns in the input data, such as time dependencies.
Deep learning (4c) is another useful variant. Neural networks with an inner structure that contains more than one hidden layer (Figure 5) are referred to as "deep." These networks are particularly well suited for feature extraction, but are more complex to train and interpret. The hidden layers are neurons (shown as circles) that are not directly connected to the output (alert/no alert) or the input data. A deep neural network has more than one such level.
Additionally, numerous other procedures can be used (4d in Figure 4), such as decision trees. The end of the process results in a classification (5): Should an alert be triggered?
Experiments
As an alternative to programming your own neural network, you can install trial versions of well-known tools and explore the possibilities of machine learning with the help of their toolkits.
For example, you can install and use Splunk Enterprise and its Machine Learning Toolkit (Figure 6) with a trial license for up to 60 days and index up to 500MB of data per day. The software provides a number of examples for different deployment scenarios, including IT security. The installation includes sample data for testing various scenarios.
Appropriate data is not always at hand for testing algorithms. Creating the data yourself gives you more control over the order of randomness you desire. For example, the time series in Figure 7 represents logins.
Typically you will see far more logins on workdays and significantly fewer on weekends, showing more or less pronounced fluctuations. Listing 1 shows how such a time series can be generated. The advantage of simulated data is that you can influence individual parameters in a targeted manner, as shown here with the random_gain
parameter (lines 25 and 28) in the noise section.
Listing 1
Creating a Time Series
01 import numpy as np 02 import plotly.graph_objects as go 03 04 step = 1 / 1000 t = np.arange(0, 1, step) # time vector 05 periods = 30 # number of 'days' 06 07 # function to produce base sine data 08 # with a 7th of the base frequency overlap 09 def data_w_weekend(t): 10 if np.sin(periods / 7 * 2 * np.pi * t) > 0.5: 11 value = 0.001 * np.sin(periods * 2 * np.pi * t) 12 return max(value, 0.0) 13 else: 14 value = np.sin(periods * 2 * np.pi * t) 15 return max(value, 0.0) 16 17 # building the data vector 18 my_data = [] 19 i = 0 20 while i < 1000: 21 my_data.append(data_w_weekend(i / 1000)) 22 i += 1 23 24 # add some noise 25 random_gain = 0.1 # factor for the noise 26 i = 0 27 while i < 1000: 28 my_data[i] += np.random.rand() * random_gain 29 i += 1 30 31 my_data_max = np.amax(my_data) 32 print('max value is: ' + str(my_data_max)) 33 # normalize the data to a range up to 1.0 34 my_norm_data = [] 35 i = 0 36 while i < 1000: 37 my_norm_data.append(my_data[i]/my_data_max) 38 i += 1 39 40 # plot the data 41 trace0 = go.Scatter( 42 x = t, 43 y = my_norm_data, 44 name='Logons' 45 ) 46 fig = go.Figure() 47 48 layout = go.Layout(title="Logins over time", xaxis={'title':'time'}, yaxis={'title':'occurences'}) 49 fig = go.Figure(data=trace0, layout=layout) 50 fig.show()
For practical use, you still have to perform tests with real data and handle special cases such as holidays separately. If you want to get more involved in time series prediction, a paper from TensorFlow [3] is recommended reading. The idea here is that if you encounter data that the system is unable to predict, you should suspect that a security problem exists.
Buy this article as PDF
(incl. VAT)