Skip to content

Commit 75705e2

Browse files
committedFeb 7, 2014
MISC: remove reference to deprecated Ward
1 parent de38475 commit 75705e2

8 files changed

+23
-20
lines changed
 

‎benchmarks/bench_plot_ward.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from scipy.cluster import hierarchy
99
import pylab as pl
1010

11-
from sklearn.cluster import Ward
11+
from sklearn.cluster import AgglomerativeClustering
1212

13-
ward = Ward(n_clusters=3)
13+
ward = AgglomerativeClustering(n_clusters=3, linkage='ward')
1414

1515
n_samples = np.logspace(.5, 3, 9)
1616
n_features = np.logspace(1, 3.5, 7)

‎doc/modules/clustering.rst

+6-5
Original file line numberDiff line numberDiff line change
@@ -589,11 +589,12 @@ enable only merging of neighboring pixels on an image, as in the
589589

590590
.. warning:: **Connectivity constraints with average and complete linkage**
591591

592-
Connectivity constraints and complete or average linkage enhance the
593-
'rich getting richer' aspect of agglomerative clustering. In the
594-
limit of a small number of clusters, they tend to give a few
595-
macroscopically occupied clusters and almost empty ones. (see the
596-
discussion in
592+
Connectivity constraints and complete or average linkage can enhance
593+
the 'rich getting richer' aspect of agglomerative clustering,
594+
particularly so if they are built with
595+
:func:`sklearn.neighbors.kneighbors_graph`. In the limit of a small
596+
number of clusters, they tend to give a few macroscopically occupied
597+
clusters and almost empty ones. (see the discussion in
597598
:ref:`example_cluster_plot_agglomerative_clustering.py`).
598599

599600
.. image:: ../auto_examples/cluster/images/plot_agglomerative_clustering_1.png

‎doc/tutorial/statistical_inference/unsupervised_learning.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,10 @@ transposed data.
213213
>>> X = np.reshape(images, (len(images), -1))
214214
>>> connectivity = grid_to_graph(*images[0].shape)
215215

216-
>>> agglo = cluster.WardAgglomeration(connectivity=connectivity,
216+
>>> agglo = cluster.FeatureAgglomeration(connectivity=connectivity,
217217
... n_clusters=32)
218218
>>> agglo.fit(X) # doctest: +ELLIPSIS
219-
WardAgglomeration(compute_full_tree='auto',...
219+
FeatureAgglomeration(compute_full_tree='auto',...
220220
>>> X_reduced = agglo.transform(X)
221221

222222
>>> X_approx = agglo.inverse_transform(X_reduced)

‎examples/cluster/plot_cluster_comparison.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
# create clustering estimators
7373
ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
7474
two_means = cluster.MiniBatchKMeans(n_clusters=2)
75-
ward_five = cluster.AgglomerativeClustering(n_clusters=2,
75+
ward = cluster.AgglomerativeClustering(n_clusters=2,
7676
linkage='ward', connectivity=connectivity)
7777
spectral = cluster.SpectralClustering(n_clusters=2,
7878
eigen_solver='arpack',
@@ -89,7 +89,7 @@
8989
('AffinityPropagation', affinity_propagation),
9090
('MeanShift', ms),
9191
('SpectralClustering', spectral),
92-
('Ward', ward_five),
92+
('Ward', ward),
9393
('AgglomerativeClustering', average_linkage),
9494
('DBSCAN', dbscan)]:
9595
# predict cluster memberships

‎examples/cluster/plot_digits_agglomeration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
X = np.reshape(images, (len(images), -1))
2727
connectivity = grid_to_graph(*images[0].shape)
2828

29-
agglo = cluster.WardAgglomeration(connectivity=connectivity,
30-
n_clusters=32)
29+
agglo = cluster.FeatureAgglomeration(connectivity=connectivity,
30+
n_clusters=32)
3131

3232
agglo.fit(X)
3333
X_reduced = agglo.transform(X)

‎examples/cluster/plot_feature_agglomeration_vs_univariate_selection.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from sklearn.feature_extraction.image import grid_to_graph
2929
from sklearn import feature_selection
30-
from sklearn.cluster import WardAgglomeration
30+
from sklearn.cluster import FeatureAgglomeration
3131
from sklearn.linear_model import BayesianRidge
3232
from sklearn.pipeline import Pipeline
3333
from sklearn.grid_search import GridSearchCV
@@ -66,9 +66,9 @@
6666
mem = Memory(cachedir=cachedir, verbose=1)
6767

6868
# Ward agglomeration followed by BayesianRidge
69-
A = grid_to_graph(n_x=size, n_y=size)
70-
ward = WardAgglomeration(n_clusters=10, connectivity=A, memory=mem,
71-
n_components=1)
69+
connectivity = grid_to_graph(n_x=size, n_y=size)
70+
ward = FeatureAgglomeration(n_clusters=10, connectivity=connectivity,
71+
memory=mem, n_components=1)
7272
clf = Pipeline([('ward', ward), ('ridge', ridge)])
7373
# Select the optimal number of parcels with grid search
7474
clf = GridSearchCV(clf, {'ward__n_clusters': [10, 20, 30]}, n_jobs=1, cv=cv)

‎examples/cluster/plot_ward_structured_vs_unstructured.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import numpy as np
3131
import pylab as pl
3232
import mpl_toolkits.mplot3d.axes3d as p3
33-
from sklearn.cluster import Ward
33+
from sklearn.cluster import AgglomerativeClustering
3434
from sklearn.datasets.samples_generator import make_swiss_roll
3535

3636
###############################################################################
@@ -45,7 +45,7 @@
4545
# Compute clustering
4646
print("Compute unstructured hierarchical clustering...")
4747
st = time.time()
48-
ward = Ward(n_clusters=6).fit(X)
48+
ward = AgglomerativeClustering(n_clusters=6, linkage='ward').fit(X)
4949
elapsed_time = time.time() - st
5050
label = ward.labels_
5151
print("Elapsed time: %.2fs" % elapsed_time)
@@ -71,7 +71,8 @@
7171
# Compute clustering
7272
print("Compute structured hierarchical clustering...")
7373
st = time.time()
74-
ward = Ward(n_clusters=6, connectivity=connectivity).fit(X)
74+
ward = AgglomerativeClustering(n_clusters=6, connectivity=connectivity,
75+
linkage='ward').fit(X)
7576
elapsed_time = time.time() - st
7677
label = ward.labels_
7778
print("Elapsed time: %.2fs" % elapsed_time)

‎sklearn/cluster/tests/test_hierarchical.py

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def test_ward_agglomeration():
204204
mask = np.ones([10, 10], dtype=np.bool)
205205
X = rnd.randn(50, 100)
206206
connectivity = grid_to_graph(*mask.shape)
207+
assert_warns(DeprecationWarning, WardAgglomeration)
207208
ward = WardAgglomeration(n_clusters=5, connectivity=connectivity)
208209
ward.fit(X)
209210
assert_true(np.size(np.unique(ward.labels_)) == 5)

0 commit comments

Comments
 (0)
Please sign in to comment.