00001 #ifndef snap_cascdynetinf_h
00002 #define snap_cascdynetinf_h
00004 #include "Snap.h"
00006 // pairwise transmission models
00007 typedef enum {
00008   EXP, // exponential
00009   POW, // powerlaw
00010   RAY, // rayleigh
00011   WEI // weibull
00012 } TModel;
00014 // tx rates trends over time for synthetic experiments
00015 typedef enum {
00016   CONSTANT, // constant
00017   LINEAR, // linear trend up/down
00018   EXPONENTIAL, // exponential trend
00019   RAYLEIGH, // rayleigh trend
00020   SLAB, // slab
00021   SQUARE, // square
00022   CHAINSAW, // chainsaw
00023   RANDOM // random noise around alpha value
00024 } TVarying;
00026 // optimization methods
00027 typedef enum {
00028   OSG,    // stochastic gradient
00029   OWSG,    // windowed stochastic gradient
00030   OESG,   // exponential decay stochastic gradient
00031   OWESG,  // windowed exponential decay stochastic gradient
00032   ORSG,  // rayleigh decay stochastic gradient
00033   OBSG,   // no decay batch stochastic gradient
00034   OWBSG,   // windowed batch stochastic gradient
00035   OEBSG,  // exponential decay batch stochastic gradient
00036   ORBSG,  // rayleigh decay batch stochastic gradient
00037   OFG
00038 } TOptMethod;
00040 typedef enum {
00046 } TSampling;
00048 // l2 regularizer on/off
00049 typedef enum {
00050   NONE, // no regularizer
00051   L2REG // L2 regularizer
00052 } TRegularizer;
00054 typedef enum {
00055   TIME_STEP, // run inference every time step
00056   INFECTION_STEP, // run inference every # number of infections
00057   CASCADE_STEP, // run inference every time a cascade "finishes"
00059 } TRunningMode;
00061 typedef TNodeEDatNet<TStr, TFltFltH> TStrFltFltHNEDNet;
00062 typedef TPt<TStrFltFltHNEDNet> PStrFltFltHNEDNet;
00064 typedef TNodeEDatNet<TStr, TFlt> TStrFltNEDNet;
00065 typedef TPt<TStrFltNEDNet> PStrFltNEDNet;
00067 // Hit info (node id, timestamp) about a node in a cascade
00068 class THitInfo {
00069 public:
00070   TInt NId;
00071   TFlt Tm;
00072   TIntV Keywords;
00073 public:
00074   THitInfo(const int& NodeId=-1, const double& HitTime=0) : NId(NodeId), Tm(HitTime) { }
00075   THitInfo(TSIn& SIn) : NId(SIn), Tm(SIn), Keywords(SIn) { }
00076   void AddKeyword(const int& KId) { Keywords.AddUnique(KId); }
00077   void DelKeywords() { Keywords.Clr(); }
00078   void Save(TSOut& SOut) const { NId.Save(SOut); Tm.Save(SOut); Keywords.Save(SOut); }
00079   bool operator < (const THitInfo& Hit) const {
00080     return Tm < Hit.Tm; }
00081 };
00083 // Cascade
00084 class TCascade {
00085 public:
00086   TInt CId; // cascade id
00087   THash<TInt, THitInfo> NIdHitH; // infected nodes
00088   TInt Model; // pairwise transmission model
00089 public:
00090   TCascade() : CId(0), NIdHitH(), Model(0) { }
00091   TCascade(const int &model) : NIdHitH() { Model = model; }
00092   TCascade(const int &cid, const int& model) : NIdHitH() { CId = cid; Model = model; }
00093   TCascade(TSIn& SIn) : CId(SIn), NIdHitH(SIn), Model(SIn) { }
00094   void Save(TSOut& SOut) const  { CId.Save(SOut); NIdHitH.Save(SOut); Model.Save(SOut); }
00095   void Clr() { NIdHitH.Clr(); }
00096   int GetId() { return CId; }
00097   int Len() const { return NIdHitH.Len(); }
00098   int LenBeforeT(const double& T) { int len = 0; while (len < NIdHitH.Len() && NIdHitH[len].Tm <= T) { len++; } return len; }
00099   int LenAfterT(const double& T) { int len = 0; while (len < NIdHitH.Len() && NIdHitH[NIdHitH.Len()-1-len].Tm >= T) { len++; } return len; }
00100   int GetNode(const int& i) const { return NIdHitH.GetKey(i); }
00101   THash<TInt, THitInfo>::TIter BegI() const { return NIdHitH.BegI(); }
00102   THash<TInt, THitInfo>::TIter EndI() const { return NIdHitH.EndI(); }
00103   int GetModel() const { return Model; }
00104   double GetTm(const int& NId) const { return NIdHitH.GetDat(NId).Tm; }
00105   double GetMaxTm() const { return NIdHitH[NIdHitH.Len()-1].Tm; } // we assume the cascade is sorted
00106   double GetMinTm() const { return NIdHitH[0].Tm; } // we assume the cascade is sorted
00107   void Add(const int& NId, const double& HitTm) { NIdHitH.AddDat(NId, THitInfo(NId, HitTm)); }
00108   void Del(const int& NId) { NIdHitH.DelKey(NId); }
00109   bool IsNode(const int& NId) const { return NIdHitH.IsKey(NId); }
00110   void Sort() { NIdHitH.SortByDat(true); }
00111   bool operator < (const TCascade& Cascade) const {
00112       return Len() < Cascade.Len(); }
00113 };
00115 // Node info (name and number of cascades)
00116 class TNodeInfo {
00117 public:
00118   TStr Name;
00119   TInt Vol;
00120 public:
00121   TNodeInfo() { }
00122   TNodeInfo(const TStr& NodeNm, const int& Volume) : Name(NodeNm), Vol(Volume) { }
00123   TNodeInfo(TSIn& SIn) : Name(SIn), Vol(SIn) { }
00124   void Save(TSOut& SOut) const { Name.Save(SOut); Vol.Save(SOut); }
00125   bool operator < (const TNodeInfo& NodeInfo) const {
00126       return Vol < NodeInfo.Vol; }
00127 };
00129 // Stochastic gradient network inference class
00130 class TNIBs {
00131 public:
00132   THash<TInt, TCascade> CascH; // cascades, indexed by id
00133   THash<TInt, TNodeInfo> NodeNmH; // node info (name, volume), indexed by node id
00134   TStrIntH DomainsIdH; // domain, DomainId hash table
00135   TStrIntH CascadeIdH; // quote, CascadeId hash table, QuoteId is equivalent to cascadeId
00137   // cascades per edge
00138   THash<TIntPr, TIntV> CascPerEdge;
00140   // network
00141   TStrFltFltHNEDNet Network;
00143   // pairwise transmission model
00144   TModel Model;
00146   // time horizon per cascade (if it is fixed), and totaltime
00147   TFlt Window, TotalTime;
00149   // delta for power-law and k for weibull
00150   TFlt Delta, K;
00152   // step (gamma), regularizer (mu), tolerance, and min/max alpha for stochastic gradient descend
00153   TFlt Gamma, Mu, Aging;
00154   TRegularizer Regularizer;
00155   TFlt Tol, MaxAlpha, MinAlpha, InitAlpha;
00157   // inferred network
00158   TStrFltFltHNEDNet InferredNetwork;
00159   TIntFltH TotalCascadesAlpha;
00161   // gradients (per alpha & cascade)
00162   TIntFltH AveDiffAlphas;
00163   THash<TInt, TIntFltH> DiffAlphas;
00165   // sampled cascades
00166   TIntIntPrH SampledCascadesH;
00168   // performance measures
00169   TFltPrV PrecisionRecall;
00170   TFltPrV Accuracy, MAE, MSE;
00172 public:
00173   TNIBs( ) { }
00174   TNIBs(TSIn& SIn) : CascH(SIn), NodeNmH(SIn), CascPerEdge(SIn), InferredNetwork(SIn) { Model = EXP; }
00175   void Save(TSOut& SOut) const { CascH.Save(SOut); NodeNmH.Save(SOut); CascPerEdge.Save(SOut); InferredNetwork.Save(SOut); }
00177   // functions to load text cascades & network files
00178   void LoadCascadesTxt(TSIn& SIn);
00179   void LoadGroundTruthTxt(TSIn& SIn);
00180   void LoadGroundTruthNodesTxt(TSIn& SIn);
00181   void LoadInferredTxt(TSIn& SIn);
00182   void LoadInferredNodesTxt(TSIn& SIn);
00184   // maximum time for synthetic generation, tx model & window per cascade (if any)
00185   void SetTotalTime(const float& tt) { TotalTime = tt; }
00186   void SetModel(const TModel& model) { Model = model; }
00187   void SetWindow(const double& window) { Window = window; }
00189   // delta for power law & k for weibull
00190   void SetDelta(const double& delta) { Delta = delta; }
00191   void SetK(const double& k) { K = k; }
00193   // optimization parameters
00194   void SetGamma(const double& gamma) { Gamma = gamma; }
00195   void SetAging(const double& aging) { Aging = aging; }
00196   void SetRegularizer(const TRegularizer& reg) { Regularizer = reg; }
00197   void SetMu(const double& mu) { Mu = mu; }
00198   void SetTolerance(const double& tol) { Tol = tol; }
00199   void SetMaxAlpha(const double& ma) { MaxAlpha = ma; }
00200   void SetMinAlpha(const double& ma) { MinAlpha = ma; }
00201   void SetInitAlpha(const double& ia) { InitAlpha = ia; }
00203   // processing cascades
00204   void AddCasc(const TStr& CascStr, const TModel& Model=EXP);
00205   void AddCasc(const TCascade& Cascade) { CascH.AddDat(Cascade.CId) = Cascade; }
00206   void AddCasc(const TIntFltH& Cascade, const int& CId=-1, const TModel& Model=EXP);
00207   void GenCascade(TCascade& C);
00208   bool IsCascade(int c) { return CascH.IsKey(c); }
00209   TCascade & GetCasc(int c) { return CascH.GetDat(c); }
00210   int GetCascs() { return CascH.Len(); }
00211   int GetCascadeId(const TStr& Cascade) { return CascadeIdH.GetDat(Cascade); }
00213   // node info
00214   int GetNodes() { return InferredNetwork.GetNodes(); }
00215   void AddNodeNm(const int& NId, const TNodeInfo& Info) { NodeNmH.AddDat(NId, Info); }
00216   TStr GetNodeNm(const int& NId) const { return NodeNmH.GetDat(NId).Name; }
00217   TNodeInfo GetNodeInfo(const int& NId) const { return NodeNmH.GetDat(NId); }
00218   bool IsNodeNm(const int& NId) const { return NodeNmH.IsKey(NId); }
00219   void SortNodeNmByVol(const bool& asc=false) { NodeNmH.SortByDat(asc); }
00221   // domains
00222   void AddDomainNm(const TStr& Domain, const int& DomainId=-1) { DomainsIdH.AddDat(Domain) = TInt(DomainId==-1? DomainsIdH.Len() : DomainId); }
00223   bool IsDomainNm(const TStr& Domain) const { return DomainsIdH.IsKey(Domain); }
00224   int GetDomainId(const TStr& Domain) { return DomainsIdH.GetDat(Domain); }
00226   // get network or graph at a given time
00227   void GetGroundTruthGraphAtT(const double& Step, PNGraph &GraphAtT);
00228   void GetGroundTruthNetworkAtT(const double& Step, PStrFltNEDNet& NetworkAtT);
00229   void GetInferredGraphAtT(const double& Step, PNGraph &GraphAtT);
00230   void GetInferredNetworkAtT(const double& Step, PStrFltNEDNet& NetworkAtT);
00232   // reset/init for optimization
00233   void Reset();
00234   void Init(const TFltV& Steps);
00236   // optimization methods
00237   void SG(const int& NId, const int& Iters, const TFltV& Steps, const TSampling& Sampling, const TStr& ParamSampling=TStr(""), const bool& PlotPerformance=false);
00238   void BSG(const int& NId, const int& Iters, const TFltV& Steps, const int& BatchLen, const TSampling& Sampling, const TStr& ParamSampling=TStr(""), const bool& PlotPerformance=false);
00239   void FG(const int& NId, const int& Iters, const TFltV& Steps);
00241   // auxiliary function for optimization
00242   void UpdateDiff(const TOptMethod& OptMethod, const int& NId, TCascade& Cascade, TIntPrV& AlphasToUpdate, const double& CurrentTime=TFlt::Mx);
00244   // functions to compute burstiness
00245   void find_C( int t, TFltV &x, TFltVV &C, const int& k, const double& s, const double& gamma, const double& T );
00246   void find_min_state( TFltVV &C, TIntV &states, const int& k, const double& s, const double& gamma, const double& T );
00247   void LabelBurstAutomaton(const int& SrcId, const int& DstId, TIntV &state_labels, TFltV &state_times, const bool& inferred=false, const int& k = 5, const double& s = 2.0, const double& gamma = 1.0, const TSecTm& MinTime=TSecTm(), const TSecTm& MaxTime=TSecTm() );
00249   // function to compute performance for a particular time step and node given groundtruth + inferred network
00250   void ComputePerformanceNId(const int& NId, const int& Step, const TFltV& Steps);
00252   // storing ground truth and inferred network in pajek and text format
00253   void SaveInferredPajek(const TStr& OutFNm, const double& Step, const TIntV& NIdV=TIntV());
00254   void SaveInferred(const TStr& OutFNm, const TIntV& NIdV=TIntV());
00255   void SaveInferred(const TStr& OutFNm, const double& Step, const TIntV& NIdV=TIntV());
00256   void SaveInferredEdges(const TStr& OutFNm);
00258   // store network
00259   void SaveGroundTruthPajek(const TStr& OutFNm, const double& Step);
00260   void SaveGroundTruth(const TStr& OutFNm);
00262   // storing NodeId, site name
00263   void SaveSites(const TStr& OutFNm, const TIntFltVH& CascadesPerNode=TIntFltVH());
00265   // storing cascades in text format
00266   void SaveCascades(const TStr& OutFNm);
00267 };
00269 #endif