ParametersFloat.java 3.54 KB
package is2.data;

import is2.util.DB;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;



final public class ParametersFloat  {

	public float[] parameters;
	public float[] total;

	public ParametersFloat(int size) {
		
		parameters = new   float[size];
		total = new float[size];
		for(int i = 0; i < parameters.length; i++) {
			parameters[i] = 0F;
			total[i] = 0F;
		}
	}


	/**
	 * @param parameters2
	 */
	public ParametersFloat(float[] p) {
		parameters =p;
	}

	public void average(double avVal) {
		for(int j = 0; j < total.length; j++) {
			parameters[j] = total[j]/((float)avVal);
		}
		total =null;
	}

	public ParametersFloat average2(double avVal) {
		float[] px = new float[this.parameters.length];
		for(int j = 0; j < total.length; j++) {
			px[j] = total[j]/((float)avVal);
		}
		ParametersFloat pf = new ParametersFloat(px);
		return pf;
	}
	
	public void update(FV pred, FV act,  float upd, float err) {

		
		float lam_dist = act.getScore(parameters,false)- pred.getScore(parameters,false);
		float loss =(float)err - lam_dist;

		FV dist = act.getDistVector(pred);	 

		float alpha;
		float A = dist.dotProduct(dist);
		if (A<=0.0000000000000001)  alpha=0.0f;
		else alpha= loss/A;
		
	//	alpha = Math.min(alpha, 0.00578125F);
		
		dist.update(parameters, total, alpha, upd,false); 
	
	}

	public void update(FV pred, FV act,  float upd, float err, float C) {

		
		float lam_dist = act.getScore(parameters,false)- pred.getScore(parameters,false);
		float loss =(float)err - lam_dist;

		FV dist = act.getDistVector(pred);	 

		float alpha;
		float A = dist.dotProduct(dist);
		if (A<=0.0000000000000001)  alpha=0.0f;
		else alpha= loss/A;
		
		alpha = Math.min(alpha, C);
		
		dist.update(parameters, total, alpha, upd,false); 
	
	}

	
	
	public double update(FV a, double b) {

		double A = a.dotProduct(a);
		if (A<=0.0000000000000000001) return 0.0;
		return b/A;
	}

	
	public double getScore(FV fv) {
		if (fv ==null) return 0.0F;
		return fv.getScore(parameters,false);

	}

	
	final public void write(DataOutputStream dos) throws IOException{

		dos.writeInt(parameters.length);
		for(float d : parameters) dos.writeFloat(d);

	}

	public void read(DataInputStream dis) throws IOException{

		parameters = new float[dis.readInt()];
		int notZero=0;
		for(int i=0;i<parameters.length;i++) {
			parameters[i]=dis.readFloat();
			if (parameters[i]!=0.0F) notZero++; 
		}
		
		DB.println("read parameters "+parameters.length+" not zero "+notZero);

	}
	
	public int countNZ() {

		int notZero=0;
		for(int i=0;i<parameters.length;i++) {
			if (parameters[i]!=0.0F) notZero++; 
		}
		return notZero;

	}

	public F2SF getFV() {
		return new F2SF(parameters);
	}

	
	public int size() {
		return parameters.length;
	}

	public void update(FVR act, FVR pred, Instances isd, int instc, Parse dx, double upd, double e, float lam_dist) {

		e++;
		
		
		float b = (float)e-lam_dist;
		
		FVR dist = act.getDistVector(pred);
		
		dist.update(parameters, total, hildreth(dist,b), upd,false);  
	}
	
	
	public void update(FVR pred, FVR act, float upd, float e) {

		e++;
		float lam_dist = act.getScore(parameters,false)- pred.getScore(parameters,false);
		
		float b = (float)e-lam_dist;
		
		FVR dist = act.getDistVector(pred);
		
		dist.update(parameters, total, hildreth(dist,b), upd,false);  
	}
	
	protected double hildreth(FVR a, double b) {

		double A = a.dotProduct(a);
		if (A<=0.0000000000000000001) return 0.0;
		return b/A;
	}
	
	public float getScore(FVR fv) { //xx
		if (fv ==null) return 0.0F;
		return fv.getScore(parameters,false);

	}

	
}