Skip to content

Commit f9d6258

Browse files
authored
Merge pull request #118 from mschauer/cpdag
Performance improvements
2 parents 924838b + eefc06d commit f9d6258

File tree

5 files changed

+95
-35
lines changed

5 files changed

+95
-35
lines changed

src/ges.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ function tails_and_adj_neighbors(g, x, y)
206206
Nb[.~ a], Nb[a]
207207
end
208208
function adj_neighbors(g, x, y)
209-
intersect(inneighbors(g,y), outneighbors(g,y), all_neighbors(g,x))
209+
# a = intersect(inneighbors(g,y), outneighbors(g,y), all_neighbors(g,x))
210+
sorted_intersect_(neighbors_undirected(g,y), all_neighbors(g,x))
210211
end
211212

212213

src/meek.jl

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,38 @@ function meek_rules!(g; rule4=false)
2828
end
2929

3030
"""
31-
meek_rule1(dg, v, w)
31+
meek_rule1(g, v, w)
3232
3333
Rule 1: Orient v-w into v->w whenever there is u->v
3434
such that u and w are not adjacent
3535
(otherwise a new v-structure is created.)
3636
"""
37-
function meek_rule1(dg, v, w)
38-
for u in inneighbors(dg, v)
39-
has_edge(dg, v => u) && continue # not directed
40-
isadjacent(dg, u, w) && continue
37+
function meek_rule1(g, v, w)
38+
for u in inneighbors(g, v)
39+
u == w && continue
40+
has_edge(g, v => u) && continue # not directed
41+
isadjacent(g, u, w) && continue
4142
return true
4243
end
4344
return false
4445
end
4546

4647
"""
47-
meek_rule2(dg, v, w)
48+
meek_rule2(g, v, w)
4849
4950
Rule 2: Orient v-w into v->w whenever there is a chain v->k->w
5051
(otherwise a directed cycle is created.)
5152
"""
52-
function meek_rule2(dg, v, w)
53+
function meek_rule2(g, v, w)
5354
outs = Int[]
54-
for k in outneighbors(dg, v)
55-
!has_edge(dg, k => v) && push!(outs, k)
55+
for k in outneighbors(g, v)
56+
k == w && continue
57+
!has_edge(g, k => v) && push!(outs, k)
5658
end
5759
ins = Int[]
58-
for k in inneighbors(dg, w)
59-
!has_edge(dg, w => k) && push!(ins, k)
60+
for k in inneighbors(g, w)
61+
k == v && continue
62+
!has_edge(g, w => k) && push!(ins, k)
6063
end
6164
if !disjoint_sorted(ins, outs)
6265
return true
@@ -65,42 +68,42 @@ function meek_rule2(dg, v, w)
6568
end
6669

6770
"""
68-
meek_rule3(dg, v, w)
71+
meek_rule3(g, v, w)
6972
7073
Rule 3 (Diagonal): Orient v-w into v->w whenever there are two chains
7174
v-k->w and v-l->w such that k and l are nonadjacent
7275
(otherwise a new v-structure or a directed cycle is created.)
7376
"""
74-
function meek_rule3(dg, v, w)
75-
fulls = [] # Find nodes k where v-k
76-
for k in outneighbors(dg, v)
77-
has_edge(dg, k => v) || continue
77+
function meek_rule3(g, v, w)
78+
fulls = Int[] # Find nodes k where v-k
79+
for k in outneighbors(g, v)
80+
has_edge(g, k => v) || continue
7881
# Skip if not k->w (or if not l->w)
79-
if has_edge(dg, w => k) || !has_edge(dg, k => w)
82+
if has_edge(g, w => k) || !has_edge(g, k => w)
8083
continue
8184
end
8285
push!(fulls, k)
8386
end
8487
for (k, l) in combinations(fulls, 2) # FIXME:
85-
isadjacent(dg, k, l) && continue
88+
isadjacent(g, k, l) && continue
8689
return true
8790
end
8891
return false
8992
end
9093

9194
"""
92-
meek_rule4(dg, v, w)
95+
meek_rule4(g, v, w)
9396
9497
Rule 4: Orient v-w into v→w if v-k→l→w where adj(v,l) and not adj(k,w) [check].
9598
"""
96-
function meek_rule4(dg, v, w)
97-
for l in inneighbors(dg, w)
98-
has_edge(dg, w => l) && continue # undirected
99-
!isadjacent(dg, v, l) && continue # not adjacent to v
100-
for k in inneighbors(dg, l)
101-
has_edge(dg, l => k) && continue # undirected
102-
!has_both(dg, v, k) && continue # not undirected to v
103-
isadjacent(dg, k, w) && continue # adjacent to w
99+
function meek_rule4(g, v, w)
100+
for l in inneighbors(g, w)
101+
has_edge(g, w => l) && continue # undirected
102+
!isadjacent(g, v, l) && continue # not adjacent to v
103+
for k in inneighbors(g, l)
104+
has_edge(g, l => k) && continue # undirected
105+
!has_both(g, v, k) && continue # not undirected to v
106+
isadjacent(g, k, w) && continue # adjacent to w
104107
return true
105108
end
106109
end

src/misc.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ using Random
44
55
Create `DiGraph` from edge-list.
66
"""
7-
function digraph(E)
8-
d = maximum(flatten(E))
7+
function digraph(E, d = maximum(flatten(E)))
98
g = DiGraph(d)
109
for (i, j) in E
1110
add_edge!(g, i, j)
@@ -28,11 +27,12 @@ function graph(E)
2827
end
2928

3029
"""
31-
vpairs(g)
30+
arrows(g)
3231
3332
Return the edge-list as `Pair`s.
3433
"""
35-
vpairs(g) = map(Pair, collect(edges(g)))
34+
arrows(g::SimpleDiGraph{T}) where T = nv(g) > 0 ? map(Pair, edges(g)) : Pair{T,T}[]
35+
const vpairs = arrows
3636

3737
"""
3838
skel_oracle(g; stable=true)

src/pdag.jl

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,71 @@ orientedge!(g, x, y) = @assert has_edge(g, x, y) && rem_edge!(g, y, x)
122122
All vertices in `g` connected to `x` by an undirected edge.
123123
Returns sorted array.
124124
"""
125-
neighbors_undirected(g, x) = inneighbors(g, x) outneighbors(g, x)
125+
function neighbors_undirected(g, x)
126+
a = outneighbors(g, x)
127+
b = inneighbors(g, x)
128+
z = sorted_intersect_(a, b)
129+
# @assert z == a ∩ b
130+
z
131+
end
126132

127133
"""
128134
neighbors_adjacent(g, x)
129135
130136
All vertices in `g` connected to `x` by an any edge.
131137
Returns sorted array.
132138
"""
133-
neighbors_adjacent(g, x) = sort(outneighbors(g, x) inneighbors(g, x))
139+
function neighbors_adjacent(g::SimpleDiGraph, x)
140+
a = outneighbors(g, x)
141+
b = inneighbors(g, x)
142+
z = sorted_union_(a, b)
143+
#@assert z == sort(a ∪ b)
144+
z
145+
end
146+
function sorted_union_(x::Vector{T}, y::Vector{T}) where T
147+
z = T[]
148+
i = j = 1
149+
while true
150+
if j > length(y)
151+
append!(z, @view x[i:end])
152+
return z
153+
end
154+
if i > length(x)
155+
append!(z, @view y[j:end])
156+
return z
157+
end
158+
if x[i] < y[j]
159+
push!(z, x[i])
160+
i += 1
161+
elseif x[i] == y[j]
162+
push!(z, x[i])
163+
i += 1
164+
j += 1
165+
else #x[i] < y[]
166+
push!(z, y[j])
167+
j += 1
168+
end
169+
end
170+
end
171+
function sorted_intersect_(x::Vector{T}, y::Vector{T}) where T
172+
z = T[]
173+
i = j = 1
174+
while true
175+
if j > length(y) || i > length(x)
176+
return z
177+
end
178+
if x[i] < y[j]
179+
i += 1
180+
elseif x[i] == y[j]
181+
push!(z, x[i])
182+
i += 1
183+
j += 1
184+
else #x[i] < y[]
185+
j += 1
186+
end
187+
end
188+
end
189+
134190

135191
"""
136192
parents(g, x)

test/cpdag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ C3 = copy(E3)
2323
append!(C3, reverse.(E3))
2424

2525
h, S = skel_oracle(digraph(E1a))
26-
@test vpairs(h) == E1a
26+
@test map(Pair, edges(h)) == E1a
2727
@test length(S) == 2
2828
@test S[Edge(2, 3)] == [1]
2929
@test S[Edge(1, 4)] == [2, 3]

0 commit comments

Comments
 (0)