PRSFNN
PRSFNN (Polygenic Risk Score with Functional Neural Networks) is a Julia module for calculating polygenic risk scores by integrating GWAS summary statistics with functional annotations using neural networks.
Features
- Integration of GWAS summary statistics with functional annotations
- Linkage disequilibrium (LD) calculation and correction
- Coordinate Ascent Variational Inference (CAVI) for posterior effect size estimation
- Neural network models to learn the relationship between functional annotations and genetic effect sizes
PRSFNN runner
Note: this Julia module is used for running PRSFNN in individual LD blocks and training the neural network. To run PRSFNN genome-wide is an HPC environment, see our workflow here.
Installation
# From the Julia REPL
using Pkg
Pkg.add(url="https://github.com/weinstockj/PRS.jl")
Getting Started
Steps to load this module from the root directory:
- Run
julia --color=yes --project=.
(requires Julia 1.9.0 or later) - Run
using Revise
# helpful while developing - Run
using PRSFNN
Now the functions have been loaded. To call an internal function, use PRSFNN.function_name
To run Julia in debugger mode: JULIA_DEBUG=PRSFNN julia --color=yes --project=.
Usage
Basic Example
using PRSFNN
# Run PRSFNN on a genomic region
result = main(
output_prefix = "chr3_block1",
annot_data_path = "path/to/annotations.parquet",
gwas_data_path = "path/to/gwas_stats.tsv",
ld_panel_path = "path/to/ld_panel"
)
Simulating GWAS Data
For testing and development, you can simulate GWAS summary statistic data:
# Generate simulated data
raw = simulate_raw(;N = 10_000, P = 1_000, K = 100, h2 = 0.10)
# Extract sufficient statistics needed for PRS
summary_stats = estimate_sufficient_statistics(raw.X, raw.Y)
Training a PRS Model
# Train the PRS model
N = size(raw.X, 1) # Number of samples
P = size(raw.X, 2) # Number of SNPs
XtX = construct_XtX(summary_stats.R, ones(P), N)
D = construct_D(XtX)
Xty = construct_Xty(summary_stats.coef, D)
σ2, R2, yty = infer_σ2(
summary_stats.coef,
summary_stats.SE,
XtX,
Xty,
N,
P;
estimate = true,
λ = 0.50 * N
)
K = size(raw.G, 2)
H = 5 # Number of hidden units
layer_1 = Dense(K => H, Flux.softplus; init = Flux.glorot_normal(gain = 0.005))
layer_output = Dense(H => 2)
layer_output.bias .= [StatsFuns.log(0.001), StatsFuns.logit(0.1)]
model = Chain(layer_1, layer_output)
initial_lr = 0.00005
optim_type = AdamW(initial_lr)
opt = Flux.setup(optim_type, model)
PRS = train_until_convergence(
summary_stats.coef,
summary_stats.SE,
summary_stats.R,
XtX,
Xty,
raw.G,
model = model,
opt = opt,
σ2 = σ2,
R2 = R2,
yty = yty,
N = fill(N, P), # Vector of sample sizes
train_nn = false,
max_iter = 5
)
Documentation
For more detailed documentation, visit the official documentation site.
Running the Tests
Run unit tests with:
julia --project=. test/runtests.jl
or interactively with:
includet("test/runtests.jl")
Contact
Please address correspondence to:
- Josh Weinstock <josh.weinstock@emory.edu>
- April Kim <aprilkim@jhu.edu>
- Alexis Battle <ajbattle@jhu.edu>
Functions
PRSFNN.main
— FunctionPRSFNN
This function defines the command line interface for the PRSFNN package.
Arguments
output_prefix
: A prefix for the output filesannot_data_path
: A path to the directory containing the annotationsld_panel_path
: A path to the directory containing the LD reference panelgwas_data_path
: A path to the GWAS summary statistics filemodel_file
: A path to the file containing the trained modelbetas_output_file
: A path to the file where the PRS betas will be savedinterpretation_output_file
: A path to the file where the interpretation of the model will be saved
PRSFNN.rss
— Functionrss(β, coef, Σ, SRSinv, to)
Calculate the summary statistic RSS likelihood.
Arguments
β::Vector
: Vector of effect sizes.coef::Vector
: Observed coefficients.Σ::AbstractPDMat
: Positive definite covariance matrix.SRSinv::Matrix
: Precomputed matrix for efficiency.to
: Timer object for benchmarking.
Returns
Float64
: Log likelihood value.
Example
rss(
[0.0011, 0.0052, 0.0013],
[-0.019, 0.013, -0.0199],
PDMat(Σ),
SRSinv,
TimerOutput()
)
PRSFNN.elbo
— Functionelbo(z, q_μ, log_q_var, coef, SE, R, σ2_β, p_causal, σ2, [spike_σ2], to)
Compute the Evidence Lower Bound (ELBO) for variational inference.
Arguments
z::Vector
: Random vector from standard normal distribution.q_μ::Vector
: Mean vector of the variational distribution.log_q_var::Vector
: Log variance vector of the variational distribution.coef::Vector
: Observed coefficients.SE::Vector
orΣ::AbstractPDMat
: Standard errors or covariance matrix.R::AbstractArray
orSRSinv::Matrix
: Correlation matrix or precomputed matrix.σ2_β::Vector
: Vector of prior variances for effect sizes.p_causal::Vector
: Vector of prior probabilities that each SNP is causal.σ2::Real
: Global variance parameter.spike_σ2::Real
: Optional variance parameter for the spike component.to
: Timer object for benchmarking.
Returns
Float64
: Computed ELBO value.
Example
elbo(
rand(Normal(0, 1), 3),
[0.01, -0.003, 0.0018],
[-9.234, -9.24, -9.24],
[0.023, -0.0009, -0.0018],
[0.0094, 0.00988, 0.0102],
[1.0 0.03 0.017; 0.031 1.0 -0.03; 0.017 -0.02 1.0],
[0.01, 0.01, 0.01],
[0.10, 0.10, 0.10],
0.01,
TimerOutput()
)
PRSFNN.joint_log_prob
— Functionjoint_log_prob(β, coef, SE, R, σ2_β, p_causal, σ2, [to])
Compute the joint log probability of the model combining likelihood and prior.
Arguments
β::Vector
: Vector of effect sizes.coef::Vector
: Observed coefficients.SE::Vector
orΣ::AbstractPDMat
: Standard errors or covariance matrix.R::Matrix
orSRSinv::Matrix
: Correlation matrix or precomputed matrix.σ2_β::Vector
: Vector of prior variances for effect sizes.p_causal::Vector
: Vector of prior probabilities that each SNP is causal.σ2::Real
: Global variance parameter.spike_σ2::Real
: Optional variance parameter for the spike component.to
: Timer object for benchmarking.
Returns
Float64
: Joint log probability.
Example
joint_log_prob(
[0.0011, 0.0052, 0.0013],
[-0.019, 0.013, -0.0199],
[0.0098, 0.0098, 0.0102],
[1.0 0.03 0.017; 0.031 1.0 -0.03; 0.017 -0.02 1.0],
[0.01, 0.01, 0.01],
[0.10, 0.10, 0.10],
0.01,
TimerOutput()
)
PRSFNN.train_until_convergence
— Functiontrain_until_convergence(coef, SE, R, D, G; max_iter = 20, threshold = 0.1, N = 10_000)
Arguments
coef::Vector
: A length P vector of effect sizesSE::Vector
: A length P vector of standard errorsR::AbstractArray
: A P x P correlation matrixD::Vector
: A length P vector of the sum of squared genotypesG::AbstractArray
: A P x K matrix of annotations
PRSFNN.fit_heritability_nn
— Functionfitheritabilitynn(model, qμ, qμsq, qalpha, G)
Fit the heritability neural network model.
Arguments
model::Chain
: A neural network modelq_μ::Vector
: A length P vector of posterior meansq_μ_sq::Vector
: A length P vector of posterior variancesq_α::Vector
: A length P vector of posterior probabilities of being causalG::AbstractArray
: A P x K matrix of annotations
model = Chain(
Dense(20 => 5, relu; init = Flux.glorot_normal(gain = 0.0005)),
Dense(5 => 2)
)
G = rand(Normal(0, 1), 100, 20)
q_μ_sq = (G * rand(Normal(0, 0.10), 20)) .^ 2
q_α = 1.0 ./ (1.0 .+ exp.(-1.0 .* (-2.0 .+ q_μ_sq)))
trained_model = PRSFNN.fit_heritability_nn(
model,
q_μ_sq,
q_α,
G
)
yhat = transpose(trained_model(transpose(G)))
yhat[:, 1] .= exp.(yhat[:, 1])
yhat[:, 2] .= 1.0 ./ (1.0 .+ exp.(-yhat[:, 2]))
PRSFNN.log_prior
— Functionlog_prior(β, σ2_β, p_causal, σ2, spike_σ2, to)
Calculate the log density of β based on a spike and slab prior.
Arguments
β::Vector
: Vector of effect sizes.σ2_β::Vector
: Vector of prior variances for effect sizes.p_causal::Vector
: Vector of prior probabilities that each SNP is causal.σ2::Real
: Global variance parameter.spike_σ2::Real
: Variance parameter for the spike component.to
: Timer object for benchmarking.
Returns
Float64
: Log prior density.
PRSFNN.estimate_sufficient_statistics
— Functionestimate_sufficient_statistics(X, Y)
Estimate the sufficient statistics for genetic association analysis from genotypes and phenotypes.
Arguments
X::AbstractArray
: Genotype matrix (N×P) where N is the number of samples and P is the number of SNPs.Y::Vector
: Phenotype vector of length N.
Returns
A named tuple containing:
coef::Vector
: Regression coefficients for each SNP.SE::Vector
: Standard errors of the regression coefficients.Z::Vector
: Z-scores (coef/SE) for each SNP.R::Matrix
: Correlation matrix of the genotypes (LD matrix).D::Vector
: Sum of squares for each SNP (diagonal of X'X).
Description
This function computes the basic summary statistics needed for genetic association analysis, including marginal effect sizes, standard errors, test statistics, and the correlation structure between variants.
Example
X = randn(1000, 100) # 1000 samples, 100 SNPs
Y = randn(1000) # Phenotypes
stats = estimate_sufficient_statistics(X, Y)
PRSFNN.compute_LD
— Functioncompute_LD(LD_reference::String)
Compute linkage disequilibrium (LD) correlation matrix from a reference genotype file.
Arguments
LD_reference::String
: Path to the BED format genotype file
Returns
R::Matrix{Float64}
: LD correlation matrix for polymorphic variantssds::Vector{Float64}
: Standard deviations of genotypes for each variantallele_freq::Vector{Float64}
: Allele frequencies (mean/2) for each variantgood_variants::Vector{Int}
: Indices of polymorphic variants that passed QC filters
Details
This function reads genotype data from a BED file, converts it to floating point representation, filters out monomorphic variants and those with very low variance, and computes the correlation matrix between the remaining variants. Variants are filtered based on mean frequency and standard deviation thresholds.
PRSFNN.fit_genome_wide_nn
— Functionfitgenomewidenn(betas, annotationfilesdir, modelfile; nepochs, H, ntest, learningratedecay, patience)
Train a neural network model to predict variant effects using genome-wide annotation data.
Arguments
betas
: Path to file containing PRS beta values from previous analysisannotation_files_dir
: Directory containing annotation parquet filesmodel_file
: Path where the trained model will be savedn_epochs
: Number of training epochs (default: 1306)H
: Number of hidden units in the neural network (default: 3)n_test
: Number of test samples to use for evaluation (default: 30)learning_rate_decay
: Learning rate decay factor (default: 0.98)patience
: Number of epochs to wait before reducing learning rate (default: 30)
Returns
model
: The trained neural network modelopt
: The optimizer stateglobal_σ2
: Global residual variance estimate
PRSFNN.train_cavi
— Functiontrain_cavi(p_causal, σ2_β, coef, SE, R, XtX, Xty, to; P = 1_000, n_elbo = 10, max_iter = 5, N = 10_000, yty = 10_000, spike_σ2 = 1e-6, update_σ2 = true, σ2 = 1.0)
Performs Coordinate Ascent Variational Inference (CAVI) for a spike-and-slab model to estimate genetic variant effect sizes and posterior inclusion probabilities.
Arguments
p_causal::Vector{Float64}
: Prior probabilities for variants to be causalσ2_β::Vector{Float64}
: Prior variances for effect sizes (slab component)coef::Vector{Float64}
: Effect sizes from GWAS summary statisticsSE::Vector{Float64}
: Standard errors of the effect sizesR::Matrix{Float64}
: LD correlation matrixXtX::Matrix{Float64}
: Matrix equal to N times the covariance matrix of genotypesXty::Vector{Float64}
: Vector of inner products between genotypes and phenotypeto::TimerOutput
: Timer output object for performance profiling
Keyword Arguments
P::Int64=1_000
: Number of variantsn_elbo::Int64=10
: Number of Monte Carlo samples for ELBO estimationmax_iter::Int64=5
: Maximum number of iterations for convergenceN::Float64=10_000
: Sample sizeyty::Float64=10_000
: Sum of squared phenotypesspike_σ2::Float64=1e-6
: Variance for the spike componentupdate_σ2::Bool=true
: Whether to update the residual varianceσ2::Float64=1.0
: Initial residual variance
Returns
A named tuple containing:
q_μ::Vector{Float64}
: Posterior means for the slab componentq_spike_μ::Vector{Float64}
: Posterior means for the spike componentq_α::Vector{Float64}
: Posterior inclusion probabilitiesq_var::Vector{Float64}
: Posterior variances for the slab componentq_odds::Vector{Float64}
: Posterior odds for variant inclusionloss::Float64
: Final ELBO valuese_loss::Float64
: Standard error of the ELBO estimateσ2::Float64
: Final residual variance estimate
Details
This function implements an iterative Coordinate Ascent Variational Inference algorithm for a spike-and-slab model, commonly used in Polygenic Risk Score (PRS) calculation. At each iteration, it updates the variational parameters to maximize the Evidence Lower Bound (ELBO).
The algorithm stops when insufficient improvement in ELBO is detected or when the maximum number of iterations is reached. The best parameters (with highest ELBO) are returned.
PRSFNN.simulate_raw
— Functionsimulate_raw(; N = 10_000, P = 1_000, K = 100, h2 = 0.10)
Simulate genotype and phenotype data with genetic architecture influenced by functional annotations.
Arguments
N::Int=10_000
: Number of samples/individualsP::Int=1_000
: Number of genetic variants/SNPsK::Int=100
: Number of functional annotationsh2::Float64=0.10
: Desired narrow-sense heritability of the trait
Returns
A named tuple containing:
X::Matrix{Float64}
: Genotype matrix (N × P)β::Vector{Float64}
: True causal effect sizesY::Vector{Float64}
: Simulated phenotypesΣ::Matrix{Float64}
: Covariance matrix of genotypes (LD structure)s::Vector{Float64}
: Per-SNP normalized variance scalars derived from annotationsG::Matrix{Float64}
: Matrix of functional annotations (P × K)γ::Vector{Int}
: Binary indicators of causal status (1=causal, 0=non-causal)function_choices::Vector{Int}
: Selected functional forms for each annotationphi::Vector{Function}
: Functions mapping each annotation to effect size variancesigma_squared::Vector{Float64}
: Per-SNP variance components before normalization
Details
This function simulates genotype and phenotype data with a realistic genetic architecture:
- Genotypes (X) are simulated with a realistic LD structure using random eigendecomposition
- 10% of variants are designated as causal (γ)
- Functional annotations (G) influence effect size variance through different functional forms
- Effect sizes follow a spike-and-slab distribution:
- Causal variants: Normal(0, scaled_variance)
- Non-causal variants: Normal(0, tiny_variance)
- Phenotypes are computed as Y = Xβ + ε, with noise calibrated to achieve target heritability
The simulation includes both continuous and binary annotations, with varying relationships to effect size variance (linear, quadratic, or null effects).
PRSFNN.infer_σ2
— Functioninfer_σ2(coef, SE, R, D, X_sd, N, P)
Arguments
coef::Vector
: A length P vector of effect sizesSE::Vector
: A length P vector of standard errorsXtX::AbstractArray
: A P x P matrix equal to N times the covariance matrix of the genotypesXty::Vector
: A length P vector of the inner product between genotype and phenotypeN
: Number of samplesP
: Number of SNPs
PRSFNN.poet_cov
— Functionpoet_cov(X::AbstractArray; K = 100, τ = 0.01, N = 1000)
Apply POET (Principal Orthogonal complEment Thresholding) method to estimate a sparse covariance matrix.
Arguments
X::AbstractArray
: Input covariance or correlation matrixK::Int=100
: Number of principal components to retainτ::Float64=0.01
: Thresholding parameter for sparsityN::Int=1000
: Sample size used for bias correction
Returns
Sigma::Matrix{Float64}
: The POET-estimated covariance matrix
Details
Implements the POET method for high-dimensional covariance matrix estimation. The method decomposes the matrix into a low-rank component (representing systematic factors) and a sparse component (representing idiosyncratic noise). The sparse component is obtained by thresholding.
Reference: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5563862/