from dataclasses import dataclass, field
from typing import Optional, Iterable, Iterator, Self
from kappybara.pattern import Site, Agent, Component, Pattern, Embedding
from kappybara.utils import SetProperty, Property, IndexedSet
[docs]
@dataclass(frozen=True)
class Edge:
"""Represents bonds between sites.
Note:
Edge(x, y) is the same as Edge(y, x).
Attributes:
site1: First site in the bond.
site2: Second site in the bond.
"""
site1: Site
site2: Site
def __eq__(self, other):
return (self.site1 == other.site1 and self.site2 == other.site2) or (
self.site1 == other.site2 and self.site2 == other.site1
)
def __hash__(self):
return hash(frozenset((self.site1, self.site2)))
[docs]
@dataclass
class Mixture:
"""A collection of agents and their connections.
Attributes:
agents: Indexed set of all agents in the mixture.
_embeddings: Cache of embeddings for tracked components.
_max_embedding_width: Maximum diameter of tracked components.
"""
agents: IndexedSet[Agent]
_embeddings: dict[Component, IndexedSet[Embedding]]
_max_embedding_width: int
[docs]
@classmethod
def from_kappa(cls, patterns: dict[str, int]) -> Self:
"""Create a mixture from Kappa pattern strings and counts.
Args:
patterns: Dictionary mapping pattern strings to copy counts.
Returns:
New Mixture with instantiated patterns.
"""
real_patterns = []
for pattern, count in patterns.items():
real_patterns.extend([Pattern.from_kappa(pattern)] * count)
return cls(real_patterns)
def __init__(self, patterns: Optional[Iterable[Pattern]] = None):
"""Initialize a new mixture.
Args:
patterns: Optional collection of patterns to instantiate.
"""
self.agents = IndexedSet()
self._embeddings = {}
self._max_embedding_width = 0
self.agents.create_index("type", Property(lambda a: a.type))
if patterns is not None:
for pattern in patterns:
self.instantiate(pattern)
def __iter__(self) -> Iterator[Component]:
yield from ComponentMixture([Pattern(list(self.agents))])
@property
def kappa_str(self) -> str:
"""The mixture representation in Kappa format.
Returns:
Kappa string with %init declarations for each component type.
"""
return "\n".join(
f"%init: {len(components)} {group.kappa_str}"
for group, components in grouped(
list(component for component in self)
).items()
)
[docs]
def instantiate(self, pattern: Pattern | str, n_copies: int = 1) -> None:
"""Add instances of a pattern to the mixture.
Args:
pattern: Pattern to instantiate, or Kappa string.
n_copies: Number of copies to create.
Raises:
AssertionError: If pattern is underspecified.
"""
if isinstance(pattern, str):
pattern = Pattern.from_kappa(pattern)
assert (
not pattern.underspecified
), "Pattern isn't specific enough to instantiate."
for _ in range(n_copies):
for component in pattern.components:
self.add(component)
[docs]
def add(self, component: Component) -> None:
"""Add a component to the mixture.
Args:
component: Component to add with its agents and connections.
"""
component_ordered = list(component.agents)
new_agents = [agent.detached() for agent in component_ordered]
new_edges = set()
for i, agent in enumerate(component_ordered):
# Duplicate the proper link structure
for site in agent:
if site.coupled:
partner = site.partner
i_partner = component_ordered.index(partner.agent)
new_site = new_agents[i][site.label]
new_partner = new_agents[i_partner][partner.label]
new_edges.add(Edge(new_site, new_partner))
update = MixtureUpdate(agents_to_add=new_agents, edges_to_add=new_edges)
self.apply_update(update)
[docs]
def remove(self, component: Component) -> None:
"""Remove a component from the mixture.
Args:
component: Component to remove.
"""
update = MixtureUpdate()
for agent in component:
update.remove_agent(agent)
self.apply_update(update)
[docs]
def embeddings(self, component: Component) -> IndexedSet[Embedding]:
"""Get embeddings of a tracked component.
Notes:
Returns the number of matches directly returned
by subgraph isomorphism, i.e. not accounting for symmetries.
Args:
component: Component to get embeddings for.
Returns:
Set of embeddings for the component.
Raises:
KeyError: If component is not being tracked.
"""
try:
return self._embeddings[component]
except KeyError as e:
e.add_note(
f"Undeclared component: {component}. To embed it, first use `track_component`."
)
raise
[docs]
def track_component(self, component: Component):
"""Start tracking embeddings of a component.
Args:
component: Component pattern to track.
"""
self._max_embedding_width = max(component.diameter, self._max_embedding_width)
embeddings = IndexedSet(component.embeddings(self))
embeddings.create_index("agent", SetProperty(lambda e: iter(e.values())))
self._embeddings[component] = embeddings
[docs]
def apply_update(self, update: "MixtureUpdate") -> None:
"""Apply a collection of changes to the mixture.
Args:
update: MixtureUpdate specifying changes to apply.
"""
for agent in update.touched_before:
for tracked in self._embeddings:
self._embeddings[tracked].remove_by("agent", agent)
for edge in update.edges_to_remove:
self._remove_edge(edge)
for agent in update.agents_to_remove:
self._remove_agent(agent)
for agent in update.agents_to_add:
self._add_agent(agent)
for edge in update.edges_to_add:
self._add_edge(edge)
# NOTE: the current implementation doesn't directly mutate agent type
update_region = neighborhood(update.touched_after, self._max_embedding_width)
update_region = IndexedSet(update_region)
update_region.create_index("type", Property(lambda a: a.type))
for component_pattern in self._embeddings:
new_embeddings = component_pattern.embeddings(update_region)
for e in new_embeddings:
self._embeddings[component_pattern].add(e)
def _update_embeddings(self) -> None:
for component_pattern in self._embeddings:
self.track_component(component_pattern)
def _add_agent(self, agent: Agent) -> None:
"""Add an agent to the mixture.
Note:
Calling these private functions isn't guaranteed to keep indexes
up to date, which is why they shouldn't be used externally.
The provided agent should not have any bound sites.
Args:
agent: Agent to add (should have empty sites).
Raises:
AssertionError: If agent has bound sites or isn't instantiable.
"""
assert all(site.partner == "." for site in agent) # Check all sites are unbound
assert agent.instantiable
self.agents.add(agent)
def _remove_agent(self, agent: Agent) -> None:
"""Remove an agent from the mixture.
Note:
Any bonds associated with agent must be removed first.
Args:
agent: Agent to remove (should have empty sites).
Raises:
AssertionError: If agent has bound sites.
"""
assert all(site.partner == "." for site in agent) # Check all sites are unbound
self.agents.remove(agent)
def _add_edge(self, edge: Edge) -> None:
"""Add a bond between two sites.
Args:
edge: Edge specifying the bond to create.
Raises:
AssertionError: If either agent is not in the mixture.
"""
assert edge.site1.agent in self.agents
assert edge.site2.agent in self.agents
edge.site1.partner = edge.site2
edge.site2.partner = edge.site1
def _remove_edge(self, edge: Edge) -> None:
"""Remove a bond between two sites.
Args:
edge: Edge specifying the bond to remove.
Raises:
AssertionError: If the edge doesn't exist.
"""
assert edge.site1.partner == edge.site2
assert edge.site2.partner == edge.site1
edge.site1.partner = "."
edge.site2.partner = "."
[docs]
@dataclass
class ComponentMixture(Mixture):
"""A mixture that explicitly tracks connected components.
Attributes:
components: Indexed set of all components in the mixture.
"""
components: IndexedSet[Component]
def __init__(self, patterns: Optional[Iterable[Pattern]] = None):
"""Initialize a component-tracking mixture.
Args:
patterns: Optional collection of patterns to instantiate.
"""
self.components = IndexedSet()
self.components.create_index(
"agent", SetProperty(lambda c: c.agents, is_unique=True)
)
super().__init__(patterns)
def __iter__(self) -> Iterator[Component]:
yield from self.components
[docs]
def embeddings_in_component(
self, match_pattern: Component, mixture_component: Component
) -> list[dict[Agent, Agent]]:
"""Get embeddings of a pattern within a specific component.
Args:
match_pattern: Pattern to find embeddings for.
mixture_component: Component to search within.
Returns:
List of embeddings within the specified component.
"""
return self._embeddings[match_pattern].lookup("component", mixture_component)
[docs]
def track_component(self, component: Component):
"""Start tracking embeddings of a component pattern.
Args:
component: Component pattern to track.
"""
super().track_component(component)
self._embeddings[component].create_index(
"component",
Property(lambda e: self.components.lookup("agent", next(iter(e.values())))),
)
[docs]
def apply_update(self, update: "MixtureUpdate") -> None:
"""Apply a collection of changes to the mixture.
Args:
update: MixtureUpdate specifying changes to apply.
"""
super().apply_update(update)
def _update_embeddings(self) -> None:
for component_pattern in self._embeddings:
self.track_component(component_pattern)
def _add_agent(self, agent: Agent) -> None:
"""Add an agent as a new single-agent component.
Args:
agent: Agent to add.
"""
super()._add_agent(agent)
component = Component([agent])
self.components.add(component)
def _remove_agent(self, agent: Agent) -> None:
"""Remove an agent and its component.
Args:
agent: Agent to remove.
Raises:
AssertionError: If agent is part of a multi-agent component.
"""
super()._remove_agent(agent)
component = self.components.lookup("agent", agent)
assert len(component) == 1
self.components.remove(component)
def _add_edge(self, edge: Edge) -> None:
"""Add an edge, potentially merging components.
Args:
edge: Edge to add between sites.
"""
super()._add_edge(edge)
# If the agents are in different components, merge the components
# TODO: incremental mincut
component1 = self.components.lookup("agent", edge.site1.agent)
component2 = self.components.lookup("agent", edge.site2.agent)
if component1 == component2:
return
# Ensure `component2` is the smaller of the 2
if len(component2) > len(component1):
component1, component2 = component2, component1
relocated: dict[Component, list[Embedding]] = {}
for tracked in self._embeddings:
relocated[tracked] = list(
self._embeddings[tracked].lookup("component", component2)
)
for e in relocated[tracked]:
self._embeddings[tracked].remove(e)
self.components.remove(component2) # NOTE: invokes a redundant linear time pass
for agent in component2:
component1.add(agent)
# TODO: better semantics for this type of operation
# Operate on diffs to set property.. ?
self.components.indices["agent"][agent] = [component1]
for tracked in self._embeddings:
# TODO: refactor when we can register IndexedSet item updates, including
# cached property evaluations
for e in relocated[tracked]:
assert (
self.components.lookup("agent", next(iter(e.values())))
== component1
)
self._embeddings[tracked].add(e)
def _remove_edge(self, edge: Edge) -> None:
"""Remove an edge, potentially splitting components.
Args:
edge: Edge to remove.
"""
super()._remove_edge(edge)
agent1: Agent = edge.site1.agent
agent2: Agent = edge.site2.agent
old_component = self.components.lookup("agent", agent1)
assert old_component == self.components.lookup("agent", agent2)
# Create a new component if the old one got disconnected
maybe_new_component = Component(agent1.depth_first_traversal)
if agent2 in maybe_new_component:
return # The old component is still connected, do nothing
new_component1 = maybe_new_component
new_component2 = Component(agent2.depth_first_traversal)
relocated: dict[Component, list[Embedding]] = {}
for tracked in self._embeddings:
relocated[tracked] = list(
self._embeddings[tracked].lookup("component", old_component)
)
for e in relocated[tracked]:
self._embeddings[tracked].remove(e)
# TODO: need to do manual updates to the indices in `components`
# to do this more efficiently
self.components.remove(old_component)
self.components.add(new_component1)
self.components.add(new_component2)
for tracked in self._embeddings:
# TODO: refactor when we can register IndexedSet item updates, including
# cached property evaluations
for e in relocated[tracked]:
assert self.components.lookup("agent", next(iter(e.values()))) in [
new_component1,
new_component2,
]
self._embeddings[tracked].add(e)
[docs]
@dataclass
class MixtureUpdate:
"""Specifies changes to be applied to a mixture.
Attributes:
agents_to_add: Agents to be added to the mixture.
agents_to_remove: Agents to be removed from the mixture.
edges_to_add: Edges to be created.
edges_to_remove: Edges to be removed.
agents_changed: Agents with internal state changes.
"""
agents_to_add: list[Agent] = field(default_factory=list)
agents_to_remove: list[Agent] = field(default_factory=list)
edges_to_add: set[Edge] = field(default_factory=set)
edges_to_remove: set[Edge] = field(default_factory=set)
agents_changed: set[Agent] = field(default_factory=set) # Agents changed internally
[docs]
def create_agent(self, agent: Agent) -> Agent:
"""Create a new agent based on a template.
Note:
Sites in the created agent will be emptied.
Args:
agent: Template agent to base the new agent on.
Returns:
New agent with empty sites.
"""
new_agent = agent.detached()
self.agents_to_add.append(new_agent)
return new_agent
[docs]
def remove_agent(self, agent: Agent) -> None:
"""Specify to remove an agent and its edges from the mixture.
Args:
agent: Agent to remove.
"""
self.agents_to_remove.append(agent)
for site in agent:
if site.coupled:
self.edges_to_remove.add(Edge(site, site.partner))
[docs]
def connect_sites(self, site1: Site, site2: Site) -> None:
"""Specify to create an edge between two sites.
If the sites are bound to other sites, indicates to remove those edges.
Args:
site1: First site to connect.
site2: Second site to connect.
"""
if site1.coupled and site1.partner != site2:
self.disconnect_site(site1)
if site2.coupled and site2.partner != site1:
self.disconnect_site(site2)
if not site1.partner == site2:
self.edges_to_add.add(Edge(site1, site2))
[docs]
def disconnect_site(self, site: Site) -> None:
"""Specify that a site should be unbound.
Args:
site: Site to disconnect from its partner.
"""
if site.coupled:
self.edges_to_remove.add(Edge(site, site.partner))
[docs]
def register_changed_agent(self, agent: Agent) -> None:
"""Register an agent as having internal state changes.
Args:
agent: Agent that has been internally modified.
"""
self.agents_changed.add(agent)
@property
def touched_after(self) -> set[Agent]:
"""The agents that will be changed or added after this update.
Returns:
Set of agents affected by the update.
"""
touched = self.agents_changed | set(self.agents_to_add)
for edge in self.edges_to_add:
touched.add(edge.site1.agent)
touched.add(edge.site2.agent)
for edge in self.edges_to_remove:
a, b = edge.site1.agent, edge.site2.agent
if a not in self.agents_to_remove: # TODO make agents_to_remove a set
touched.add(a)
if b not in self.agents_to_remove:
touched.add(b)
return touched
@property
def touched_before(self) -> set[Agent]:
"""The agents that will be changed or removed by this update.
Returns:
Set of agents affected before the update is applied.
"""
touched = self.agents_changed | set(self.agents_to_remove)
for edge in self.edges_to_remove:
touched.add(edge.site1.agent)
touched.add(edge.site2.agent)
for edge in self.edges_to_add:
a, b = edge.site1.agent, edge.site2.agent
if a not in self.agents_to_add: # TODO make agents_to_add a set
touched.add(a)
if b not in self.agents_to_add:
touched.add(b)
return touched
[docs]
def neighborhood(agents: Iterable[Agent], radius: int) -> set[Agent]:
"""Get all agents within a distance radius of the given agents.
Args:
agents: Starting agents for the neighborhood.
radius: Maximum distance to include.
Returns:
Set of all agents within the specified radius.
"""
frontier = agents
seen = set(frontier)
for _ in range(radius):
new_frontier = set()
for cur in frontier:
for n in cur.neighbors:
seen.add(n)
if n not in seen:
new_frontier.add(n)
frontier = new_frontier
return seen
[docs]
def grouped(components: Iterable[Component]) -> dict[Component, list[Component]]:
"""Group components by isomorphism.
Args:
components: Components to group.
Returns:
Dictionary mapping representative components to lists of isomorphic components.
"""
grouped: dict[Component, list[Component]] = {}
for component in components:
for group in grouped:
if component.isomorphic(group):
grouped[group].append(component)
break
else:
grouped[component] = [component]
return grouped