From 97991aaa1d6a5810821353e3a4e16da9927a198b Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <michiel.cottaar@ndcn.ox.ac.uk>
Date: Thu, 1 Feb 2024 14:40:30 +0000
Subject: [PATCH] Move scanner_constraints! to scanners.jl

---
 src/MRIBuilder.jl      |  4 ++--
 src/building_blocks.jl | 45 --------------------------------------
 src/scanners.jl        | 49 ++++++++++++++++++++++++++++++++++++++++++
 src/variables.jl       |  1 -
 4 files changed, 51 insertions(+), 48 deletions(-)

diff --git a/src/MRIBuilder.jl b/src/MRIBuilder.jl
index 3e1f516..7baf904 100644
--- a/src/MRIBuilder.jl
+++ b/src/MRIBuilder.jl
@@ -3,10 +3,10 @@ Builds and optimises NMR/MRI sequences.
 """
 module MRIBuilder
 
-include("build_sequences.jl")
-include("scanners.jl")
 include("variables.jl")
 include("building_blocks.jl")
+include("scanners.jl")
+include("build_sequences.jl")
 include("wait.jl")
 include("gradients/gradients.jl")
 include("pulses/pulses.jl")
diff --git a/src/building_blocks.jl b/src/building_blocks.jl
index ab01541..28b2107 100644
--- a/src/building_blocks.jl
+++ b/src/building_blocks.jl
@@ -1,7 +1,6 @@
 module BuildingBlocks
 import JuMP: has_values, value, Model, @constraint, @objective, owner_model, objective_function, optimize!, AbstractJuMPScalar
 import Printf: @sprintf
-import ..Scanners: Scanner
 import ..Variables: variables, start_time, duration, end_time, gradient_strength, slew_rate, effective_time, VariableType, qval_square
 
 """
@@ -111,50 +110,6 @@ These all have in common that they have no free variables and explicitly set any
 function fixed end
 
 
-"""
-    scanner_constraints!([model, ]building_block, scanner)
-
-Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}.
-"""
-function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner)
-    scanner_constraints!(owner_model(building_block), building_block, scanner)
-end
-
-function scanner_constraints!(model::Model, building_block::BuildingBlock, scanner::Scanner)
-    for func in [gradient_strength, slew_rate]
-        if isfinite(func(scanner))
-            scanner_constraints!(model, building_block, scanner, func)
-        end
-    end
-end
-
-function scanner_constraints!(model::Model, building_block::BuildingBlock, scanner::Scanner, func::Function)
-    if func in variables(building_block)
-        # apply constraint at this level
-        res_bb = func(building_block)
-        if res_bb isa AbstractVector
-            if isnothing(building_block.rotate)
-                # no rotation; apply constraint to each dimension independently
-                for expr in res_bb
-                    @constraint model expr <= func(scanner)
-                    @constraint model expr >= -func(scanner)
-                end
-            else
-                # with rotation: apply constraint to total squared
-                total_squared = sum(map(n->n^2, res_bb))
-                @constraint model total_squared <= func(scanner)^2
-            end
-        else
-            @constraint model res_bb <= func(scanner)
-            @constraint model res_bb >= -func(scanner)
-        end
-    elseif building_block isa ContainerBlock
-        # apply constraints at lower level
-        for (_, child_block) in get_children_blocks(building_block)
-            scanner_constraints!(model, child_block, scanner, func)
-        end
-    end
-end
 
 """
     variables(building_block)
diff --git a/src/scanners.jl b/src/scanners.jl
index daf3982..e707a0d 100644
--- a/src/scanners.jl
+++ b/src/scanners.jl
@@ -2,6 +2,9 @@
 Define general [`Scanner`](@ref) type and methods as well as some concrete scanners.
 """
 module Scanners
+import JuMP: Model, @constraint, owner_model
+import ..Variables: gradient_strength, slew_rate
+import ..BuildingBlocks: BuildingBlock, get_children_blocks
 
 const gyromagnetic_ratio = 42576.38476  # (kHz/T)
 
@@ -83,4 +86,50 @@ predefined_scanners = Dict(
     :Siemens_Terra => Siemens_Terra,
     :Siemens_Connectom => Siemens_Connectom,
 )
+
+"""
+    scanner_constraints!([model, ]building_block, scanner)
+
+Adds any constraints from a specific scanner to a [`BuildingBlock`]{@ref}.
+"""
+function scanner_constraints!(building_block::BuildingBlock, scanner::Scanner)
+    scanner_constraints!(owner_model(building_block), building_block, scanner)
+end
+
+function scanner_constraints!(model::Model, building_block::BuildingBlock, scanner::Scanner)
+    for func in [gradient_strength, slew_rate]
+        if isfinite(func(scanner))
+            scanner_constraints!(model, building_block, scanner, func)
+        end
+    end
+end
+
+function scanner_constraints!(model::Model, building_block::BuildingBlock, scanner::Scanner, func::Function)
+    if func in variables(building_block)
+        # apply constraint at this level
+        res_bb = func(building_block)
+        if res_bb isa AbstractVector
+            if isnothing(building_block.rotate)
+                # no rotation; apply constraint to each dimension independently
+                for expr in res_bb
+                    @constraint model expr <= func(scanner)
+                    @constraint model expr >= -func(scanner)
+                end
+            else
+                # with rotation: apply constraint to total squared
+                total_squared = sum(map(n->n^2, res_bb))
+                @constraint model total_squared <= func(scanner)^2
+            end
+        else
+            @constraint model res_bb <= func(scanner)
+            @constraint model res_bb >= -func(scanner)
+        end
+    elseif building_block isa ContainerBlock
+        # apply constraints at lower level
+        for (_, child_block) in get_children_blocks(building_block)
+            scanner_constraints!(model, child_block, scanner, func)
+        end
+    end
+end
+
 end
\ No newline at end of file
diff --git a/src/variables.jl b/src/variables.jl
index 7539c5b..01afc14 100644
--- a/src/variables.jl
+++ b/src/variables.jl
@@ -1,6 +1,5 @@
 module Variables
 import JuMP: @variable, Model, @objective, objective_function, owner_model, has_values, value, AbstractJuMPScalar
-import ..Scanners: gradient_strength, slew_rate
 
 all_variables_symbols = [
     :block => [
-- 
GitLab