1 #include <tclap/CmdLine.h> 6 #include <itkExtractImageFilter.h> 8 #include <itkImageRegionIterator.h> 9 #include <itkCompositeTransform.h> 10 #include <itkStationaryVelocityFieldTransform.h> 11 #include <rpiDisplacementFieldTransform.h> 16 int main(
int argc,
const char** argv)
18 const unsigned int Dimension = 3;
20 typedef itk::Image <double,Dimension+1> InputImageType;
21 typedef itk::Image <double,Dimension> InputSubImageType;
22 typedef itk::ImageRegionIterator <InputImageType> InputImageIteratorType;
23 typedef itk::ImageRegionIterator <InputSubImageType> InputSubImageIteratorType;
27 typedef itk::AffineTransform<AgregatorType::ScalarType,Dimension> AffineTransformType;
28 typedef AffineTransformType::Pointer AffineTransformPointer;
31 TCLAP::CmdLine cmd(
"INRIA / IRISA - VisAGeS/Empenn Team",
' ',ANIMA_VERSION);
34 TCLAP::ValueArg<std::string> inputArg(
"i",
"input",
"Input 4D image",
true,
"",
"input 4D image",cmd);
35 TCLAP::ValueArg<std::string> inBVecArg(
"I",
"input-bvec",
"Input gradient vectors file",
true,
"",
"input gradients",cmd);
36 TCLAP::ValueArg<std::string> outArg(
"o",
"output",
"Output (corrected) image",
true,
"",
"output image",cmd);
37 TCLAP::ValueArg<std::string> outBVecArg(
"O",
"output-bvec",
"Output gradient vectors (bvec format)",
true,
"",
"output gradients",cmd);
39 TCLAP::ValueArg<unsigned int> directionArg(
"d",
"dir",
"Affine direction for directional transform output (default: 1 = Y axis)",
false,1,
"direction of directional affine",cmd);
40 TCLAP::ValueArg<unsigned int> b0Arg(
"b",
"b0",
"Index of the B0 reference image",
false,0,
"reference image index",cmd);
42 TCLAP::ValueArg<unsigned int> blockSizeArg(
"",
"bs",
"Block size (default: 5)",
false,5,
"block size",cmd);
43 TCLAP::ValueArg<unsigned int> blockSpacingArg(
"",
"sp",
"Block spacing (default: 5)",
false,5,
"block spacing",cmd);
44 TCLAP::ValueArg<unsigned int> nlBlockSpacingArg(
"",
"nsp",
"Block spacing (default: 3)",
false,3,
"non linear matching block spacing",cmd);
45 TCLAP::ValueArg<double> stdevThresholdArg(
"s",
"stdev",
"Threshold block standard deviation (default: 5)",
false,5,
"block minimal standard deviation",cmd);
46 TCLAP::ValueArg<double> percentageKeptArg(
"k",
"per-kept",
"Percentage of blocks with the highest variance kept (default: 0.8)",
false,0.8,
"percentage of blocks kept",cmd);
48 TCLAP::ValueArg<unsigned int> blockMetricArg(
"",
"metric",
"Similarity metric between blocks (0: squared correlation coefficient, 1: correlation coefficient, 2: mean squares, default: 0)",
false,0,
"similarity metric",cmd);
49 TCLAP::ValueArg<unsigned int> optimizerArg(
"",
"opt",
"Optimizer for optimal block search (0: Exhaustive, 1: Bobyqa, default: 1)",
false,1,
"optimizer",cmd);
51 TCLAP::ValueArg<unsigned int> maxIterationsArg(
"",
"mi",
"Maximum block match iterations (default: 10)",
false,10,
"maximum iterations",cmd);
52 TCLAP::ValueArg<double> minErrorArg(
"",
"me",
"Minimal distance between consecutive estimated transforms (default: 0.01)",
false,0.01,
"minimal distance between transforms",cmd);
54 TCLAP::ValueArg<unsigned int> optimizerMaxIterationsArg(
"",
"oi",
"Maximum iterations for local optimizer (default: 100)",
false,100,
"maximum local optimizer iterations",cmd);
56 TCLAP::ValueArg<double> searchRadiusArg(
"",
"sr",
"Search radius in pixels (exhaustive search window, rho start for bobyqa, default: 2)",
false,2,
"optimizer search radius",cmd);
57 TCLAP::ValueArg<double> finalRadiusArg(
"",
"fr",
"Final radius (rho end for bobyqa, default: 0.001)",
false,0.001,
"optimizer final radius",cmd);
58 TCLAP::ValueArg<double> searchStepArg(
"",
"st",
"Search step for exhaustive search (default: 2)",
false,2,
"exhaustive optimizer search step",cmd);
59 TCLAP::ValueArg<double> translateUpperBoundArg(
"",
"tub",
"Upper bound on translation for bobyqa (in voxels, default: 10)",
false,10,
"Bobyqa translate upper bound",cmd);
61 TCLAP::ValueArg<unsigned int> symmetryArg(
"",
"sym-reg",
"Registration symmetry type (0: asymmetric, 1: symmetric, 2: kissing, default: 0)",
false,0,
"symmetry type",cmd);
62 TCLAP::ValueArg<unsigned int> agregatorArg(
"a",
"agregator",
"Transformation agregator type (0: M-Estimation, 1: least squares, 2: least trimmed squares, default: 0)",
false,0,
"agregator type",cmd);
63 TCLAP::ValueArg<double> agregThresholdArg(
"",
"at",
"Agregator threshold value (for M-estimation or LTS)",
false,0.5,
"agregator threshold value",cmd);
64 TCLAP::ValueArg<double> extrapolationSigmaArg(
"",
"fs",
"Sigma for extrapolation of local pairings (default: 3)",
false,3,
"extrapolation sigma",cmd);
65 TCLAP::ValueArg<double> elasticSigmaArg(
"",
"es",
"Sigma for elastic regularization (default: 3)",
false,3,
"elastic regularization sigma",cmd);
66 TCLAP::ValueArg<double> outlierSigmaArg(
"",
"os",
"Sigma for outlier rejection among local pairings (default: 3)",
false,3,
"outlier rejection sigma",cmd);
67 TCLAP::ValueArg<double> seStoppingThresholdArg(
"",
"lst",
"LTS Stopping Threshold (default: 0.01)",
false,0.01,
"LTS stopping threshold",cmd);
69 TCLAP::ValueArg<unsigned int> numPyramidLevelsArg(
"p",
"pyr",
"Number of pyramid levels (default: 3)",
false,3,
"number of pyramid levels",cmd);
70 TCLAP::ValueArg<unsigned int> lastPyramidLevelArg(
"l",
"last-level",
"Index of the last pyramid level explored (default: 0)",
false,0,
"last pyramid level",cmd);
71 TCLAP::ValueArg<unsigned int> numThreadsArg(
"T",
"threads",
"Number of execution threads (default: 0 = all cores)",
false,0,
"number of threads",cmd);
77 catch (TCLAP::ArgException& e)
79 std::cerr <<
"Error: " << e.error() <<
"for argument " << e.argId() << std::endl;
83 InputImageType::Pointer inputImage = anima::readImage <InputImageType> (inputArg.getValue());
84 unsigned int numberOfImages = inputImage->GetLargestPossibleRegion().GetSize()[Dimension];
85 typedef itk::ExtractImageFilter <InputImageType, InputSubImageType> ExtractFilterType;
87 ExtractFilterType::Pointer referenceExtractFilter = ExtractFilterType::New();
88 referenceExtractFilter->SetInput(inputImage);
89 InputImageType::RegionType refExtractRegion = inputImage->GetLargestPossibleRegion();
90 refExtractRegion.SetIndex(Dimension,b0Arg.getValue());
91 refExtractRegion.SetSize(Dimension,0);
92 referenceExtractFilter->SetExtractionRegion(refExtractRegion);
93 referenceExtractFilter->SetDirectionCollapseToGuess();
94 referenceExtractFilter->Update();
97 GFReaderType gfReader;
99 gfReader.SetGradientIndependentNormalization(
false);
102 GFReaderType::GradientVectorType directions = gfReader.GetGradients();
104 for (
unsigned int i = 0;i < numberOfImages;++i)
106 if (i == b0Arg.getValue())
109 std::cout <<
"\033[K\rProcessing image " << i+1 <<
" out of " << numberOfImages << std::flush;
111 ExtractFilterType::Pointer extractFilter = ExtractFilterType::New();
112 extractFilter->SetInput(inputImage);
113 InputImageType::RegionType extractRegion = inputImage->GetLargestPossibleRegion();
114 extractRegion.SetIndex(Dimension,i);
115 extractRegion.SetSize(Dimension,0);
116 extractFilter->SetExtractionRegion(extractRegion);
117 extractFilter->SetDirectionCollapseToGuess();
118 extractFilter->Update();
121 PyramidBMType::Pointer matcher = PyramidBMType::New();
123 matcher->SetBlockSize(blockSizeArg.getValue());
124 matcher->SetBlockSpacing(blockSpacingArg.getValue());
125 matcher->SetStDevThreshold(stdevThresholdArg.getValue());
128 matcher->SetMaximumIterations(maxIterationsArg.getValue());
129 matcher->SetMinimalTransformError(minErrorArg.getValue());
130 matcher->SetFinalRadius(finalRadiusArg.getValue());
131 matcher->SetOptimizerMaximumIterations(optimizerMaxIterationsArg.getValue());
132 matcher->SetSearchRadius(searchRadiusArg.getValue());
133 matcher->SetStepSize(searchStepArg.getValue());
134 matcher->SetTranslateUpperBound(translateUpperBoundArg.getValue());
137 matcher->SetOutputTransformType(PyramidBMType::outRigid);
138 matcher->SetAffineDirection(directionArg.getValue());
139 matcher->SetAgregThreshold(agregThresholdArg.getValue());
140 matcher->SetSeStoppingThreshold(seStoppingThresholdArg.getValue());
141 matcher->SetNumberOfPyramidLevels(numPyramidLevelsArg.getValue());
142 matcher->SetLastPyramidLevel(lastPyramidLevelArg.getValue());
143 matcher->SetVerbose(
false);
145 if (numThreadsArg.getValue() != 0)
146 matcher->SetNumberOfWorkUnits( numThreadsArg.getValue() );
148 matcher->SetPercentageKept( percentageKeptArg.getValue() );
149 matcher->SetTransformInitializationType(PyramidBMType::GravityCenters);
151 matcher->SetFloatingImage(referenceExtractFilter->GetOutput());
152 matcher->SetReferenceImage(extractFilter->GetOutput());
154 AffineTransformPointer rigidTrsf = AffineTransformType::New();
155 rigidTrsf->SetIdentity();
156 matcher->SetOutputTransform(rigidTrsf.GetPointer());
162 catch (itk::ExceptionObject &e)
164 std::cerr << e << std::endl;
168 rigidTrsf = dynamic_cast <AffineTransformType *> (matcher->GetOutputTransform().GetPointer());
170 InputSubImageType::Pointer rigidReference = matcher->GetOutputImage();
173 matcher->SetReferenceImage(rigidReference);
174 matcher->SetFloatingImage(extractFilter->GetOutput());
175 matcher->SetTransform(PyramidBMType::Directional_Affine);
176 matcher->SetOutputTransformType(PyramidBMType::outAffine);
178 AffineTransformPointer tmpTrsfDirectional = AffineTransformType::New();
179 tmpTrsfDirectional->SetIdentity();
180 matcher->SetOutputTransform(tmpTrsfDirectional.GetPointer());
186 catch (itk::ExceptionObject &e)
188 std::cerr << e << std::endl;
194 NonLinearPyramidBMType::Pointer nonLinearMatcher = NonLinearPyramidBMType::New();
196 nonLinearMatcher->SetReferenceImage(rigidReference);
197 nonLinearMatcher->SetFloatingImage(matcher->GetOutputImage());
200 nonLinearMatcher->SetBlockSize(blockSizeArg.getValue());
201 nonLinearMatcher->SetBlockSpacing(nlBlockSpacingArg.getValue());
202 nonLinearMatcher->SetStDevThreshold(stdevThresholdArg.getValue());
203 nonLinearMatcher->SetTransform(NonLinearPyramidBMType::Directional_Affine);
204 nonLinearMatcher->SetAffineDirection(directionArg.getValue());
207 nonLinearMatcher->SetMaximumIterations(maxIterationsArg.getValue());
208 nonLinearMatcher->SetMinimalTransformError(minErrorArg.getValue());
209 nonLinearMatcher->SetFinalRadius(finalRadiusArg.getValue());
210 nonLinearMatcher->SetOptimizerMaximumIterations(optimizerMaxIterationsArg.getValue());
211 nonLinearMatcher->SetSearchRadius(searchRadiusArg.getValue());
212 nonLinearMatcher->SetStepSize(searchStepArg.getValue());
213 nonLinearMatcher->SetTranslateUpperBound(translateUpperBoundArg.getValue());
216 nonLinearMatcher->SetBCHCompositionOrder(1);
217 nonLinearMatcher->SetExponentiationOrder(0);
218 nonLinearMatcher->SetExtrapolationSigma(extrapolationSigmaArg.getValue());
219 nonLinearMatcher->SetElasticSigma(elasticSigmaArg.getValue());
220 nonLinearMatcher->SetOutlierSigma(outlierSigmaArg.getValue());
221 nonLinearMatcher->SetNumberOfPyramidLevels(numPyramidLevelsArg.getValue());
222 nonLinearMatcher->SetLastPyramidLevel(lastPyramidLevelArg.getValue());
223 nonLinearMatcher->SetVerbose(
false);
225 if (numThreadsArg.getValue() != 0)
226 nonLinearMatcher->SetNumberOfWorkUnits(numThreadsArg.getValue());
228 nonLinearMatcher->SetPercentageKept(percentageKeptArg.getValue());
232 nonLinearMatcher->Update();
234 catch (itk::ExceptionObject &e)
236 std::cerr << e << std::endl;
241 typedef itk::CompositeTransform <AgregatorType::ScalarType,Dimension> GeneralTransformType;
242 GeneralTransformType::Pointer transformSerie = GeneralTransformType::New();
243 transformSerie->AddTransform(tmpTrsfDirectional);
245 typedef itk::StationaryVelocityFieldTransform <AgregatorType::ScalarType,Dimension> SVFTransformType;
246 typedef SVFTransformType::Pointer SVFTransformPointer;
248 typedef rpi::DisplacementFieldTransform <AgregatorType::ScalarType,Dimension> DenseTransformType;
249 typedef DenseTransformType::Pointer DenseTransformPointer;
251 SVFTransformPointer svfPointer = nonLinearMatcher->GetOutputTransform();
253 DenseTransformPointer dispTrsf = DenseTransformType::New();
256 transformSerie->AddTransform(dispTrsf.GetPointer());
259 AffineTransformType::MatrixType rigidMatrix = rigidTrsf->GetMatrix();
260 vnl_vector_fixed <double,3> tmpDir(0.0);
261 for (
unsigned int j = 0;j < 3;++j)
263 for (
unsigned int k = 0;k < 3;++k)
264 tmpDir[j] += rigidMatrix(j,k) * directions[i][k];
267 directions[i] = tmpDir;
269 AffineTransformPointer rigidTrsfInverse = AffineTransformType::New();
270 rigidTrsf->GetInverse(rigidTrsfInverse);
271 transformSerie->AddTransform(rigidTrsfInverse.GetPointer());
274 ResampleFilterType::Pointer scalarResampler = ResampleFilterType::New();
276 InputSubImageType::SizeType size = referenceExtractFilter->GetOutput()->GetLargestPossibleRegion().GetSize();
277 InputSubImageType::PointType origin = referenceExtractFilter->GetOutput()->GetOrigin();
278 InputSubImageType::SpacingType spacing = referenceExtractFilter->GetOutput()->GetSpacing();
279 InputSubImageType::DirectionType direction = referenceExtractFilter->GetOutput()->GetDirection();
281 scalarResampler->SetTransform(transformSerie);
282 scalarResampler->SetSize(size);
283 scalarResampler->SetOutputOrigin(origin);
284 scalarResampler->SetOutputSpacing(spacing);
285 scalarResampler->SetOutputDirection(direction);
287 scalarResampler->SetInput(extractFilter->GetOutput());
288 if (numThreadsArg.getValue() != 0)
289 scalarResampler->SetNumberOfWorkUnits(numThreadsArg.getValue());
290 scalarResampler->Update();
292 InputSubImageType::RegionType regionSubImage = scalarResampler->GetOutput()->GetLargestPossibleRegion();
293 InputImageType::RegionType regionImage = inputImage->GetLargestPossibleRegion();
294 regionImage.SetIndex(Dimension,i);
295 regionImage.SetSize(Dimension,1);
297 InputImageIteratorType outIterator(inputImage,regionImage);
298 InputSubImageIteratorType inIterator(scalarResampler->GetOutput(),regionSubImage);
300 while (!inIterator.IsAtEnd())
302 outIterator.Set(inIterator.Get());
309 std::cout << std::endl;
311 anima::writeImage <InputImageType> (outArg.getValue(),inputImage);
314 std::ofstream outputFile(outBVecArg.getValue());
315 for (
unsigned int i = 0;i < 3;++i)
317 for (
unsigned int j = 0;j < directions.size();++j)
318 outputFile << directions[j][i] <<
" ";
320 outputFile << std::endl;
int main(int argc, const char **argv)
void SetGradientFileName(std::string fName)
void GetSVFExponential(itk::StationaryVelocityFieldTransform< ScalarType, NDimensions > *baseTrsf, rpi::DisplacementFieldTransform< ScalarType, NDimensions > *resultTransform, unsigned int exponentiationOrder, unsigned int numThreads, bool invert)