[Biojava-dev] WeightedSet

mark.schreiber at group.novartis.com mark.schreiber at group.novartis.com
Mon Jun 14 22:29:00 EDT 2004


Hi all -

Below is a Set and a Test that are inspired by the BioJava Distribution 
objects. The WeightedSet is very much like a Distribution but contains 
Objects which have associated weights. Importantly you can sample these 
objects in the same way you can sample a Distribution. The major 
difference is that the WeightedSet can contain any object and not just 
Symbols.

I have found these useful for doing random (or weighted) sampling of 
objects (which are not symbols) for various applications. I could be 
included in the utils section of biojava if others think it will be useful 
although it is not specifically for biological applications.

The test class shows the expected usage and behaivour



!!!-----CODE STARTS HERE-----!!!

import java.util.*;

/**
 * <p>Inspred by the BioJava Distribution objects the WeightedSet is a map 
from
 * a Key to a Weight. Unlike Distributions the Keys do not have to be 
Symbols.
 * </p>
 *
 * <p>When Objects are added or their weights are set then the weights are 
internally
 * normalized to 1.0</p>
 *
 * @author Mark Schreiber
 * @version 1.0
 */

public class WeightedSet extends AbstractSet implements 
java.io.Serializable{
  private HashMap key2Weight;
  double totalWeight;

  public WeightedSet() {
    key2Weight = new HashMap();
  }

  /**
   * Converts the Set to a map from key <code>Objects</code> to 
<code>Double</code>
   * weights.
   * @return a Map with all the key-weight mappings. Weights are not 
normalized in this map.
   */
  public Map asMap(){
    return key2Weight;
  }

  /**
   * Randomly samples an <code>Object</code> from the <code>Set</code> 
according
   * to its weight.
   * @return the Object sampled.
   */
  public Object sample(){
    double p = Math.random();
    for (Iterator i = this.iterator(); i.hasNext(); ) {
      Object o = i.next();
      double weight = getWeight(o);

      p -= weight;
      if(p <= 0.0){
        return o;
      }
    }
    throw new org.biojava.bio.BioError("Cannot sample an object, does this 
set contain any objects?");
  }

  /**
   * Determines the normalized weight for <code>o</code>
   * @param o the <code>Object</code> you want to know the weight of
   * @return the normalized weight
   * @throws NoSuchElementException if <code>o</code> is not found in this 
set
   */
  public double getWeight(Object o) throws NoSuchElementException{
    if(!( key2Weight.containsKey(o)))
      throw new NoSuchElementException(o+" not found in this 
WeightedSet");

    Double d = (Double)key2Weight.get(o);
    if(totalWeight == 0.0)
      return 0.0;


    return d.doubleValue() / totalWeight;
  }

  /**
   * The total weight that has been added to this Set.
   * @return the total weight (the value that can be used for normalizing)
   */
  protected double getTotalWeight(){
    return totalWeight;
  }

  /**
   * Sets the weight of an <code>Object</code>. If the <code>Object</code> 
is
   * not in this <code>Set</code> then it is added.
   * @param o the <code>Object</code>
   * @param w the weight.
   * @throws IllegalArgumentException if <code>w</code> is < 0.0
   */
  public void setWeight(Object o, double w){
    if(w < 0.0){
      throw new IllegalArgumentException("Weight must be >= 0.0");
    }
    if(key2Weight.containsKey(o)){
      remove(o);
    }
    totalWeight += w;
    key2Weight.put(o, new Double(w));
  }

  public boolean contains(Object o) {
    return key2Weight.containsKey(o);
  }

  public boolean remove(Object o) {
    if(key2Weight.containsKey(o)){
      totalWeight -= ((Double)key2Weight.get(o)).doubleValue();
      key2Weight.remove(o);
      return true;
    }
    return false;
  }

  public boolean isEmpty() {
    return key2Weight.isEmpty();
  }
  public boolean retainAll(Collection c) {
    boolean b = false;
    Collection toRemove = new ArrayList();

    for (Iterator i = iterator(); i.hasNext(); ) {
      Object item = i.next();
      if(c.contains(item) == false){
        b = true;
        toRemove.add(item);
      }
    }

    removeAll(toRemove);

    return b;
  }

  /**
   * Adds a new <code>Object</code> with a weight of zero. Equivalent to
   * setWeight(o, 0.0);
   * @param o the object to add.
   * @return true if this Object has not been added before.
   */
  public boolean add(Object o) {
    boolean b = !(key2Weight.containsKey(o));
    setWeight(o, 0.0);
    return b;
  }
  public int size() {
    return key2Weight.size();
  }

  public boolean containsAll(Collection c) {
    if(size() == 0)
      return false;

    for (Iterator i = iterator(); i.hasNext(); ) {
      Object item = i.next();
      if(!(key2Weight.containsKey(item))){
        return false;
      }
    }
    return true;
  }
  public Object[] toArray() {
    Object[] o = new Object[size()];
    int j = 0;
    for (Iterator i = iterator(); i.hasNext(); ) {
      o[j++] = i.next();
    }

    return o;
  }

  public void clear() {
    key2Weight = new HashMap();
    totalWeight = 0.0;
  }
  public Iterator iterator() {
    return key2Weight.keySet().iterator();
  }

  public boolean addAll(Collection c) {
    boolean b = false;

    for (Iterator i = c.iterator(); i.hasNext(); ) {

      Object item = i.next();
      if(!(key2Weight.containsKey(item)))
         b = true;

      add(item);
    }
    return b;
  }
}

!!!!-------TEST CODE STARTS HERE -------!!!!

import junit.framework.*;
import java.util.*;

public class TestWeightedSet
    extends TestCase {
  private WeightedSet weightedSet = null;

  public TestWeightedSet(String name) {
    super(name);
  }

  protected void setUp() throws Exception {
    super.setUp();

    weightedSet = new WeightedSet();
  }

  protected void tearDown() throws Exception {
    weightedSet = null;
    super.tearDown();
  }


  public void testAdd() {
    Object o = new Object();
    boolean expectedReturn = true;
    boolean actualReturn = weightedSet.add(o);
    assertEquals("return value", expectedReturn, actualReturn);
    assertTrue(weightedSet.getWeight(o) == 0.0);
    assertTrue(weightedSet.size() == 1);

    actualReturn = weightedSet.add(o);
    expectedReturn = false;
    assertEquals("return value", expectedReturn, actualReturn);
    assertTrue(weightedSet.size() == 1);
  }

  public void testAddAll() {
    List c = new ArrayList();
    Object o = new Object();
    String s = "";

    c.add(o); c.add(s);

    boolean expectedReturn = true;
    boolean actualReturn = weightedSet.addAll(c);
    assertEquals("return value", expectedReturn, actualReturn);
    assertTrue(weightedSet.size() == 2);
    assertTrue(weightedSet.getWeight(o) == 0.0);
    assertTrue(weightedSet.getWeight(s) == 0.0);
  }

  public void testAsMap() {
    weightedSet.add("one");

    Map m = weightedSet.asMap();
    assertTrue(m.containsKey("one"));
    Double expectedReturn = new Double(0.0);
    Double actualReturn = (Double)m.get("one");
    assertEquals("return value", expectedReturn, actualReturn);
  }

  public void testClear() {
    weightedSet.setWeight("one", 0.5);
    weightedSet.setWeight("two", 0.5);
    weightedSet.clear();
    assertTrue(weightedSet.getTotalWeight() == 0.0);
    assertTrue(weightedSet.contains("one") == false);
    assertTrue(weightedSet.contains("two") == false);
  }

  public void testContains() {
    Object o = new Object();
    boolean expectedReturn = false;
    boolean actualReturn = weightedSet.contains(o);
    assertEquals("return value", expectedReturn, actualReturn);

    weightedSet.add(o);
    expectedReturn = true;
    actualReturn = weightedSet.contains(o);
    assertEquals("return value", expectedReturn, actualReturn);
  }

  public void testContainsAll() {
    List c = new ArrayList();
    Object o = new Object();
    String s = "";

    c.add(o);
    c.add(s);


    boolean expectedReturn = false;
    boolean actualReturn = weightedSet.containsAll(c);
    assertEquals("return value", expectedReturn, actualReturn);

    weightedSet.addAll(c);
    expectedReturn = true;
    actualReturn = weightedSet.containsAll(c);
    assertEquals("return value", expectedReturn, actualReturn);
  }

  public void testGetTotalWeight() {
    Double expectedReturn = new Double(1.5);
    weightedSet.setWeight("one", 0.5);
    weightedSet.setWeight("two", 0.5);
    weightedSet.setWeight("three", 0.5);
    Double actualReturn = new Double(weightedSet.getTotalWeight());
    assertEquals("return value", expectedReturn, actualReturn);
  }

  public void testGetWeight() throws NoSuchElementException {
    weightedSet.setWeight("one", 0.5);
    weightedSet.setWeight("two", 0.5);
    weightedSet.setWeight("three", 0.5);
    weightedSet.setWeight("four", 0.5);

    Double expectedReturn = new Double(0.25);
    Double actualReturn = new Double(weightedSet.getWeight("one"));
    assertEquals("return value", expectedReturn, actualReturn);
  }

  public void testIsEmpty() {
    weightedSet.setWeight("four", 0.5);
    boolean expectedReturn = false;
    boolean actualReturn = weightedSet.isEmpty();
    assertEquals("return value", expectedReturn, actualReturn);
  }


  public void testRemove() {
    weightedSet.setWeight("one", 0.5);
    weightedSet.setWeight("two", 0.5);
    weightedSet.setWeight("three", 0.5);
    weightedSet.setWeight("four", 0.5);
    weightedSet.setWeight("five", 0.5);

    weightedSet.remove("one");
    assertTrue(weightedSet.getTotalWeight() == 2.0);
    assertTrue(weightedSet.getWeight("five") == 0.25);
  }

  public void testRetainAll() {
    weightedSet.setWeight("one", 0.5);
    weightedSet.setWeight("two", 0.5);
    weightedSet.setWeight("three", 0.5);
    weightedSet.setWeight("four", 0.5);
    weightedSet.setWeight("five", 0.5);

    Collection c = new ArrayList();
    c.add("one"); c.add("two");

    boolean expectedReturn = true;
    boolean actualReturn = weightedSet.retainAll(c);
    assertEquals("return value", expectedReturn, actualReturn);

    assertTrue(weightedSet.contains("one"));
    assertTrue(weightedSet.containsAll(c));
    assertTrue(! weightedSet.contains("three"));
  }

  public void testSample() {
    weightedSet.setWeight("one", 0.5);
    Object expectedReturn = "one";
    Object actualReturn = weightedSet.sample();
    assertEquals("return value", expectedReturn, actualReturn);

    weightedSet.setWeight("two", 4.5);
  }

  public void testSetWeight() {
    Object o = "one";
    double w = 0.3;
    weightedSet.setWeight(o, w);

    assertTrue(weightedSet.getTotalWeight() == 0.3);
    assertTrue(weightedSet.getWeight(o) == 1.0);
  }

  public void testSetWeight2(){
    Object o = "one";
    double w = 2.5;
    weightedSet.setWeight(o, w);

    assertTrue(weightedSet.getTotalWeight() == 2.5);
    assertTrue(weightedSet.getWeight(o) == 1.0);
  }

  public void testSetWeight3(){
    Object o = "one";
    double w = 2.5;
    Object p = "two";
    double x = 2.5;

    weightedSet.setWeight(o, w);
    assertTrue(weightedSet.getTotalWeight() == 2.5);
    assertTrue(weightedSet.getWeight(o) == 1.0);
    weightedSet.setWeight(p, x);
    assertTrue(weightedSet.getTotalWeight() == 5.0);
    assertTrue(weightedSet.getWeight(o) == 0.5);
    assertTrue(weightedSet.getWeight(p) == 0.5);
  }

  public void testSize() {
    int expectedReturn = 0;
    int actualReturn = weightedSet.size();
    assertEquals("return value", expectedReturn, actualReturn);

    weightedSet.setWeight("one", 0.5);
    weightedSet.setWeight("two", 0.5);
    weightedSet.setWeight("three", 0.5);
    weightedSet.setWeight("four", 0.5);
    weightedSet.setWeight("five", 0.5);

    expectedReturn = 5;
    actualReturn = weightedSet.size();
    assertEquals("return value", expectedReturn, actualReturn);

  }
}


More information about the biojava-dev mailing list