将树规则从 Weka 转换为 SQL 查询的设计模式


fac_a < 64
|   fac_d < 71.5
|   |   fac_a < 49.5
|   |   |   fac_d < 23.5 : 19.44 (13/43.71) [13/77.47]
|   |   |   fac_d >= 23.5 : 24.25 (32/23.65) [16/49.15]
|   |   fac_a >= 49.5 : 30.8 (10/17.68) [5/22.44]
|   fac_d >= 71.5 : 33.6 (25/53.05) [15/47.35]
fac_a >= 64
|   fac_d < 83.5
|   |   fac_a < 91
|   |   |   fac_e < 93.5
|   |   |   |   fac_d < 45 : 31.9 (16/23.25) [3/64.14]
|   |   |   |   fac_d >= 45
|   |   |   |   |   fac_e < 21.5 : 44.1 (5/16.58) [2/21.39]
|   |   |   |   |   fac_e >= 21.5
|   |   |   |   |   |   fac_a < 77.5 : 33.45 (4/2.89) [1/0.03]
|   |   |   |   |   |   fac_a >= 77.5 : 39.46 (7/10.21) [1/11.69]
|   |   |   fac_e >= 93.5 : 45.97 (2/8.03) [1/107.71]
|   |   fac_a >= 91 : 42.26 (9/9.57) [4/69.03]
|   fac_d >= 83.5 : 47.1 (9/30.24) [6/40.15]

我想在我的数据集(在 MSSQL 中)中添加一列,该列为我提供了基于这些规则的响应变量的分类预测。将上述内容转换为一组n查询(其中n是我的树上的叶计数)相对容易,其中WHERE子句是从分支信息自动生成的:

-- Rule 1
UPDATE table_name
SET prediction=value1
    fac_a < 64 AND 
    fac_d < 71.5 AND 
    fac_a < 49.5 AND 
    fac_d < 23.5
-- Rule 2
UPDATE table_name
SET prediction=value2
    fac_a < 64 AND 
    fac_d < 71.5 AND 
    fac_a < 49.5 AND 
    fac_d >= 23.5
etc. for each rule

但是当我有复杂的树(大约 100 个叶节点)和 100,000+ 行时,这不能很好地扩展。是否有 SQL 查询的设计模式可以应用此树分类,使我能够更有效地计算预测?

这里有一个想法:将规则放入分层表中,然后将查找打包到递归用户定义的标量函数中。见下文。(出于某种原因,SQL Fiddle 对用户定义的函数不满意,但我在 SQL Server 2012 上测试了它,它应该在 2008 上运行。

它在您提供的示例数据和 1000 行事实表上速度很快。至少,它可能比您现在拥有的更容易管理。这种方法也有可能更好的变体,但看看你的想法。

如果您的决策树深度超过 100 级(或者您没有正确填充规则表),您将达到函数的默认递归深度限制 100。可以使用选项(最大递归 0)表示无限制,也可以使用选项(最大递归 32767)或更少来更改此限制。

create table facdata (
  fac_a decimal(10,4),
  fac_b decimal(10,4),
  fac_c decimal(10,4),
  fac_d decimal(10,4),
  fac_e decimal(10,4),
  val   decimal(10,4)
with v(i) as (
  select 40 union all select 50 union all select 70
  union all select 80 union all select 90 union all select 100
insert facdata
  select a.i, 30, c.i, d.i, e.i, null
  from v as a, v as c, v as d, v as e
create table decisions (
  did hierarchyid primary key,
  fac char,
  split decimal(10,4),
  val decimal(10,4)
insert decisions values
  (cast('/0/' as hierarchyid), 'a', 64,null),
  (cast('/0/0/' as hierarchyid), 'd', 71.5,null),
  (cast('/0/0/0/' as hierarchyid), 'a', 49.5,null),
  (cast('/0/0/0/0/' as hierarchyid), 'd', 23.5,null),
  (cast('/0/0/0/0/0/' as hierarchyid), NULL, NULL,19.44),
  (cast('/0/0/0/0/1/' as hierarchyid), NULL, NULL, 24.25),
  (cast('/0/0/0/1/' as hierarchyid), NULL, NULL, 30.8),
  (cast('/0/0/1/' as hierarchyid), NULL, NULL, 33.6),
  (cast('/0/1/' as hierarchyid), 'd', 83.5,null),
  (cast('/0/1/0/' as hierarchyid), 'a', 91,null),
  (cast('/0/1/1/' as hierarchyid), NULL, NULL, 47.1),
  (cast('/0/1/0/0/' as hierarchyid), 'e', 93.5,null),
  (cast('/0/1/0/0/0/' as hierarchyid), 'd', 45,null),
  (cast('/0/1/0/0/0/0/' as hierarchyid), null,null,31.9),
  (cast('/0/1/0/0/0/1/' as hierarchyid), 'e', 21.5,null),
  (cast('/0/1/0/0/0/1/0/' as hierarchyid), null,null,44.1),
  (cast('/0/1/0/0/0/1/1/' as hierarchyid), 'a', 77.5,null),
  (cast('/0/1/0/0/0/1/1/0/' as hierarchyid), NULL,NULL,33.45),
  (cast('/0/1/0/0/0/1/1/1/' as hierarchyid), NULL,NULL,39.46),
  (cast('/0/1/0/0/1/' as hierarchyid), NULL,NULL,45.97),
  (cast('/0/1/0/1/' as hierarchyid), NULL,NULL, 42.26);
create function dbo.findvalfrom(
  @h hierarchyid,
  @val_a decimal(10,4),
  @val_b decimal(10,4),
  @val_c decimal(10,4),
  @val_d decimal(10,4),
  @val_e decimal(10,4)
) returns decimal(10,4) as begin
    declare @c char;
    declare @s decimal(10,4);
    declare @v decimal(10,4);
      @c = fac, @s = split, @v = val
    from decisions
    where did = @h
    if @v is not null return @v;
    declare @val decimal(10,4);
    set @val = case when @c='a' then @val_a
                    when @c='b' then @val_b
                    when @c='c' then @val_c
                    when @c='d' then @val_d
                    when @c='e' then @val_e end;
    set @h = cast (@h.ToString()+case when @val<@s then '0/' else '1/' end as hierarchyid);
    return dbo.findvalfrom(@h,@val_a,@val_b,@val_c,@val_d,@val_e);

update facdata set
  val = dbo.findvalfrom('/0/',fac_a,fac_b,fac_c,fac_d,fac_e);
select * from facdata;
