设计模式——抽象工厂

概述

抽象工厂模式(Abstract Factory Pattern)是一种创建型设计模式。他能创建一些列相关的对象,而无需指定其具体类。通俗的来说,抽象工厂模式围绕一个超级工厂创建其他工厂。该超级工厂又称为其他工厂的工厂。

在抽象工厂模式中,接口是负责创建一个相关对象的工厂,不需要显式指定它们的类。每个生成的工厂都能按照工厂模式提供对象。

抽象工厂模式提供了一种创建一系列相关或相互依赖对象的接口,而无需指定具体实现类。通过使用抽象工厂模式,可以将客户端与具体产品的创建过程解耦,使得客户端可以通过工厂接口来创建一族产品。

抽象工厂模式结构

抽象工厂的主要角色包括:

  • 抽象工厂(Abstract Factory):声明了一组用于创建产品对象的方法,每个方法对应一种产品类型。抽象工厂可以是接口或抽象类。
  • 具体工厂(Concrete Factory):实现了抽象工厂接口,负责创建具体产品对象的实例。
  • 抽象产品(Abstract Product):定义了一组产品对象的共同接口或抽象类,描述了产品对象的公共方法。
  • 具体产品(Concrete Product):实现了抽象产品接口,定义了具体产品的特定行为和属性。

抽象工厂的优缺点

优点
  • 确保同一工厂生成的产品相互匹配
  • 避免客户端和具体产品代码的耦合
  • 复合单一职责原则,可以将产品生成的代码抽取到同一位置,使得代码易于维护
  • 复合开闭原则,向应用程序中引入新产品变体时,无需修改客户端代码
缺点
  • 由于该类模式需要向应用中引入大量接口和类,代码可能会比使用之前更加复杂
  • 拓展产品族很困难,增加一个新的产品族需要修改抽象工厂和所有具体工厂的代码

实现方式

  • 从不同的产品类型和产品变体出发,绘制维度矩阵
  • 为所有产品声明抽象产品接口,然后让所有的具体产品实现这些接口
  • 声明抽象工厂接口,并在接口中为所有抽象产品提供一组构建方法
  • 为每种产品实现一个具体工厂类
  • 在应用程序中开发初始化代码。初始化代码根据配置或当前环境,对特定具体工厂类进行初始化,然后将该工厂对象传递给所有需要创建产品的类。
  • 找出代码中所有对产品构造函数的直接调用,将其替换为对工厂对象中相应构建方法的调用

适用场景

  • 如果代码需要与多个不同系列的相关产品交互,但是由于无法提前获取相关信息,或者出于对未来扩展性的考虑, 不希望代码基于产品的具体类进行构建,在这种情况下,可以使用抽象工厂。

    抽象工厂提供了用于创建每个系列产品对象的接口。主要代码通过该接口创造对象,就不会生成与应用程序已生成的产品类型不一致的产品。

  • 如果有一个基于一组抽象方法的类,且其主要功能因此变得不明确,那么在这种情况下可以使用抽象工厂。

    在设计良好的程序中,每个类仅负责一件事情。如果一个类与多种类型产品交互,就可以考虑将工厂方法抽取到独立的工厂类或具备完整功能的抽象工厂中

代码框架

from __future__ import annotations
from abc import ABC, abstractmethod

class AbstractFactory(ABC):
    @abstractmethod
    def create_product_a(self) -> AbstractProductA:
        pass

    @abstractmethod
    def create_product_b(self) -> AbstractProductB:
        pass

class ConcreteFactory1(AbstractFactory):

    def create_product_a(self) -> AbstractProductA:
        return ConcreteProductA1()

    def create_product_b(self) -> AbstractProductB:
        return ConcreteProductB1()

class ConcreteFactory2(AbstractFactory):

    def create_product_a(self) -> AbstractProductA:
        return ConcreteProductA2()

    def create_product_b(self) -> AbstractProductB:
        return ConcreteProductB2()

class AbstractProductA(ABC):

    @abstractmethod
    def useful_function_a(self) -> str:
        pass

class ConcreteProductA1(AbstractProductA):
    def useful_function_a(self) -> str:
        return "The result of the product A1."

class ConcreteProductA2(AbstractProductA):
    def useful_function_a(self) -> str:
        return "The result of the product A2."

class AbstractProductB(ABC):

    @abstractmethod
    def useful_function_b(self) -> None:
        pass

    @abstractmethod
    def another_useful_function_b(self, collaborator: AbstractProductA) -> None:
        pass

class ConcreteProductB1(AbstractProductB):
    def useful_function_b(self) -> str:
        return "The result of the product B1."

    def another_useful_function_b(self, collaborator: AbstractProductA) -> str:
        result = collaborator.useful_function_a()
        return f"The result of the B1 collaborating with the ({result})"

class ConcreteProductB2(AbstractProductB):
    def useful_function_b(self) -> str:
        return "The result of the product B2."

    def another_useful_function_b(self, collaborator: AbstractProductA):

        result = collaborator.useful_function_a()
        return f"The result of the B2 collaborating with the ({result})"

def client_code(factory: AbstractFactory) -> None:

    product_a = factory.create_product_a()
    product_b = factory.create_product_b()

    print(f"{product_b.useful_function_b()}")
    print(f"{product_b.another_useful_function_b(product_a)}", end="")

if __name__ == "__main__":

    print("Client: Testing client code with the first factory type:")
    client_code(ConcreteFactory1())

    print("\n")

    print("Client: Testing the same client code with the second factory type:")
    client_code(ConcreteFactory2())

代码示例

FactoryProducer 类获取 AbstractFactory 对象。它将向 AbstractFactory 传递形状信息 ShapeCIRCLE / RECTANGLE / SQUARE),以便获取它所需对象的类型。同时它还向 AbstractFactory 传递颜色信息 ColorRED / GREEN / BLUE),以便获取它所需对象的类型。

from __future__ import annotations
from abc import ABC, abstractmethod

class AbstractFactory(ABC):
    @abstractmethod
    def get_shape(self, shape) -> Shape:
        pass

    @abstractmethod
    def get_color(self, color) -> Color:
        pass

class ShapeFactory(AbstractFactory):

    def get_shape(self, shape) -> Shape:
        if shape == "Circle":
            return Circle()
        elif shape == "Square":
            return Square()
        elif shape == "Rectangle":
            return Rectangle()
        else:
            raise NotImplementedError(f" The {shape} has not be implemented")

    def get_color(self, color) -> str:
        pass

class ColorFactory(AbstractFactory):

    def get_shape(self, shape) -> str:
        pass

    def get_color(self, color) -> Color:
        if color == "Red":
            return Red()
        elif color == "Green":
            return Green()
        elif color == "Blue":
            return Blue()
        else:
            raise NotImplementedError(f" The {color} has not be implemented")

class Shape(ABC):

    @abstractmethod
    def draw(self) -> str:
        pass

class Circle(Shape):
    def draw(self) -> str:
        return "画了一个圆"

class Square(Shape):
    def draw(self) -> str:
        return "画了一个正方形"

class Rectangle(Shape):
    def draw(self) -> str:
        return "画了一个长方形"

class Color(ABC):

    @abstractmethod
    def fill(self) -> None:
        pass

class Red(Color):
    def fill(self) -> str:
        return "填充了红色"

class Green(Color):
    def fill(self) -> str:
        return "填充了绿色"

class Blue(Color):
    def fill(self) -> str:
        return "填充了蓝色"

class FactoryProducer:
    @classmethod
    def get_factory(self,choice) -> AbstractFactory:
        if choice == "shape":
            return ShapeFactory()
        elif choice == "color":
            return ColorFactory()
        else:
            raise NotImplementedError(f"The {choice} has not be implemented")

if __name__ == "__main__":
    shape_list = ["Circle", "Rectangle", "Square", "Circle", "Rectangle"]
    color_list = ["Red", "Green", "Blue", "Blue", "Green"]

    shape_factory = FactoryProducer.get_factory("shape")
    color_factory = FactoryProducer.get_factory("color")

    for shp, col in zip(shape_list, color_list):
        shape = shape_factory.get_shape(shp)
        print(shape.draw())
        color = color_factory.get_color(col)
        print(color.fill())