• 大小: 9KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-06-02
  • 语言: Python
  • 标签: Python  

资源简介

基于Python的SVM模块源代码基于Python的SVM模块源代码

资源截图

代码片段和文件信息

#!/usr/bin/env python

from ctypes import *
from ctypes.util import find_library
from os import path
import sys

if sys.version_info[0] >= 3:
xrange = range

__all__ = [‘libsvm‘ ‘svm_problem‘ ‘svm_parameter‘
           ‘toPyModel‘ ‘gen_svm_nodearray‘ ‘print_null‘ ‘svm_node‘ ‘C_SVC‘
           ‘EPSILON_SVR‘ ‘LINEAR‘ ‘NU_SVC‘ ‘NU_SVR‘ ‘ONE_CLASS‘
           ‘POLY‘ ‘PRECOMPUTED‘ ‘PRINT_STRING_FUN‘ ‘RBF‘
           ‘SIGMOID‘ ‘c_double‘ ‘svm_model‘]

try:
dirname = path.dirname(path.abspath(__file__))
if sys.platform == ‘win32‘:
libsvm = CDLL(path.join(dirname r‘..\windows\libsvm.dll‘))
else:
libsvm = CDLL(path.join(dirname ‘../libsvm.so.2‘))
except:
# For unix the prefix ‘lib‘ is not considered.
if find_library(‘svm‘):
libsvm = CDLL(find_library(‘svm‘))
elif find_library(‘libsvm‘):
libsvm = CDLL(find_library(‘libsvm‘))
else:
raise Exception(‘LIBSVM library not found.‘)

C_SVC = 0
NU_SVC = 1
ONE_CLASS = 2
EPSILON_SVR = 3
NU_SVR = 4

LINEAR = 0
POLY = 1
RBF = 2
SIGMOID = 3
PRECOMPUTED = 4

PRINT_STRING_FUN = CFUNCTYPE(None c_char_p)
def print_null(s):
return

def genFields(names types):
return list(zip(names types))

def fillprototype(f restype argtypes):
f.restype = restype
f.argtypes = argtypes

class svm_node(Structure):
_names = [“index“ “value“]
_types = [c_int c_double]
_fields_ = genFields(_names _types)

def __str__(self):
return ‘%d:%g‘ % (self.index self.value)

def gen_svm_nodearray(xi feature_max=None isKernel=None):
if isinstance(xi dict):
index_range = xi.keys()
elif isinstance(xi (list tuple)):
if not isKernel:
xi = [0] + xi  # idx should start from 1
index_range = range(len(xi))
else:
raise TypeError(‘xi should be a dictionary list or tuple‘)

if feature_max:
assert(isinstance(feature_max int))
index_range = filter(lambda j: j <= feature_max index_range)
if not isKernel:
index_range = filter(lambda j:xi[j] != 0 index_range)

index_range = sorted(index_range)
ret = (svm_node * (len(index_range)+1))()
ret[-1].index = -1
for idx j in enumerate(index_range):
ret[idx].index = j
ret[idx].value = xi[j]
max_idx = 0
if index_range:
max_idx = index_range[-1]
return ret max_idx

class svm_problem(Structure):
_names = [“l“ “y“ “x“]
_types = [c_int POINTER(c_double) POINTER(POINTER(svm_node))]
_fields_ = genFields(_names _types)

def __init__(self y x isKernel=None):
if len(y) != len(x):
raise ValueError(“len(y) != len(x)“)
self.l = l = len(y)

max_idx = 0
x_space = self.x_space = []
for i xi in enumerate(x):
tmp_xi tmp_idx = gen_svm_nodearray(xiisKernel=isKernel)
x_space += [tmp_xi]
max_idx = max(max_idx tmp_idx)
self.n = max_idx

self.y = (c_double * l)()
for i yi in enumerate(y): self.y[i] = yi

self.x = (POINTER(svm_node) * l)()
for i xi in enumerate(self.x_space): self.x[i] = xi

class svm_parameter(Structure):
_names = [“svm_type“ “kernel_type“ “degree“ “gamma“ “coef0“
“cache_size“ “

评论

共有 条评论