杰瑞科技汇

Java如何实现Apriori算法?

Apriori 算法是一种用于挖掘频繁项集和生成强关联规则的算法,它基于一个核心思想:一个项集如果是频繁的,那么它的所有子集也必须是频繁的,这个性质被称为 Apriori 性质,是算法剪枝(Pruning)的基础。

Java如何实现Apriori算法?-图1
(图片来源网络,侵删)

实现 Apriori 算法主要分为两个步骤:

  1. 找出所有频繁项集:使用迭代的方法,先生成候选 1-项集,然后筛选出频繁 1-项集;接着基于频繁 1-项集生成候选 2-项集,再筛选出频繁 2-项集;以此类推,直到无法生成新的频繁项集为止。
  2. 从频繁项集中生成强关联规则:对于每一个频繁项集 L,如果它包含至少两个项,就尝试从中生成所有可能的非空子集 Xsupport(L) / support(X) >= min_confidence,则 X -> (L - X) 就是一条强关联规则。

准备工作:数据模型

我们需要定义一些基本的数据结构来表示我们的数据。

Item.java - 项

import java.util.Objects;
public class Item implements Comparable<Item> {
    private final String name;
    public Item(String name) {
        this.name = name;
    }
    public String getName() {
        return name;
    }
    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        Item item = (Item) o;
        return Objects.equals(name, item.name);
    }
    @Override
    public int hashCode() {
        return Objects.hash(name);
    }
    @Override
    public String toString() {
        return name;
    }
    @Override
    public int compareTo(Item other) {
        return this.name.compareTo(other.name);
    }
}

Transaction.java - 事务

Java如何实现Apriori算法?-图2
(图片来源网络,侵删)
import java.util.HashSet;
import java.util.Set;
public class Transaction {
    private final Set<Item> items;
    public Transaction(Set<Item> items) {
        this.items = new HashSet<>(items); // 使用 HashSet 保证唯一性
    }
    public Set<Item> getItems() {
        return items;
    }
    @Override
    public String toString() {
        return items.toString();
    }
}

FrequentItemSet.java - 频繁项集 这个类用来存储一个项集及其支持度计数。

import java.util.Set;
public class FrequentItemSet {
    private final Set<Item> itemSet;
    private final int supportCount;
    public FrequentItemSet(Set<Item> itemSet, int supportCount) {
        this.itemSet = itemSet;
        this.supportCount = supportCount;
    }
    public Set<Item> getItemSet() {
        return itemSet;
    }
    public int getSupportCount() {
        return supportCount;
    }
    @Override
    public String toString() {
        return itemSet + " (Support: " + supportCount + ")";
    }
}

核心算法实现

这是整个实现的核心,包含了 Apriori 算法的逻辑。

AprioriAlgorithm.java

import java.util.*;
import java.util.stream.Collectors;
public class AprioriAlgorithm {
    // 最小支持度阈值
    private final double minSupport;
    // 最小置信度阈值
    private final double minConfidence;
    public AprioriAlgorithm(double minSupport, double minConfidence) {
        this.minSupport = minSupport;
        this.minConfidence = minConfidence;
    }
    /**
     * 主执行函数,启动 Apriori 算法
     * @param transactions 所有事务的列表
     * @return 生成的强关联规则列表
     */
    public List<String> run(List<Transaction> transactions) {
        // 1. 找出所有频繁项集
        List<FrequentItemSet> allFrequentItemSets = findFrequentItemSets(transactions);
        System.out.println("\n=== 所有频繁项集 ===");
        allFrequentItemSets.forEach(System.out::println);
        // 2. 从频繁项集中生成强关联规则
        List<String> strongRules = generateStrongRules(allFrequentItemSets, transactions.size());
        return strongRules;
    }
    /**
     * 步骤1:找出所有频繁项集
     * @param transactions 事务列表
     * @return 所有频繁项集的列表
     */
    private List<FrequentItemSet> findFrequentItemSets(List<Transaction> transactions) {
        // L_k: 频繁 k-项集的列表
        List<FrequentItemSet> L = new ArrayList<>();
        // C_k: 候选 k-项集的列表
        List<Set<Item>> C = new ArrayList<>();
        // --- 第一次迭代:生成频繁 1-项集 ---
        // C1: 生成候选 1-项集
        C = generateC1(transactions);
        // L1: 从 C1 中筛选出频繁 1-项集
        L = findFrequentItemSetsFromC(C, transactions);
        // --- 迭代生成更高阶的频繁项集 ---
        int k = 2;
        // 只要能生成新的频繁项集,就继续迭代
        while (!L.isEmpty()) {
            System.out.println("\n--- Iteration k=" + k + " ---");
            // C_k: 生成候选 k-项集
            C = generateCk(L, k);
            System.out.println("候选 " + k + "-项集 C" + k + ": " + C);
            // L_k: 从 C_k 中筛选出频繁 k-项集
            List<FrequentItemSet> newL = findFrequentItemSetsFromC(C, transactions);
            L.addAll(newL); // 将新发现的频繁项集加入总列表
            System.out.println("频繁 " + k + "-项集 L" + k + ": " + newL);
            k++;
        }
        return L;
    }
    /**
     * 生成候选 1-项集 C1
     */
    private List<Set<Item>> generateC1(List<Transaction> transactions) {
        Set<Item> allItems = new HashSet<>();
        for (Transaction t : transactions) {
            allItems.addAll(t.getItems());
        }
        return allItems.stream().map(item -> {
            Set<Item> set = new HashSet<>();
            set.add(item);
            return set;
        }).collect(Collectors.toList());
    }
    /**
     * 根据前一个频繁项集 L_{k-1} 生成候选 k-项集 C_k
     * 使用 Apriori 性质进行剪枝
     */
    private List<Set<Item>> generateCk(List<FrequentItemSet> L_prev, int k) {
        List<Set<Item>> C = new ArrayList<>();
        List<Set<Item>> prevItemSets = L_prev.stream()
                .map(FrequentItemSet::getItemSet)
                .collect(Collectors.toList());
        // 连接步:将前一个频繁项集中的项集两两连接
        for (int i = 0; i < prevItemSets.size(); i++) {
            for (int j = i + 1; j < prevItemSets.size(); j++) {
                Set<Item> set1 = new HashSet<>(prevItemSets.get(i));
                Set<Item> set2 = new HashSet<>(prevItemSets.get(j));
                // 确保前 k-2 个项是相同的,用于连接
                Set<Item> union = new HashSet<>(set1);
                union.addAll(set2);
                if (union.size() == k) {
                    // 剪枝步:检查新项集的所有 (k-1)-子集是否都在 L_{k-1} 中
                    if (hasAllSubsetsInLprev(union, prevItemSets, k - 1)) {
                        C.add(union);
                    }
                }
            }
        }
        return C;
    }
    /**
     * 剪枝步的辅助函数:检查一个项集的所有 (k-1)-子集是否都是频繁的
     */
    private boolean hasAllSubsetsInLprev(Set<Item> itemSet, List<Set<Item>> lPrev, int subsetSize) {
        List<Item> items = new ArrayList<>(itemSet);
        // 生成所有大小为 subsetSize 的子集
        for (int i = 0; i < items.size(); i++) {
            Set<Item> subset = new HashSet<>();
            for (int j = 0; j < subsetSize; j++) {
                subset.add(items.get((i + j) % items.size()));
            }
            if (!lPrev.contains(subset)) {
                return false;
            }
        }
        return true;
    }
    /**
     * 从候选集 C 中计算支持度,并筛选出频繁项集 L
     */
    private List<FrequentItemSet> findFrequentItemSetsFromC(List<Set<Item>> C, int totalTransactions) {
        List<FrequentItemSet> L = new ArrayList<>();
        int minSupportCount = (int) Math.ceil(minSupport * totalTransactions);
        for (Set<Item> itemSet : C) {
            int supportCount = 0;
            for (Transaction t : transactions) {
                if (t.getItems().containsAll(itemSet)) {
                    supportCount++;
                }
            }
            if (supportCount >= minSupportCount) {
                L.add(new FrequentItemSet(itemSet, supportCount));
            }
        }
        return L;
    }
    /**
     * 步骤2:从频繁项集中生成强关联规则
     */
    private List<String> generateStrongRules(List<FrequentItemSet> allFrequentItemSets, int totalTransactions) {
        List<String> strongRules = new ArrayList<>();
        int minSupportCount = (int) Math.ceil(minSupport * totalTransactions);
        for (FrequentItemSet fis : allFrequentItemSets) {
            Set<Item> itemSet = fis.getItemSet();
            // 只处理长度大于等于2的项集
            if (itemSet.size() < 2) {
                continue;
            }
            // 生成该频繁项集的所有非空真子集
            List<Set<Item>> allSubsets = generateAllNonEmptySubsets(itemSet);
            for (Set<Item> antecedent : allSubsets) {
                Set<Item> consequent = new HashSet<>(itemSet);
                consequent.removeAll(antecedent);
                // 计算置信度: confidence = P(Y|X) = support(X ∪ Y) / support(X)
                double supportXY = fis.getSupportCount();
                // 找到前件的支持度
                double supportX = findSupport(antecedent, allFrequentItemSets, totalTransactions);
                double confidence = supportXY / supportX;
                if (confidence >= minConfidence) {
                    String rule = String.format("%s -> %s (Support: %.2f, Confidence: %.2f)",
                            antecedent, consequent,
                            (double) fis.getSupportCount() / totalTransactions,
                            confidence);
                    strongRules.add(rule);
                }
            }
        }
        return strongRules;
    }
    /**
     * 生成一个集合的所有非空真子集
     */
    private List<Set<Item>> generateAllNonEmptySubsets(Set<Item> set) {
        List<Set<Item>> subsets = new ArrayList<>();
        List<Item> items = new ArrayList<>(set);
        int n = items.size();
        // 从 1 到 2^n - 1,每个二进制位代表一个元素是否在子集中
        for (int i = 1; i < (1 << n); i++) {
            Set<Item> subset = new HashSet<>();
            for (int j = 0; j < n; j++) {
                // 检查第 j 位是否为 1
                if ((i & (1 << j)) != 0) {
                    subset.add(items.get(j));
                }
            }
            subsets.add(subset);
        }
        return subsets;
    }
    /**
     * 从频繁项集列表中查找某个项集的支持度
     */
    private int findSupport(Set<Item> itemSet, List<FrequentItemSet> allFrequentItemSets, int totalTransactions) {
        int minSupportCount = (int) Math.ceil(minSupport * totalTransactions);
        for (FrequentItemSet fis : allFrequentItemSets) {
            if (fis.getItemSet().equals(itemSet)) {
                return fis.getSupportCount();
            }
        }
        // 如果没找到,说明它不是频繁的,返回0
        return 0;
    }
}

测试和示例

我们创建一个主类来测试我们的 Apriori 算法实现。

Java如何实现Apriori算法?-图3
(图片来源网络,侵删)

Main.java

import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class Main {
    public static void main(String[] args) {
        // 1. 准备测试数据
        List<Transaction> transactions = Arrays.asList(
                createTransaction("牛奶", "面包", "黄油"),
                createTransaction("啤酒", "面包"),
                createTransaction("牛奶", "尿布", "面包", "鸡蛋"),
                createTransaction("面包", "黄油", "牛奶"),
                createTransaction("啤酒", "尿布"),
                createTransaction("牛奶", "尿布", "面包", "黄油"),
                createTransaction("面包", "鸡蛋", "牛奶"),
                createTransaction("啤酒", "面包"),
                createTransaction("牛奶", "尿布", "面包", "黄油"),
                createTransaction("面包", "牛奶")
        );
        // 2. 设置最小支持度和最小置信度
        // 假设有10个事务,最小支持度为0.4,意味着至少要有4个事务包含该规则
        // 最小置信度为0.6,意味着规则的正确性要有60%的把握
        double minSupport = 0.4;
        double minConfidence = 0.6;
        // 3. 创建 Apriori 算法实例并运行
        AprioriAlgorithm apriori = new AprioriAlgorithm(minSupport, minConfidence);
        List<String> strongRules = apriori.run(transactions);
        // 4. 输出结果
        System.out.println("\n=== 最终生成的强关联规则 ===");
        if (strongRules.isEmpty()) {
            System.out.println("没有找到满足最小支持度和最小置信度的强关联规则。");
        } else {
            for (String rule : strongRules) {
                System.out.println(rule);
            }
        }
    }
    /**
     * 辅助函数:根据商品名称创建一个事务
     */
    private static Transaction createTransaction(String... itemNames) {
        Set<Item> items = Arrays.stream(itemNames)
                .map(Item::new)
                .collect(Collectors.toSet());
        return new Transaction(items);
    }
}

如何运行和解读结果

  1. 将上述所有 Java 文件(Item.java, Transaction.java, FrequentItemSet.java, AprioriAlgorithm.java, Main.java)保存在同一个目录下。
  2. 使用 javac 编译所有文件:javac *.java
  3. 运行主程序:java Main

可能的输出:

--- Iteration k=2 ---
候选 2-项集 C2: [[面包, 牛奶], [面包, 黄油], [面包, 尿布], [面包, 鸡蛋], [牛奶, 黄油], [牛奶, 尿布], [牛奶, 鸡蛋], [黄油, 尿布], [啤酒, 面包], [啤酒, 尿布]]
频繁 2-项集 L2: [[面包, 牛奶] (Support: 6), [面包, 黄油] (Support: 4), [面包, 尿布] (Support: 4), [牛奶, 黄油] (Support: 4), [牛奶, 尿布] (Support: 4), [啤酒, 面包] (Support: 3)]
--- Iteration k=3 ---
候选 3-项集 C3: [[面包, 牛奶, 黄油], [面包, 牛奶, 尿布], [面包, 牛奶, 鸡蛋], [面包, 黄油, 尿布], [牛奶, 黄油, 尿布]]
频繁 3-项集 L3: [[面包, 牛奶, 黄油] (Support: 4), [面包, 牛奶, 尿布] (Support: 4)]
--- Iteration k=4 ---
候选 4-项集 C4: [[面包, 牛奶, 黄油, 尿布]]
频繁 4-项集 L4: [[面包, 牛奶, 黄油, 尿布] (Support: 4)]
--- Iteration k=5 ---
候选 5-项集 C5: []
频繁 5-项集 L5: []
=== 所有频繁项集 ===
[面包] (Support: 7)
[牛奶] (Support: 7)
[黄油] (Support: 4)
[尿布] (Support: 4)
[鸡蛋] (Support: 3)
[啤酒] (Support: 3)
[面包, 牛奶] (Support: 6)
[面包, 黄油] (Support: 4)
[面包, 尿布] (Support: 4)
[牛奶, 黄油] (Support: 4)
[牛奶, 尿布] (Support: 4)
[啤酒, 面包] (Support: 3)
[面包, 牛奶, 黄油] (Support: 4)
[面包, 牛奶, 尿布] (Support: 4)
[面包, 牛奶, 黄油, 尿布] (Support: 4)
=== 最终生成的强关联规则 ===
[面包] -> [牛奶] (Support: 0.60, Confidence: 0.86)
[牛奶] -> [面包] (Support: 0.60, Confidence: 0.86)
[黄油] -> [面包, 牛奶] (Support: 0.40, Confidence: 1.00)
[面包, 牛奶] -> [黄油] (Support: 0.40, Confidence: 0.67)
[尿布] -> [面包, 牛奶] (Support: 0.40, Confidence: 1.00)
[面包, 牛奶] -> [尿布] (Support: 0.40, Confidence: 0.67)
[牛奶, 黄油] -> [面包] (Support: 0.40, Confidence: 1.00)
[面包, 黄油] -> [牛奶] (Support: 0.40, Confidence: 1.00)
[牛奶, 尿布] -> [面包] (Support: 0.40, Confidence: 1.00)
[面包, 尿布] -> [牛奶] (Support: 0.40, Confidence: 1.00)
[面包, 牛奶, 黄油] -> [尿布] (Support: 0.40, Confidence: 1.00)
[尿布] -> [面包, 牛奶, 黄油] (Support: 0.40, Confidence: 1.00)
[面包, 牛奶, 尿布] -> [黄油] (Support: 0.40, Confidence: 1.00)
[黄油] -> [面包, 牛奶, 尿布] (Support: 0.40, Confidence: 1.00)
[面包, 牛奶, 黄油, 尿布] -> [] (Support: 0.40, Confidence: 1.00)
[] -> [面包, 牛奶, 黄油, 尿布] (Support: 0.40, Confidence: 1.00)

算法优化和注意事项

  1. 性能瓶颈:Apriori 算法的主要性能瓶颈在于需要多次扫描整个事务数据库来计算候选集的支持度,当数据集很大时,这会非常耗时。
  2. FP-Growth 算法:为了解决 Apriori 的性能问题,后续出现了 FP-Growth (Frequent Pattern Growth) 算法,它只需要扫描数据库两次,并且不使用候选集,而是通过构建一个叫 FP-Tree 的紧凑数据结构来直接挖掘频繁项集,效率更高。
  3. 剪枝优化generateCk 方法中的剪枝步骤至关重要,它能极大地减少需要计算支持度的候选集数量,是 Apriori 算法的核心。
  4. 数据结构:在实际应用中,为了提高查找和比较效率,可以使用 HashSetHashMap 来存储项集,并用位图等更高效的结构来表示项集。
  5. 并行化:计算候选集支持度的过程可以并行化,因为每个事务的计算是独立的。

这个 Java 实现完整地展示了 Apriori 算法的逻辑,从数据准备到核心算法实现,再到最终规则的生成,是一个很好的学习和实践范例。

分享:
扫描分享到社交APP
上一篇
下一篇