我使用 Flink 和 Java 来制作我的推荐系统,使用我们的逻辑。
所以我有一个数据集:
[user] [item]
100 1
100 2
100 3
100 4
100 5
200 1
200 2
200 3
200 6
300 1
300 6
400 7
所以我将所有映射到一个元组:
DataSet<Tuple3<Long, Long, Integer>> csv = text.flatMap(new LineSplitter()).groupBy(0, 1).reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Integer>>() {
@Override
public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple3<Long, Long, Integer>> collector) throws Exception {
Long customerId = 0L;
Long itemId = 0L;
Integer count = 0;
for (Tuple2<Long, Long> item : iterable) {
customerId = item.f0;
itemId = item.f1;
count = count + 1;
}
collector.collect(new Tuple3<>(customerId, itemId, count));
}
});
在我得到所有客户并且是数组列表中的项目之后:
DataSet<CustomerItems> customerItems = csv.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, Integer>, CustomerItems>() {
@Override
public void reduce(Iterable<Tuple3<Long, Long, Integer>> iterable, Collector<CustomerItems> collector) throws Exception {
ArrayList<Long> newItems = new ArrayList<>();
Long customerId = 0L;
for (Tuple3<Long, Long, Integer> item : iterable) {
customerId = item.f0;
newItems.add(item.f1);
}
collector.collect(new CustomerItems(customerId, newItems));
}
});
现在我需要获得所有"相似"的客户。但是在尝试了很多事情之后,没有任何效果。
逻辑将是:
for ci : CustomerItems
c1 = c1.customerId
for ci2 : CustomerItems
c2 = ci2.cstomerId
if c1 != c2
if c2.getItems() have any item inside c1.getItems()
collector.collect(new Tuple2<c1, c2>)
我尝试使用 reduce,但我无法在迭代器上迭代两次(循环内循环)。
谁能帮我?
您可以将数据集与自身交叉,基本上将逻辑 1:1 插入到交叉函数中(不包括 2 个循环,因为交叉为您执行此操作)。
我解决了这个问题,但我需要在"交叉"之后分组和减少。我不知道这是最好的方法。谁能提出一些建议?
结果在这里:
package org.myorg.quickstart;
import org.apache.flink.api.common.functions.CrossFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.util.Collector;
import java.io.Serializable;
import java.util.ArrayList;
public class UserRecommendation {
public static void main(String[] args) throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// le o arquivo cm o dataset
DataSet<String> text = env.readTextFile("/Users/paulo/Downloads/dataset.csv");
// cria tuple com: customer | item | count
DataSet<Tuple3<Long, Long, Integer>> csv = text.flatMap(new LineFieldSplitter()).groupBy(0, 1).reduceGroup(new GroupReduceFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Integer>>() {
@Override
public void reduce(Iterable<Tuple2<Long, Long>> iterable, Collector<Tuple3<Long, Long, Integer>> collector) throws Exception {
Long customerId = 0L;
Long itemId = 0L;
Integer count = 0;
for (Tuple2<Long, Long> item : iterable) {
customerId = item.f0;
itemId = item.f1;
count = count + 1;
}
collector.collect(new Tuple3<>(customerId, itemId, count));
}
});
// agrupa os items do customer dentro do customer
final DataSet<CustomerItems> customerItems = csv.groupBy(0).reduceGroup(new GroupReduceFunction<Tuple3<Long, Long, Integer>, CustomerItems>() {
@Override
public void reduce(Iterable<Tuple3<Long, Long, Integer>> iterable, Collector<CustomerItems> collector) throws Exception {
ArrayList<Long> newItems = new ArrayList<>();
Long customerId = 0L;
for (Tuple3<Long, Long, Integer> item : iterable) {
customerId = item.f0;
newItems.add(item.f1);
}
collector.collect(new CustomerItems(customerId, newItems));
}
});
// obtém todos os itens do customer que pertence a um usuário parecido
DataSet<CustomerItems> ci = customerItems.cross(customerItems).with(new CrossFunction<CustomerItems, CustomerItems, CustomerItems>() {
@Override
public CustomerItems cross(CustomerItems customerItems, CustomerItems customerItems2) throws Exception {
if (!customerItems.customerId.equals(customerItems2.customerId)) {
boolean has = false;
for (Long item : customerItems2.items) {
if (customerItems.items.contains(item)) {
has = true;
break;
}
}
if (has) {
for (Long item : customerItems2.items) {
if (!customerItems.items.contains(item)) {
customerItems.ritems.add(item);
}
}
}
}
return customerItems;
}
}).groupBy(new KeySelector<CustomerItems, Long>() {
@Override
public Long getKey(CustomerItems customerItems) throws Exception {
return customerItems.customerId;
}
}).reduceGroup(new GroupReduceFunction<CustomerItems, CustomerItems>() {
@Override
public void reduce(Iterable<CustomerItems> iterable, Collector<CustomerItems> collector) throws Exception {
CustomerItems c = new CustomerItems();
for (CustomerItems current : iterable) {
c.customerId = current.customerId;
for (Long item : current.ritems) {
if (!c.ritems.contains(item)) {
c.ritems.add(item);
}
}
}
collector.collect(c);
}
});
ci.first(100).print();
System.out.println(ci.count());
}
public static class CustomerItems implements Serializable {
public Long customerId;
public ArrayList<Long> items = new ArrayList<>();
public ArrayList<Long> ritems = new ArrayList<>();
public CustomerItems() {
}
public CustomerItems(Long customerId, ArrayList<Long> items) {
this.customerId = customerId;
this.items = items;
}
@Override
public String toString() {
StringBuilder itemsData = new StringBuilder();
if (items != null) {
for (Long item : items) {
if (itemsData.length() == 0) {
itemsData.append(item);
} else {
itemsData.append(", ").append(item);
}
}
}
StringBuilder ritemsData = new StringBuilder();
if (ritems != null) {
for (Long item : ritems) {
if (ritemsData.length() == 0) {
ritemsData.append(item);
} else {
ritemsData.append(", ").append(item);
}
}
}
return String.format("[ID: %d, Items: %s, RItems: %s]", customerId, itemsData, ritemsData);
}
}
public static final class LineFieldSplitter implements FlatMapFunction<String, Tuple2<Long, Long>> {
@Override
public void flatMap(String value, Collector<Tuple2<Long, Long>> out) {
// normalize and split the line
String[] tokens = value.split("t");
if (tokens.length > 1) {
out.collect(new Tuple2<>(Long.valueOf(tokens[0]), Long.valueOf(tokens[1])));
}
}
}
}
与要点链接: https://gist.github.com/prsolucoes/b406ae98ea24120436954967e37103f6