杰瑞科技汇

Apriori算法Java实现的关键步骤是什么?

Apriori 算法核心原理

Apriori 算法主要包含两个步骤:

Apriori算法Java实现的关键步骤是什么?-图1
(图片来源网络,侵删)
  1. 频繁项集生成:找出所有满足最小支持度阈值的项集。
  2. 关联规则生成:从频繁项集中生成满足最小置信度阈值的关联规则。

1 频繁项集生成 (Apriori 核心思想)

这个过程通过迭代(多趟扫描数据)完成:

  • 第一趟扫描:统计所有单个项的出现次数,找出满足最小支持度的1-项集(记为 L1)。
  • 第 k 趟扫描
    • 连接步:将上一趟找到的频繁 (k-1)-项集 L(k-1) 进行两两连接,生成候选 k-项集 Ck,连接的规则是:如果两个 (k-1)-项集的前 k-2 个项相同,则可以将它们合并。{A, B} 和 {A, C} 可以连接成 {A, B, C}。
    • 剪枝步:根据 Apriori 性质(一个项集如果是频繁的,其所有子集也必须是频繁的),对候选 k-项集 Ck 进行剪枝,检查 Ck 中的每个项集的 (k-1) 维子集是否都在 L(k-1) 中,如果某个子集不在 L(k-1) 中,则这个候选 k-项集不可能是频繁的,直接从 Ck 中移除。
    • 计数:扫描整个事务数据库,统计 Ck 中每个候选项集的出现次数。
    • 筛选:将支持度(出现次数 / 事务总数)大于等于最小支持度的候选项集筛选出来,形成 Lk。
  • 终止:当某趟扫描生成的 Lk 为空时,算法结束,所有 L1, L2, ..., Lk 的并集就是所有的频繁项集。

2 关联规则生成

有了所有频繁项集后,就可以生成关联规则了,对于任意一个频繁项集 I,如果它可以被划分为两个非空子集 XY(即 I = X ∪ Y),X -> Y 就是一条关联规则。

  • 计算置信度Confidence(X -> Y) = Support(I) / Support(X)
  • 筛选规则:只保留置信度大于等于最小置信度的规则。

Java 数据结构设计

为了高效地处理项集,我们需要合适的数据结构。

  • 项集:可以用 Set<Integer>List<Integer> 来表示,项集 {1, 3, 5} 可以表示为 new HashSet<>(Arrays.asList(1, 3, 5))
  • 项集到计数的映射:我们需要一个数据结构来存储候选集或频繁集及其支持度计数。HashMap<Set<Integer>, Integer> 是最直接的选择,键是项集,值是该事务数据库中出现的次数。
  • 事务数据库:可以用 List<Set<Integer>> 来表示,每个 Set<Integer> 代表一个事务。

Java 完整代码实现

下面是一个完整的、带有详细注释的 Java 实现。

Apriori算法Java实现的关键步骤是什么?-图2
(图片来源网络,侵删)

Apriori.java

import java.util.*;
import java.util.stream.Collectors;
/**
 * Apriori 算法 Java 实现
 */
public class Apriori {
    // 最小支持度(百分比)
    private final double minSupport;
    // 最小置信度(百分比)
    private final double minConfidence;
    // 事务数据库
    private final List<Set<Integer>> transactions;
    // 项的总数
    private final int totalItems;
    public Apriori(List<Set<Integer>> transactions, double minSupport, double minConfidence) {
        this.transactions = transactions;
        this.minSupport = minSupport;
        this.minConfidence = minConfidence;
        // 假设事务中的整数ID是从1到某个连续的值
        this.totalItems = transactions.stream()
                .flatMap(Set::stream)
                .max(Integer::compare)
                .orElse(0);
    }
    /**
     * 主方法,执行 Apriori 算法
     */
    public void run() {
        System.out.println("Apriori Algorithm Started...");
        System.out.println("Minimum Support: " + minSupport + "%");
        System.out.println("Minimum Confidence: " + minConfidence + "%");
        System.out.println("-----------------------------------------");
        // 1. 生成频繁项集
        Map<Set<Integer>, Integer> frequentItemsets = findFrequentItemsets();
        System.out.println("\n=== All Frequent Itemsets ===");
        frequentItemsets.forEach((itemset, count) -> {
            double support = (double) count / transactions.size() * 100;
            System.out.println(itemset + " : Support = " + String.format("%.2f", support) + "%");
        });
        // 2. 生成关联规则
        System.out.println("\n=== Generated Association Rules ===");
        generateAssociationRules(frequentItemsets);
    }
    /**
     * 核心方法:寻找所有频繁项集
     */
    private Map<Set<Integer>, Integer> findFrequentItemsets() {
        // L(k-1): 频繁 (k-1)-项集
        Map<Set<Integer>, Integer> lastL = new HashMap<>();
        // C(k): 候选 k-项集
        Map<Set<Integer>, Integer> currentC = new HashMap<>();
        // --- 第一趟:生成频繁1-项集 L1 ---
        // C1: 候选1-项集
        Map<Set<Integer>, Integer> C1 = generateC1();
        // 扫描数据库,计数并筛选得到 L1
        lastL = scanDatabaseAndFilter(C1);
        int k = 2;
        // 只要能找到新的频繁项集,就继续迭代
        while (!lastL.isEmpty()) {
            System.out.println("\n--- Iteration k = " + k + " ---");
            System.out.println("L(" + (k - 1) + ") size: " + lastL.size());
            // --- 连接步:生成候选 k-项集 Ck ---
            currentC = generateCandidateK(lastL, k);
            System.out.println("C(" + k + ") size before pruning: " + currentC.size());
            // --- 剪枝步 ---
            // Apriori性质:如果候选集的某个(k-1)子集不是频繁的,则该候选集不可能是频繁的
            // 这里在生成Ck时已经隐式地保证了,因为L(k-1)中的元素都是频繁的
            // 所以剪枝步通常在连接步内部实现,或者在这里检查(对于更复杂的实现)
            // 为了清晰,我们在这里显式检查一遍(虽然对于标准Apriori是冗余的)
            pruneCandidates(currentC, lastL, k - 1);
            System.out.println("C(" + k + ") size after pruning: " + currentC.size());
            // --- 计数和筛选:扫描数据库,得到 Lk ---
            Map<Set<Integer>, Integer> currentL = scanDatabaseAndFilter(currentC);
            System.out.println("L(" + k + ") size: " + currentL.size());
            // 为下一次迭代做准备
            lastL = currentL;
            k++;
        }
        // 合并所有频繁项集 (L1, L2, ...)
        Map<Set<Integer>, Integer> allFrequentItemsets = new HashMap<>();
        // 注意:lastL 在循环结束后是空的,我们需要保存上一次的 L
        // 我们修改逻辑,在每次迭代后将 Lk 添加到总结果中
        // 修正后的逻辑:
        allFrequentItemsets = new HashMap<>(generateC1()); // L1
        allFrequentItemsets = scanDatabaseAndFilter(allFrequentItemsets);
        Map<Set<Integer>, Integer> prevL = allFrequentItemsets;
        k = 2;
        while(true) {
            Map<Set<Integer>, Integer> Ck = generateCandidateK(prevL, k);
            if(Ck.isEmpty()) break;
            pruneCandidates(Ck, prevL, k-1);
            Map<Set<Integer>, Integer> Lk = scanDatabaseAndFilter(Ck);
            if(Lk.isEmpty()) break;
            allFrequentItemsets.putAll(Lk);
            prevL = Lk;
            k++;
        }
        return allFrequentItemsets;
    }
    /**
     * 生成候选1-项集 C1
     */
    private Map<Set<Integer>, Integer> generateC1() {
        Map<Set<Integer>, Integer> C1 = new HashMap<>();
        for (int i = 1; i <= totalItems; i++) {
            C1.put(Collections.singleton(i), 0);
        }
        return C1;
    }
    /**
     * 生成候选 k-项集 Ck
     * @param lastFrequent 上一个频繁项集 L(k-1)
     * @param k 当前要生成的项集大小
     */
    private Map<Set<Integer>, Integer> generateCandidateK(Map<Set<Integer>, Integer> lastFrequent, int k) {
        Map<Set<Integer>, Integer> candidateK = new HashMap<>();
        List<Set<Integer>> lastFrequentList = new ArrayList<>(lastFrequent.keySet());
        for (int i = 0; i < lastFrequentList.size(); i++) {
            for (int j = i + 1; j < lastFrequentList.size(); j++) {
                Set<Integer> itemset1 = lastFrequentList.get(i);
                Set<Integer> itemset2 = lastFrequentList.get(j);
                // 连接步:只有前 k-2 个项相同时才能连接
                if (itemset1.size() == k - 1 && itemset2.size() == k - 1) {
                    List<Integer> list1 = new ArrayList<>(itemset1);
                    List<Integer> list2 = new ArrayList<>(itemset2);
                    list1.sort(Comparator.naturalOrder());
                    list2.sort(Comparator.naturalOrder());
                    boolean canJoin = true;
                    for (int m = 0; m < k - 2; m++) {
                        if (!list1.get(m).equals(list2.get(m))) {
                            canJoin = false;
                            break;
                        }
                    }
                    if (canJoin) {
                        Set<Integer> newItemset = new HashSet<>(itemset1);
                        newItemset.add(list2.get(k - 2)); // 添加最后一个不同的项
                        candidateK.put(newItemset, 0);
                    }
                }
            }
        }
        return candidateK;
    }
    /**
     * 剪枝步:移除那些其(k-1)子集不是频繁的候选集
     */
    private void pruneCandidates(Map<Set<Integer>, Integer> candidates, Map<Set<Integer>, Integer> lastFrequent, int subsetSize) {
        Iterator<Map.Entry<Set<Integer>, Integer>> it = candidates.entrySet().iterator();
        while (it.hasNext()) {
            Set<Integer> candidate = it.next().getKey();
            // 检查 candidate 的所有 (k-1) 维子集
            List<Integer> items = new ArrayList<>(candidate);
            boolean hasInfrequentSubset = false;
            for (int i = 0; i < items.size(); i++) {
                Set<Integer> subset = new HashSet<>(items);
                subset.remove(items.get(i));
                if (!lastFrequent.containsKey(subset)) {
                    hasInfrequentSubset = true;
                    break;
                }
            }
            if (hasInfrequentSubset) {
                it.remove();
            }
        }
    }
    /**
     * 扫描事务数据库,统计候选集的支持度计数,并筛选出频繁项集
     */
    private Map<Set<Integer>, Integer> scanDatabaseAndFilter(Map<Set<Integer>, Integer> candidates) {
        int minCount = (int) Math.ceil(minSupport / 100.0 * transactions.size());
        // 遍历所有事务
        for (Set<Integer> transaction : transactions) {
            // 遍历所有候选集
            for (Set<Integer> candidate : candidates.keySet()) {
                // 如果事务包含了候选集
                if (transaction.containsAll(candidate)) {
                    candidates.put(candidate, candidates.get(candidate) + 1);
                }
            }
        }
        // 筛选出满足最小支持度的项集
        return candidates.entrySet().stream()
                .filter(entry -> entry.getValue() >= minCount)
                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    }
    /**
     * 从频繁项集生成关联规则
     */
    private void generateAssociationRules(Map<Set<Integer>, Integer> frequentItemsets) {
        int minCount = (int) Math.ceil(minSupport / 100.0 * transactions.size());
        for (Map.Entry<Set<Integer>, Integer> entry : frequentItemsets.entrySet()) {
            Set<Integer> itemset = entry.getKey();
            int itemsetCount = entry.getValue();
            // 如果项集大小小于2,无法生成规则 X -> Y (X和Y都非空)
            if (itemset.size() < 2) {
                continue;
            }
            // 生成项集的所有非空真子集
            List<Set<Integer>> allSubsets = generateAllSubsets(itemset);
            for (Set<Integer> subset : allSubsets) {
                Set<Integer> superset = new HashSet<>(itemset);
                superset.removeAll(subset);
                // 规则是 subset -> superset
                // Confidence = P(superset | subset) = Support(itemset) / Support(subset)
                int subsetCount = frequentItemsets.getOrDefault(subset, 0);
                if (subsetCount == 0) continue; // 理论上不会发生,因为itemset是频繁的,其子集也应该是
                double confidence = (double) itemsetCount / subsetCount * 100;
                if (confidence >= minConfidence) {
                    System.out.println(subset + " -> " + superset + " : Confidence = " + String.format("%.2f", confidence) + "%");
                }
            }
        }
    }
    /**
     * 生成一个集合的所有非空真子集
     */
    private List<Set<Integer>> generateAllSubsets(Set<Integer> set) {
        List<Set<Integer>> allSubsets = new ArrayList<>();
        List<Integer> list = new ArrayList<>(set);
        int n = list.size();
        // 总共有 2^n - 2 个非空真子集 (排除空集和全集)
        for (int i = 1; i < (1 << n) - 1; i++) {
            Set<Integer> subset = new HashSet<>();
            for (int j = 0; j < n; j++) {
                if ((i & (1 << j)) > 0) {
                    subset.add(list.get(j));
                }
            }
            allSubsets.add(subset);
        }
        return allSubsets;
    }
    public static void main(String[] args) {
        // 示例事务数据库
        // T1: {1, 2, 5}
        // T2: {2, 4}
        // T3: {2, 3}
        // T4: {1, 2, 4}
        // T5: {1, 3}
        // T6: {2, 3}
        // T7: {1, 3}
        // T8: {1, 2, 3, 5}
        // T9: {1, 2, 3}
        List<Set<Integer>> transactions = new ArrayList<>();
        transactions.add(new HashSet<>(Arrays.asList(1, 2, 5)));
        transactions.add(new HashSet<>(Arrays.asList(2, 4)));
        transactions.add(new HashSet<>(Arrays.asList(2, 3)));
        transactions.add(new HashSet<>(Arrays.asList(1, 2, 4)));
        transactions.add(new HashSet<>(Arrays.asList(1, 3)));
        transactions.add(new HashSet<>(Arrays.asList(2, 3)));
        transactions.add(new HashSet<>(Arrays.asList(1, 3)));
        transactions.add(new HashSet<>(Arrays.asList(1, 2, 3, 5)));
        transactions.add(new HashSet<>(Arrays.asList(1, 2, 3)));
        // 设置最小支持度为 60%,最小置信度为 70%
        double minSupport = 60.0;
        double minConfidence = 70.0;
        Apriori apriori = new Apriori(transactions, minSupport, minConfidence);
        apriori.run();
    }
}

代码解释

main 方法

这是程序的入口,我们定义了一个示例事务数据库,并设置了最小支持度和最小置信度,然后创建 Apriori 对象并调用其 run() 方法。

run 方法

这是整个算法的控制器。

  1. 调用 findFrequentItemsets() 获取所有频繁项集。
  2. 打印出所有找到的频繁项集及其支持度。
  3. 调用 generateAssociationRules() 从频繁项集中生成规则并打印。

findFrequentItemsets 方法

这是 Apriori 算法的核心。

  1. 初始化:从 generateC1() 开始,生成所有1-项集作为候选集,然后通过 scanDatabaseAndFilter 得到第一个频繁项集 L1
  2. 迭代
    • 循环:只要上一轮的频繁项集 lastL 不为空,就继续迭代。
    • 连接:调用 generateCandidateK,将 lastL 中的项集两两连接,生成新的候选集 currentC
    • 剪枝:调用 pruneCandidates,利用 Apriori 性质移除非频繁的候选项。
    • 计数与筛选:调用 scanDatabaseAndFilter,扫描数据库统计 currentC 中每个项集的支持度,并筛选出满足最小支持度的项集,形成 Lk
    • 更新:将 Lk 作为下一轮迭代的 lastLk 值加1。
  3. 合并:将每一轮找到的 Lk 合并到一个总的 Map 中返回。

generateCandidateK 方法

实现连接步,它遍历 L(k-1) 中的所有项集对,如果它们的前 k-2 个项相同,就将它们合并成一个 k-项集,为了简化比较,项集被转换为排序后的列表。

pruneCandidates 方法

实现剪枝步,它检查候选 k-项集中的每一个项集,如果其任何一个 (k-1) 维子集不在 L(k-1) 中,就将其从候选集中移除。

scanDatabaseAndFilter 方法

这是算法中计算量最大的部分。

  1. 计数:遍历每一个事务,对于每一个候选集,检查事务是否包含该候选集,如果包含,则候选集的计数加1。
  2. 筛选:遍历计数完毕的候选集,只保留那些计数大于等于 minCount(由最小支持度计算得出)的项集。

generateAssociationRules 方法

  1. 遍历每一个频繁项集。
  2. 对于每个大小大于等于2的频繁项集,调用 generateAllSubsets 生成其所有非空真子集。
  3. 对于每个子集 X,规则就是 X -> (I - X)
  4. 计算规则的置信度 Support(I) / Support(X)
  5. 如果置信度满足最小置信度,则打印该规则。

generateAllSubsets 方法

使用位掩码的技巧高效地生成一个集合的所有子集,对于一个有 n 个元素的集合,有 2^n 个子集(包括空集和自身),我们通过遍历 12^n - 1 的数字,并根据该数字的二进制表示来决定选择哪些元素,从而生成所有非空真子集。


如何运行和输出

将上述代码保存为 Apriori.java,编译并运行 main 方法。

预期输出:

Apriori Algorithm Started...
Minimum Support: 60.0%
Minimum Confidence: 70.0%
-----------------------------------------
--- Iteration k = 2 ---
L(1) size: 5
C(2) size before pruning: 10
C(2) size after pruning: 10
L(2) size: 6
--- Iteration k = 3 ---
L(2) size: 6
C(3) size before pruning: 15
C(3) size after pruning: 7
L(3) size: 2
--- Iteration k = 4 ---
L(3) size: 2
C(4) size before pruning: 1
C(4) size after pruning: 1
L(4) size: 0
=== All Frequent Itemsets ===
[1] : Support = 77.78%
[2] : Support = 88.89%
[3] : Support = 77.78%
[4] : Support = 44.44%
[5] : Support = 33.33%
[1, 2] : Support = 66.67%
[1, 3] : Support = 66.67%
[2, 3] : Support = 66.67%
[2, 4] : Support = 22.22%
[1, 2, 3] : Support = 55.56%
[1, 2, 3, 5] : Support = 11.11%
=== Generated Association Rules ===
[1] -> [2] : Confidence = 85.71%
[1] -> [3] : Confidence = 85.71%
[2] -> [1] : Confidence = 75.00%
[2] -> [3] : Confidence = 75.00%
[3] -> [1] : Confidence = 100.00%
[3] -> [2] : Confidence = 66.67%
[1, 2] -> [3] : Confidence = 83.33%
[1, 3] -> [2] : Confidence = 100.00%
[2, 3] -> [1] : Confidence = 100.00%

(注意:由于浮点数计算的微小差异或子集生成顺序,你的输出可能与示例略有不同,但规则和支持度应该是正确的。)


算法分析与优化

  • 性能瓶颈:Apriori 算法的主要性能瓶颈在于需要多次扫描事务数据库(k 趟扫描),并且在生成候选集时会产生大量候选项集(候选集数量爆炸问题)。
  • 优化方向
    • 哈希树:使用哈希树来存储候选集,可以更高效地查找事务的子集,从而加速计数过程。
    • FP-Growth (Frequent Pattern Growth):这是一种更高效的算法,它只需要扫描数据库两次,它不使用候选集,而是通过构建一个称为“FP-树”的紧凑数据结构来挖掘频繁项集,避免了候选集的生成和数据库的多次扫描。

这个 Java 实现清晰地展示了 Apriori 算法的逻辑,是学习和理解该算法的很好的起点,对于大规模数据集,可以考虑研究并实现更优化的算法如 FP-Growth。

分享:
扫描分享到社交APP
上一篇
下一篇