整体思路:

  1. 基于Modsecurity设计智能WAF
  2. Nginx启动阶段将AI-Security-WAF引擎启动
  3. 将连接器传递过来的所有报文栏位传入引擎
  4. 引擎分析报文状态
  5. 放行正常报文,阻断攻击报文

性能数据:

写了一个简单demo,其实感兴趣可以测试一下没有攻击栏位的报文,毕竟互联网中正常报文比攻击报文要多的多的多的多,我这里跑1000次直接拿到跑1次的毫秒数据,大约20个样本花费80ms,

time:78.788424 ms

isatk:1 样本中有攻击栏位 哈哈~

顺便说一下这个模型准确率也不足,因为网上的样本太少了,顺手随便写了一些正常报文栏位,20个样本也是花费80ms,我们的AI-Security-WAF引擎和报文是正常还是异常不相关。这个速度还是非常慢的,工程上要求比这快100倍才算达标吧~,实际上影响性能主要在分发器性能不行,这个后期要考虑替代方案。

完整代码:

// waf.cpp
#include "Python.h"
#include <iostream>
#include <malloc.h>
#include<string.h>
using namespace std;

#define START
struct timeval time1,time2;
gettimeofday(&time1,NULL);
#define END
gettimeofday(&time2,NULL);
double delta = (time2.tv_sec - time1.tv_sec)*1. +
(time2.tv_usec - time1.tv_usec)/1000000.;
printf("time:%f
",delta);

const char** test_data()
{
const char* field_1 = "1234567890";
const char* field_2 = "hello/world/waf.php";
const char* field_3 = "<script>alert(1)</script>";
const char* field_4 = "-2 union select group_concat(Username),2,3 from Person#";
const char* field_5 = "hi welecome to learn web security";
const char* field_6 = "ex1.php?dir=| cat /etc/passwd";
const char* field_7 = "<script>alert( and 2=2#)</script>";
const char* field_8 = "<script>alert(2)</script>";
const char* field_9 = "<script>alert(3)</script>";
const char* field_10 = " and 1=1#";
char* pkt_load_1 = (char*)malloc(100);
char* pkt_load_2 = (char*)malloc(100);
char* pkt_load_3 = (char*)malloc(100);
char* pkt_load_4 = (char*)malloc(100);
char* pkt_load_5 = (char*)malloc(100);
char* pkt_load_6 = (char*)malloc(100);
char* pkt_load_7 = (char*)malloc(100);
char* pkt_load_8 = (char*)malloc(100);
char* pkt_load_9 = (char*)malloc(100);
char* pkt_load_10 = (char*)malloc(100);
char** pkt_addr = (char**)malloc(20*sizeof(char*));

memcpy(pkt_load_1,field_1,100);
memcpy(pkt_load_2,field_2,100);
memcpy(pkt_load_3,field_3,100);
memcpy(pkt_load_4,field_4,100);
memcpy(pkt_load_5,field_5,100);
memcpy(pkt_load_6,field_6,100);
memcpy(pkt_load_7,field_7,100);
memcpy(pkt_load_8,field_8,100);
memcpy(pkt_load_9,field_9,100);
memcpy(pkt_load_10,field_10,100);

pkt_addr[0] = pkt_load_1;
pkt_addr[1] = pkt_load_2;
pkt_addr[2] = pkt_load_3;
pkt_addr[3] = pkt_load_4;
pkt_addr[4] = pkt_load_5;
pkt_addr[5] = pkt_load_6;
pkt_addr[6] = pkt_load_7;
pkt_addr[7] = pkt_load_8;
pkt_addr[8] = pkt_load_9;
pkt_addr[9] = pkt_load_10;

pkt_addr[10] = pkt_load_1;
pkt_addr[11] = pkt_load_2;
pkt_addr[12] = pkt_load_3;
pkt_addr[13] = pkt_load_4;
pkt_addr[14] = pkt_load_5;
pkt_addr[15] = pkt_load_6;
pkt_addr[16] = pkt_load_7;
pkt_addr[17] = pkt_load_8;
pkt_addr[18] = pkt_load_9;
pkt_addr[19] = pkt_load_10;
return (const char**)pkt_addr;
}

PyObject* aisec_waf_init()
{
Py_Initialize();
PyRun_SimpleString("import sys");
PyRun_SimpleString("sys.path.append(./)");
PyObject* moduleName = PyString_FromString("waf");
PyObject* pModule = PyImport_Import(moduleName);
if (!pModule)
{
cout << "[waf] Python get module failed." << endl;
exit(0);
}
return pModule;
}

PyObject* aisec_waf_load(PyObject* pModule)
{
PyObject* waf_load = PyObject_GetAttrString(pModule,"waf_load");
if (!waf_load || !PyCallable_Check(waf_load))
{
cout << "[waf] Cant find funftion (waf_load)" << endl;
exit(0);
}
PyObject* waf_engine = PyObject_CallObject(waf_load,NULL);
return waf_engine;
}

int aisec_waf_predict(PyObject* pModule,PyObject* waf_engine,const char** pkt_loads)
{
int isAtk = -1;
int fields_num = 20;
PyObject* waf_predict = PyObject_GetAttrString(pModule, "waf_predict");
if (!waf_predict || !PyCallable_Check(waf_predict))
{
cout << "[waf] Cant find funftion (waf_predict)" << endl;
exit(0);
}
PyObject* args = PyTuple_New(2);
PyObject* arg0 = waf_engine;
PyObject* arg1 = PyList_New(fields_num);
for(int i=0; i<fields_num; i++)
{
PyList_SetItem(arg1, i, Py_BuildValue("s", pkt_loads[i]));
}

PyTuple_SetItem(args, 0, arg0);
PyTuple_SetItem(args, 1, arg1);

PyObject* pRet = PyObject_CallObject(waf_predict,args);
if (pRet!=NULL)
{
PyArg_ParseTuple(pRet,"i",&isAtk);
}
return isAtk;
}

int main()
{
const char** pkt_loads = test_data();
PyObject* pModule = aisec_waf_init();
PyObject* waf_engine = aisec_waf_load(pModule);
int num = 1000;
int isAtk = -1;
START
while(num--){
isAtk = aisec_waf_predict(pModule,waf_engine,pkt_loads);
}
END
cout << "isatk:" << isAtk << endl;
return 0;
}


## waf.py
import sys
import urllib
import numpy as np
import tensorflow as tf
import tflearn
from tflearn.data_utils import to_categorical, pad_sequences
from sklearn.model_selection import train_test_split
import time
import re
from sklearn.externals import joblib

def get_feature(line):
x = []
for i, c in enumerate(line):
c = c.lower()
x.append(ord(c))
return x

def load_file(filename,label,ms=[],ns=[]):
with open(filename) as f:
for line in f:
line = line.strip(
)
line = urllib.unquote(line)
if len(line)<= 100:
m = get_feature(line)
n = label
ms.append(m)
ns.append(n)
print(len(ms))

def load_files(files):
xs = []
ys = []
atk_file_i = 0
for filei in files:
load_file(filei,atk_file_i,xs,ys)
atk_file_i += 1
return xs,ys

class XSS_FEATURE(object):
def __init__(self):
return
def get_len(self, url):
return len(url)
def get_url_count(self, url):
if re.search((http://)|(https://), url, re.IGNORECASE) :
return 1
else:
return 0
def get_evil_char(self, url):
return len(re.findall("[<>,"/]", url, re.IGNORECASE))
def get_evil_word(self, url):
return len(re.findall("(alert)|(script=)(%3c)|(%3e)|(%20)|(onerror)|(onload)|(eval)|(src=)|(prompt)",url,re.IGNORECASE))

class SQL_FEATURE(object):
def __init__(self):
return
def get_len(self, url):
return len(url)
def get_url_count(self, url):
if re.search((http://)|(https://), url, re.IGNORECASE):
return 1
else:
return 0
def get_evil_char(self, url):
return len(re.findall("[-,"*/]", url, re.IGNORECASE))
def get_evil_word(self, url):
return len(re.findall("(SELECT)|(CASE)|(WHEN)|(ORDER)|(GROUP)|(count)|(%2C%20)|(char)|(NULL)|(AND)",url,re.IGNORECASE))

class PHP_FEATURE(object):
def __init__(self):
return
def get_len(self, url):
return len(url)
def get_url_count(self, url):
if re.search((http://)|(https://), url, re.IGNORECASE) :
return 1
else:
return 0
def get_evil_char(self, url):
return len(re.findall("[${"=}.|]", url, re.IGNORECASE))
def get_evil_func(self, url):
return len(re.findall("(print)|(assert)|(system)|(preg_replace)|(create_function)|(call_user_func)|(call_user_func_array)|(eval)|(array_map)|(ob_start)|(shell_exec)|(passthru)|(escapeshellcmd)|(proc_popen)|(pcntl_exec)|(phpinfo)|(exit)",url,re.IGNORECASE))
def get_evil_word(self, url):
return len(re.findall("(chr)|(%27)|(passwd)|(whoami)|(base64_decode)", url, re.IGNORECASE))

def get_xss_feature(pkt_loads):
xss_f=XSS_FEATURE()
xss_features = []
for pkt_load in pkt_loads:
f1=xss_f.get_len(pkt_load)
f2=xss_f.get_url_count(pkt_load)
f3=xss_f.get_evil_char(pkt_load)
f4=xss_f.get_evil_word(pkt_load)
xss_features.append([f1,f2,f3,f4])
return xss_features

def get_sql_feature(pkt_loads):
sql_f = SQL_FEATURE()
sql_features = []
for pkt_load in pkt_loads:
f1=sql_f.get_len(pkt_load)
f2=sql_f.get_url_count(pkt_load)
f3=sql_f.get_evil_char(pkt_load)
f4=sql_f.get_evil_word(pkt_load)
sql_features.append([f1,f2,f3,f4])
return sql_features

def get_php_feature(pkt_loads):
php_f=PHP_FEATURE()
php_features = []
for pkt_load in pkt_loads:
f1=php_f.get_len(pkt_load)
f2=php_f.get_url_count(pkt_load)
f3=php_f.get_evil_char(pkt_load)
f4=php_f.get_evil_func(pkt_load)
f5=php_f.get_evil_word(pkt_load)
php_features.append([f1,f2,f3,f4,f5])
return php_features

def get_pkt_feature(atk_index,pkt_loads):
if(len(pkt_loads) == 0):
return []
if(atk_index == 1):
pkt_features = get_xss_feature(pkt_loads)
elif(atk_index == 2):
pkt_features = get_sql_feature(pkt_loads)
elif(atk_index == 3):
pkt_features = get_php_feature(pkt_loads)
return pkt_features

def atk_predict(waf_engines_index,pkt_loadss):
atk_flag = 0
atk_predict_01ss = []
for engine_index,pkt_loads in enumerate(pkt_loadss):
atk_index = engine_index + 1
pkt_features = get_pkt_feature(atk_index,pkt_loads)
if(len(pkt_features) == 0):
atk_predict_01s = []
else:
atk_predict_01s = waf_engines_index[atk_index].predict(pkt_features)
atk_predict_01s = atk_predict_01s.tolist()

atk_predict_01ss.append(atk_predict_01s)
for atk_predict_01s in atk_predict_01ss:
if(1 in atk_predict_01s):
atk_flag = 1
break
atk_predict_info = (atk_predict_01ss,atk_flag)
return atk_predict_info

def classifier_train(x,y):
with tf.Graph().as_default():
x_train, x_test, y_train, y_test=train_test_split( x,y, test_size=0.2,random_state=0)
x_train = pad_sequences(x_train,maxlen=100,value=0.)
x_test = pad_sequences(x_test,maxlen=100,value=0.)
y_train = to_categorical(y_train, nb_classes=4)
y_test = to_categorical(y_test, nb_classes=4)

net = tflearn.input_data([None, 100])
net = tflearn.embedding(net, input_dim=256, output_dim=128)
net = tflearn.lstm(net, 128, dropout=0.8)
net = tflearn.fully_connected(net, 4, activation=softmax)
net = tflearn.regression(net, optimizer=adam, learning_rate=0.1,
loss=categorical_crossentropy)

model = tflearn.DNN(net, tensorboard_verbose=3)
model.fit(x_train, y_train,n_epoch=10, validation_set=(x_test, y_test), show_metric=True,
batch_size=200,run_id="waf-demo-2018")
model.save("./model/waf-demo-2018.tfl")

def waf_load():
with tf.Graph().as_default():
net = tflearn.input_data([None, 100])
net = tflearn.embedding(net, input_dim=256, output_dim=128)
net = tflearn.lstm(net, 128, dropout=0.8)
net = tflearn.fully_connected(net, 4, activation=softmax)
net = tflearn.regression(net, optimizer=adam, learning_rate=0.1,
loss=categorical_crossentropy)
classifier_engine = tflearn.DNN(net, tensorboard_verbose=3)
classifier_engine.load("./model/waf-demo-2018.tfl")
print("load classifier ok!

")

xss_engine = joblib.load("./model/train_model_xss.m")
sql_engine = joblib.load("./model/train_model_sql.m")
php_engine = joblib.load("./model/train_model_php.m")
waf_engines_index = (classifier_engine,xss_engine, sql_engine, php_engine)
print("load atks ok!

")
return waf_engines_index;

def classifier_handout(handout,payloads):
atk_engine_collect_fields = [[],[],[]]
for field_num, atk_id in enumerate(handout):
for j in range(2):
if(atk_id[j] == 1):# xss
atk_engine_collect_fields[0].append(payloads[field_num])
elif(atk_id[j] == 2):# sql
atk_engine_collect_fields[1].append(payloads[field_num])
elif(atk_id[j] == 3):# php
atk_engine_collect_fields[2].append(payloads[field_num])
return atk_engine_collect_fields

def classifier_predict(waf_engines_index,payloads):
classifier_engine = waf_engines_index[0]
fields = []
for payload in payloads:
field = get_feature(urllib.unquote(payload))
fields.append(field)
fields = pad_sequences(fields,maxlen=100,value=0.)
distributions = classifier_engine.predict_label(fields)
handout = distributions[:,0:2]
handout = handout.tolist()
atk_engine_collect_fields = classifier_handout(handout,payloads)
return atk_engine_collect_fields

def data_pre_treat():
pkt_loads = []
x_payload_1 = "1234567890"
x_payload_2 = "hello/world/waf.php"
x_payload_3 = "<script>alert(1)</script>"
x_payload_4 = "-2 union select group_concat(Username),2,3 from Person#"
x_payload_5 = "hi welecome to learn web security"
x_payload_6 = "ex1.php?dir=| cat /etc/passwd"
x_payload_7 = "<script>alert( and 2=2#)</script>"
x_payload_8 = "<script>alert(2)</script>"
x_payload_9 = "<script>alert(3)</script>"
x_payload_10 = " and 1=1#"
pkt_loads.append(x_payload_1)
pkt_loads.append(x_payload_2)
pkt_loads.append(x_payload_3)
pkt_loads.append(x_payload_4)
pkt_loads.append(x_payload_5)
pkt_loads.append(x_payload_6)
pkt_loads.append(x_payload_7)
pkt_loads.append(x_payload_8)
pkt_loads.append(x_payload_9)
pkt_loads.append(x_payload_10)
print("data ok!

")
return pkt_loads

def get_engine_id(engine_index,atk_pkt_loads_value=0):
id = 0
if(atk_pkt_loads_value):
id = 2**engine_index + atk_pkt_loads_value
else:
id = 2**engine_index
return id

def reverse_look(atk_predict_info,atk_engine_collect_fields):
atk_effect_lists = atk_predict_info[0]
atk_pkt_loads = {}
for engine_index,atk_effect_list in enumerate(atk_effect_lists):
if(len(atk_effect_list) == 0):
continue
for field_index,atk_effect in enumerate(atk_effect_list):
if(atk_effect):
key = atk_engine_collect_fields[engine_index][field_index]
if(atk_pkt_loads.get(key) == None):
atk_pkt_loads[key] = get_engine_id(engine_index)
else:
atk_pkt_loads[key] = get_engine_id(engine_index,atk_pkt_loads[key])
return atk_pkt_loads

def waf_predict(waf_engine,payloads):
atk_engine_collect_fields = classifier_predict(waf_engine,payloads)
atk_predict_info = atk_predict(waf_engine,atk_engine_collect_fields)
return (atk_predict_info[1],)

def test():
payloads = data_pre_treat()

waf_engine = waf_load()
isAtk = waf_predict(waf_engine,payloads)

if __name__ == "__main__":
#args =(sys.argv[1],sys.argv[2],sys.argv[3],sys.argv[4])
#xs,ys = load_files(args)
#classifier_train(xs,ys)
test()

推荐阅读:

查看原文 >>
相关文章