diff --git a/data_processing/.gitignore b/data_processing/.gitignore
new file mode 100644
index 000000000..2f7896d1d
--- /dev/null
+++ b/data_processing/.gitignore
@@ -0,0 +1 @@
+target/
diff --git a/data_processing/README.md b/data_processing/README.md
new file mode 100644
index 000000000..c17f8581f
--- /dev/null
+++ b/data_processing/README.md
@@ -0,0 +1,9 @@
+## Package
+```shell
+mvn clean scala:compile assembly:single
+```
+
+## Dependencies
+* Spark 3.0.1
+* Java 8
+* Scala 2.12
diff --git a/data_processing/pom.xml b/data_processing/pom.xml
new file mode 100644
index 000000000..4308b76ac
--- /dev/null
+++ b/data_processing/pom.xml
@@ -0,0 +1,115 @@
+
+ 4.0.0
+
+ com.bytedance.aml.enterprise
+ sm4spark
+ 0.0.1-SNAPSHOT
+
+
+ UTF-8
+ 3.0.3
+
+ 1.8
+ 2.12
+ 2.12.10
+
+ 1.8
+ 1.8
+ 8
+
+
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+
+
+
+ commons-codec
+ commons-codec
+ 1.15
+
+
+
+ org.apache.spark
+ spark-sql_2.12
+ ${spark.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ ${java.version}
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 4.3.0
+
+
+ scala-compile-first
+ process-resources
+
+ add-source
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+ ${scala.version}
+
+
+
+
+
+ maven-assembly-plugin
+
+ ${project.artifactId}-${project.version}-RELEASE
+
+
+ fully.qualified.MainClass
+
+
+
+ jar-with-dependencies
+
+
+
+
+ make-assembly
+ package
+
+ single
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/data_processing/src/main/scala/com/bytedance/aml/enterprise/sparkudaf/Hist.scala b/data_processing/src/main/scala/com/bytedance/aml/enterprise/sparkudaf/Hist.scala
new file mode 100644
index 000000000..5f1cfb38a
--- /dev/null
+++ b/data_processing/src/main/scala/com/bytedance/aml/enterprise/sparkudaf/Hist.scala
@@ -0,0 +1,25 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.bytedance.aml.enterprise.sparkudaf
+
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.expressions.UserDefinedFunction
+import org.apache.spark.sql.functions
+
+object Hist{
+ def getFunc: UserDefinedFunction = functions.udaf(HistUDAF, ExpressionEncoder[HistIn])
+
+}
\ No newline at end of file
diff --git a/data_processing/src/main/scala/com/bytedance/aml/enterprise/sparkudaf/HistUDAF.scala b/data_processing/src/main/scala/com/bytedance/aml/enterprise/sparkudaf/HistUDAF.scala
new file mode 100644
index 000000000..44cabc257
--- /dev/null
+++ b/data_processing/src/main/scala/com/bytedance/aml/enterprise/sparkudaf/HistUDAF.scala
@@ -0,0 +1,65 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.bytedance.aml.enterprise.sparkudaf
+
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.{Encoder, Encoders}
+
+case class HistIn(var value: Double, var min: Double, var max: Double, var binsNum: Int, var interval: Double)
+case class Bucket(var bins: Array[Double], var counts: Array[Int])
+
+object HistUDAF extends Aggregator[HistIn, Bucket, Bucket]{
+
+ def zero: Bucket = Bucket(bins = new Array[Double](0), counts = new Array[Int](0))
+
+ def reduce(buffer: Bucket, data: HistIn): Bucket = {
+ if (buffer.bins.length == 0) {
+ buffer.bins = new Array[Double](data.binsNum + 1)
+ for (i <- 0 until data.binsNum) {
+ buffer.bins(i) = i * data.interval + data.min
+ }
+ buffer.bins(data.binsNum) = data.max
+ buffer.counts = new Array[Int](data.binsNum)
+ }
+ if (data.interval != 0.0){
+ var bucket_idx = ((data.value - data.min) / data.interval).toInt
+ if (bucket_idx < 0) {
+ bucket_idx = 0
+ } else if (bucket_idx > (data.binsNum - 1)){
+ bucket_idx = data.binsNum - 1
+ }
+ buffer.counts(bucket_idx) += 1
+ }
+ buffer
+ }
+
+
+ def merge(b1: Bucket, b2: Bucket): Bucket = {
+ (b1.bins.length, b2.bins.length) match {
+ case (_, 0) => b1
+ case (0, _) => b2
+ case _ => b1.counts = (b1.counts zip b2.counts) map (x => x._1 + x._2)
+ b1
+ }
+ }
+
+ def finish(reduction: Bucket): Bucket = reduction
+
+ def bufferEncoder: Encoder[Bucket] = Encoders.product
+
+ def outputEncoder: Encoder[Bucket] = Encoders.product
+
+}
\ No newline at end of file
diff --git a/docs/licenses/LICENCE-BurntSushi_toml.txt b/docs/licenses/LICENCE-BurntSushi_toml.txt
new file mode 100644
index 000000000..01b574320
--- /dev/null
+++ b/docs/licenses/LICENCE-BurntSushi_toml.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2013 TOML authors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/docs/licenses/LICENCE-Go-Logrus.txt b/docs/licenses/LICENCE-Go-Logrus.txt
new file mode 100644
index 000000000..f090cb42f
--- /dev/null
+++ b/docs/licenses/LICENCE-Go-Logrus.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2014 Simon Eskildsen
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/docs/licenses/LICENCE-Go-Testify.txt b/docs/licenses/LICENCE-Go-Testify.txt
new file mode 100644
index 000000000..4b0421cf9
--- /dev/null
+++ b/docs/licenses/LICENCE-Go-Testify.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/licenses/LICENCE-GoDoc-Text.txt b/docs/licenses/LICENCE-GoDoc-Text.txt
new file mode 100644
index 000000000..77113a54b
--- /dev/null
+++ b/docs/licenses/LICENCE-GoDoc-Text.txt
@@ -0,0 +1,31 @@
+Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/
+Upstream-Name: github.com/kr/text
+Source: https://github.com/kr/text/
+
+Files: *
+Copyright: 2013 Keith Rarick
+License: Expat
+
+Files: debian/*
+Copyright: 2013 Tonnerre Lombard
+License: Expat
+
+License: Expat
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+ .
+ The above copyright notice and this permission notice shall be included in
+ all copies or substantial portions of the Software.
+ .
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ THE SOFTWARE
diff --git a/docs/licenses/LICENCE-Microsoft-go-winio.txt b/docs/licenses/LICENCE-Microsoft-go-winio.txt
new file mode 100644
index 000000000..fa365be22
--- /dev/null
+++ b/docs/licenses/LICENCE-Microsoft-go-winio.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2015 Microsoft
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/docs/licenses/LICENCE-Python_six.txt b/docs/licenses/LICENCE-Python_six.txt
new file mode 100644
index 000000000..01de9e22d
--- /dev/null
+++ b/docs/licenses/LICENCE-Python_six.txt
@@ -0,0 +1,18 @@
+Copyright (c) 2010-2018 Benjamin Peterson
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE
diff --git a/docs/licenses/LICENCE-armon_go-socks5.txt b/docs/licenses/LICENCE-armon_go-socks5.txt
new file mode 100644
index 000000000..94fadc2a9
--- /dev/null
+++ b/docs/licenses/LICENCE-armon_go-socks5.txt
@@ -0,0 +1,20 @@
+The MIT License (MIT)
+
+Copyright (c) 2014 Armon Dadgar
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE
diff --git a/docs/licenses/LICENCE-benbjohnson-clock.txt b/docs/licenses/LICENCE-benbjohnson-clock.txt
new file mode 100644
index 000000000..0dfeb1d6a
--- /dev/null
+++ b/docs/licenses/LICENCE-benbjohnson-clock.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2014 Ben Johnson
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/docs/licenses/LICENCE-beorn7-perks.txt b/docs/licenses/LICENCE-beorn7-perks.txt
new file mode 100644
index 000000000..9316a10d2
--- /dev/null
+++ b/docs/licenses/LICENCE-beorn7-perks.txt
@@ -0,0 +1,20 @@
+Copyright (C) 2013 Blake Mizerany
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE
diff --git a/docs/licenses/LICENCE-cespare_xxhash.txt b/docs/licenses/LICENCE-cespare_xxhash.txt
new file mode 100644
index 000000000..341bd91f0
--- /dev/null
+++ b/docs/licenses/LICENCE-cespare_xxhash.txt
@@ -0,0 +1,22 @@
+Copyright (c) 2016 Caleb Spare
+
+MIT License
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE
diff --git a/docs/licenses/LICENCE-charset-normalizer.txt b/docs/licenses/LICENCE-charset-normalizer.txt
new file mode 100644
index 000000000..a86dd9559
--- /dev/null
+++ b/docs/licenses/LICENCE-charset-normalizer.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 TAHRI Ahmed R.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/licenses/LICENCE-cpuguy83-go-md2man.txt b/docs/licenses/LICENCE-cpuguy83-go-md2man.txt
new file mode 100644
index 000000000..1cade6cef
--- /dev/null
+++ b/docs/licenses/LICENCE-cpuguy83-go-md2man.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2014 Brian Goff
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/licenses/LICENCE-create-react-app.txt b/docs/licenses/LICENCE-create-react-app.txt
new file mode 100644
index 000000000..a73b785a6
--- /dev/null
+++ b/docs/licenses/LICENCE-create-react-app.txt
@@ -0,0 +1,26 @@
+Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name Facebook nor the names of its contributors may be used to
+ endorse or promote products derived from this software without specific
+ prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-dsnet_compress.txt b/docs/licenses/LICENCE-dsnet_compress.txt
new file mode 100644
index 000000000..945b396cf
--- /dev/null
+++ b/docs/licenses/LICENCE-dsnet_compress.txt
@@ -0,0 +1,24 @@
+Copyright © 2015, Joe Tsai and The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation and/or
+other materials provided with the distribution.
+* Neither the copyright holder nor the names of its contributors may be used to
+endorse or promote products derived from this software without specific prior
+written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY
+DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-evanphx_json-patch.txt b/docs/licenses/LICENCE-evanphx_json-patch.txt
new file mode 100644
index 000000000..050fe60f0
--- /dev/null
+++ b/docs/licenses/LICENCE-evanphx_json-patch.txt
@@ -0,0 +1,25 @@
+Copyright (c) 2014, Evan Phoenix
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+* Neither the name of the Evan Phoenix nor the names of its contributors
+ may be used to endorse or promote products derived from this software
+ without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-frankban_quicktest.txt b/docs/licenses/LICENCE-frankban_quicktest.txt
new file mode 100644
index 000000000..23a294c75
--- /dev/null
+++ b/docs/licenses/LICENCE-frankban_quicktest.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2017 Canonical Ltd.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/licenses/LICENCE-fsnotify.txt b/docs/licenses/LICENCE-fsnotify.txt
new file mode 100644
index 000000000..fb03ade75
--- /dev/null
+++ b/docs/licenses/LICENCE-fsnotify.txt
@@ -0,0 +1,25 @@
+Copyright © 2012 The Go Authors. All rights reserved.
+Copyright © fsnotify Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice, this
+ list of conditions and the following disclaimer in the documentation and/or
+ other materials provided with the distribution.
+* Neither the name of Google Inc. nor the names of its contributors may be used
+ to endorse or promote products derived from this software without specific
+ prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-go-ansiterm.txt b/docs/licenses/LICENCE-go-ansiterm.txt
new file mode 100644
index 000000000..b86c36e25
--- /dev/null
+++ b/docs/licenses/LICENCE-go-ansiterm.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2015 Microsoft Corporation
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE
diff --git a/docs/licenses/LICENCE-go-check-check.txt b/docs/licenses/LICENCE-go-check-check.txt
new file mode 100644
index 000000000..9ac6ae0a6
--- /dev/null
+++ b/docs/licenses/LICENCE-go-check-check.txt
@@ -0,0 +1,23 @@
+BSD Two Clause License
+======================
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+ 1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ 2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE AUTHOR "AS IS" AND ANY EXPRESS OR IMPLIED
+WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
+SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
+OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
+OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+DAMAGE.
diff --git a/docs/licenses/LICENCE-go-inf-inf.txt b/docs/licenses/LICENCE-go-inf-inf.txt
new file mode 100644
index 000000000..e923f606e
--- /dev/null
+++ b/docs/licenses/LICENCE-go-inf-inf.txt
@@ -0,0 +1,28 @@
+Copyright (c) 2012 Péter Surányi. Portions Copyright (c) 2009 The Go
+Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-go-restful.txt b/docs/licenses/LICENCE-go-restful.txt
new file mode 100644
index 000000000..812a5c834
--- /dev/null
+++ b/docs/licenses/LICENCE-go-restful.txt
@@ -0,0 +1,22 @@
+Copyright (c) 2012,2013 Ernest Micklei
+
+MIT License
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/docs/licenses/LICENCE-go-spew.txt b/docs/licenses/LICENCE-go-spew.txt
new file mode 100644
index 000000000..223583735
--- /dev/null
+++ b/docs/licenses/LICENCE-go-spew.txt
@@ -0,0 +1,15 @@
+ISC License
+
+Copyright (c) 2012-2016 Dave Collins
+
+Permission to use, copy, modify, and/or distribute this software for any
+purpose with or without fee is hereby granted, provided that the above
+copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE
diff --git a/docs/licenses/LICENCE-go-tomb-tomb.txt b/docs/licenses/LICENCE-go-tomb-tomb.txt
new file mode 100644
index 000000000..db0834849
--- /dev/null
+++ b/docs/licenses/LICENCE-go-tomb-tomb.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2010-2011 - Gustavo Niemeyer
+
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+ * Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-go-zap.txt b/docs/licenses/LICENCE-go-zap.txt
new file mode 100644
index 000000000..82a1dd0dc
--- /dev/null
+++ b/docs/licenses/LICENCE-go-zap.txt
@@ -0,0 +1,19 @@
+Copyright (c) 2016-2017 Uber Technologies, Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE
diff --git a/docs/licenses/LICENCE-go_uber_org_goleak.txt b/docs/licenses/LICENCE-go_uber_org_goleak.txt
new file mode 100644
index 000000000..a0e4cc690
--- /dev/null
+++ b/docs/licenses/LICENCE-go_uber_org_goleak.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2018 Uber Technologies, Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE
diff --git a/docs/licenses/LICENCE-go_uber_org_multierr.txt b/docs/licenses/LICENCE-go_uber_org_multierr.txt
new file mode 100644
index 000000000..fe9e5258b
--- /dev/null
+++ b/docs/licenses/LICENCE-go_uber_org_multierr.txt
@@ -0,0 +1,19 @@
+Copyright (c) 2017 Uber Technologies, Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE
diff --git a/docs/licenses/LICENCE-gogo-protobuf.txt b/docs/licenses/LICENCE-gogo-protobuf.txt
new file mode 100644
index 000000000..748f3b3ee
--- /dev/null
+++ b/docs/licenses/LICENCE-gogo-protobuf.txt
@@ -0,0 +1,28 @@
+Copyright 2010 The Go Authors. All rights reserved.
+https://github.com/golang/protobuf
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-golang-github-spf13-pflag-dev.txt b/docs/licenses/LICENCE-golang-github-spf13-pflag-dev.txt
new file mode 100644
index 000000000..e6a8ddc0d
--- /dev/null
+++ b/docs/licenses/LICENCE-golang-github-spf13-pflag-dev.txt
@@ -0,0 +1,28 @@
+Copyright (c) 2012 Alex Ogier. All rights reserved.
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-golang-jwt_jwt.txt b/docs/licenses/LICENCE-golang-jwt_jwt.txt
new file mode 100644
index 000000000..95135bb75
--- /dev/null
+++ b/docs/licenses/LICENCE-golang-jwt_jwt.txt
@@ -0,0 +1,8 @@
+Copyright (c) 2012 Dave Grijalva
+Copyright (c) 2021 golang-jwt maintainers
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE
diff --git a/docs/licenses/LICENCE-golang-protobuf.txt b/docs/licenses/LICENCE-golang-protobuf.txt
new file mode 100644
index 000000000..ed122f2d6
--- /dev/null
+++ b/docs/licenses/LICENCE-golang-protobuf.txt
@@ -0,0 +1,27 @@
+Copyright 2010 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-golang-snappy-go-dev.txt b/docs/licenses/LICENCE-golang-snappy-go-dev.txt
new file mode 100644
index 000000000..cf9059d9d
--- /dev/null
+++ b/docs/licenses/LICENCE-golang-snappy-go-dev.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2011 The Snappy-Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-golang_org_x_net.txt b/docs/licenses/LICENCE-golang_org_x_net.txt
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/docs/licenses/LICENCE-golang_org_x_net.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-golang_org_x_oauth2.txt b/docs/licenses/LICENCE-golang_org_x_oauth2.txt
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/docs/licenses/LICENCE-golang_org_x_oauth2.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-golang_org_x_sync.txt b/docs/licenses/LICENCE-golang_org_x_sync.txt
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/docs/licenses/LICENCE-golang_org_x_sync.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-golang_org_x_term.txt b/docs/licenses/LICENCE-golang_org_x_term.txt
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/docs/licenses/LICENCE-golang_org_x_term.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-golang_org_x_time.txt b/docs/licenses/LICENCE-golang_org_x_time.txt
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/docs/licenses/LICENCE-golang_org_x_time.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-golang_text.txt b/docs/licenses/LICENCE-golang_text.txt
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/docs/licenses/LICENCE-golang_text.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-gomega.txt b/docs/licenses/LICENCE-gomega.txt
new file mode 100644
index 000000000..9415ee72c
--- /dev/null
+++ b/docs/licenses/LICENCE-gomega.txt
@@ -0,0 +1,20 @@
+Copyright (c) 2013-2014 Onsi Fakhouri
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/docs/licenses/LICENCE-google_go-cmp.txt b/docs/licenses/LICENCE-google_go-cmp.txt
new file mode 100644
index 000000000..32017f8fa
--- /dev/null
+++ b/docs/licenses/LICENCE-google_go-cmp.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2017 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-google_golang_org_protobuf.txt b/docs/licenses/LICENCE-google_golang_org_protobuf.txt
new file mode 100644
index 000000000..0f646931a
--- /dev/null
+++ b/docs/licenses/LICENCE-google_golang_org_protobuf.txt
@@ -0,0 +1,28 @@
+Copyright 2010 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
diff --git a/docs/licenses/LICENCE-google_uuid.txt b/docs/licenses/LICENCE-google_uuid.txt
new file mode 100644
index 000000000..3726ed0a0
--- /dev/null
+++ b/docs/licenses/LICENCE-google_uuid.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009,2014 Google Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-goproxy.txt b/docs/licenses/LICENCE-goproxy.txt
new file mode 100644
index 000000000..2067e567c
--- /dev/null
+++ b/docs/licenses/LICENCE-goproxy.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2012 Elazar Leibovich. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Elazar Leibovich. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-gorilla_mux.txt b/docs/licenses/LICENCE-gorilla_mux.txt
new file mode 100644
index 000000000..5da121e53
--- /dev/null
+++ b/docs/licenses/LICENCE-gorilla_mux.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-idna.txt b/docs/licenses/LICENCE-idna.txt
new file mode 100644
index 000000000..cc7d6baac
--- /dev/null
+++ b/docs/licenses/LICENCE-idna.txt
@@ -0,0 +1,31 @@
+BSD 3-Clause License
+
+Copyright (c) 2013-2022, Kim Davies and contributors.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-josharian_intern.txt b/docs/licenses/LICENCE-josharian_intern.txt
new file mode 100644
index 000000000..0096c79c6
--- /dev/null
+++ b/docs/licenses/LICENCE-josharian_intern.txt
@@ -0,0 +1,23 @@
+2020 Roger Shimizu
+License: Expat
+Comment: Debian packaging is licensed under the same terms as upstream
+
+License: Expat
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+ .
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+ .
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE
diff --git a/docs/licenses/LICENCE-jsoniter-go.txt b/docs/licenses/LICENCE-jsoniter-go.txt
new file mode 100644
index 000000000..f6dfb8773
--- /dev/null
+++ b/docs/licenses/LICENCE-jsoniter-go.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2016 json-iterator
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/docs/licenses/LICENCE-kr-fs.txt b/docs/licenses/LICENCE-kr-fs.txt
new file mode 100644
index 000000000..76427ff52
--- /dev/null
+++ b/docs/licenses/LICENCE-kr-fs.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-kr_pretty.txt b/docs/licenses/LICENCE-kr_pretty.txt
new file mode 100644
index 000000000..480a32805
--- /dev/null
+++ b/docs/licenses/LICENCE-kr_pretty.txt
@@ -0,0 +1,19 @@
+Copyright 2012 Keith Rarick
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/docs/licenses/LICENCE-mailru_easyjson.txt b/docs/licenses/LICENCE-mailru_easyjson.txt
new file mode 100644
index 000000000..620fb1f5b
--- /dev/null
+++ b/docs/licenses/LICENCE-mailru_easyjson.txt
@@ -0,0 +1,7 @@
+Copyright (c) 2016 Mail.Ru Group
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE
diff --git a/docs/licenses/LICENCE-melbahja_goph.txt b/docs/licenses/LICENCE-melbahja_goph.txt
new file mode 100644
index 000000000..42d540c38
--- /dev/null
+++ b/docs/licenses/LICENCE-melbahja_goph.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020-present Mohamed El Bahja
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/licenses/LICENCE-mergo.txt b/docs/licenses/LICENCE-mergo.txt
new file mode 100644
index 000000000..068cab72d
--- /dev/null
+++ b/docs/licenses/LICENCE-mergo.txt
@@ -0,0 +1,28 @@
+Copyright (c) 2013 Dario Castañé. All rights reserved.
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-mholt_archiver.txt b/docs/licenses/LICENCE-mholt_archiver.txt
new file mode 100644
index 000000000..54bc89fa0
--- /dev/null
+++ b/docs/licenses/LICENCE-mholt_archiver.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2016 Matthew Holt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/docs/licenses/LICENCE-morikuni_aec.txt b/docs/licenses/LICENCE-morikuni_aec.txt
new file mode 100644
index 000000000..7504d0682
--- /dev/null
+++ b/docs/licenses/LICENCE-morikuni_aec.txt
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2016 Taihei Morikuni
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/docs/licenses/LICENCE-munnerz_goautoneg.txt b/docs/licenses/LICENCE-munnerz_goautoneg.txt
new file mode 100644
index 000000000..bbc7b897c
--- /dev/null
+++ b/docs/licenses/LICENCE-munnerz_goautoneg.txt
@@ -0,0 +1,31 @@
+Copyright (c) 2011, Open Knowledge Foundation Ltd.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+
+ Neither the name of the Open Knowledge Foundation Ltd. nor the
+ names of its contributors may be used to endorse or promote
+ products derived from this software without specific prior written
+ permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-nwaples_rardecode.txt b/docs/licenses/LICENCE-nwaples_rardecode.txt
new file mode 100644
index 000000000..160337a36
--- /dev/null
+++ b/docs/licenses/LICENCE-nwaples_rardecode.txt
@@ -0,0 +1,23 @@
+Copyright (c) 2015, Nicholas Waples
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-nxadm_tail.txt b/docs/licenses/LICENCE-nxadm_tail.txt
new file mode 100644
index 000000000..595de48cd
--- /dev/null
+++ b/docs/licenses/LICENCE-nxadm_tail.txt
@@ -0,0 +1,21 @@
+# The MIT License (MIT)
+
+# © Copyright 2015 Hewlett Packard Enterprise Development LP
+Copyright (c) 2014 ActiveState
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/docs/licenses/LICENCE-onsi_ginkgo.txt b/docs/licenses/LICENCE-onsi_ginkgo.txt
new file mode 100644
index 000000000..9415ee72c
--- /dev/null
+++ b/docs/licenses/LICENCE-onsi_ginkgo.txt
@@ -0,0 +1,20 @@
+Copyright (c) 2013-2014 Onsi Fakhouri
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/docs/licenses/LICENCE-pierrec-lz4.txt b/docs/licenses/LICENCE-pierrec-lz4.txt
new file mode 100644
index 000000000..bb8c35c0b
--- /dev/null
+++ b/docs/licenses/LICENCE-pierrec-lz4.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2015, Pierre Curto
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of xxHash nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-pkg_errors.txt b/docs/licenses/LICENCE-pkg_errors.txt
new file mode 100644
index 000000000..141995377
--- /dev/null
+++ b/docs/licenses/LICENCE-pkg_errors.txt
@@ -0,0 +1,23 @@
+Copyright (c) 2015, Dave Cheney
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-pmezard-go-difflib.txt b/docs/licenses/LICENCE-pmezard-go-difflib.txt
new file mode 100644
index 000000000..a635f8b06
--- /dev/null
+++ b/docs/licenses/LICENCE-pmezard-go-difflib.txt
@@ -0,0 +1,35 @@
+Copyright: 2013 Patrick Mézard
+License: BSD-3-clause
+
+Files: debian/*
+Copyright: 2016 Dmitry Smirnov
+License: BSD-3-clause
+
+License: BSD-3-clause
+
+Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+ .
+ Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ .
+ Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ .
+ The names of its contributors may not be used to endorse or promote
+ products derived from this software without specific prior written
+ permission.
+ .
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
+ IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
+ TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+ PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+ TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-purell.txt b/docs/licenses/LICENCE-purell.txt
new file mode 100644
index 000000000..8cf42fe5b
--- /dev/null
+++ b/docs/licenses/LICENCE-purell.txt
@@ -0,0 +1,12 @@
+Copyright (c) 2012, Martin Angers
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+* Neither the name of the author nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-pypi_setuptools.txt b/docs/licenses/LICENCE-pypi_setuptools.txt
new file mode 100644
index 000000000..323d2c18e
--- /dev/null
+++ b/docs/licenses/LICENCE-pypi_setuptools.txt
@@ -0,0 +1,19 @@
+Copyright (C) 2016 Jason R Coombs
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
diff --git a/docs/licenses/LICENCE-python-certifi.txt b/docs/licenses/LICENCE-python-certifi.txt
new file mode 100644
index 000000000..383c7a63d
--- /dev/null
+++ b/docs/licenses/LICENCE-python-certifi.txt
@@ -0,0 +1,409 @@
+Mozilla Public License
+Version 2.0
+======================
+
+
+1. Definitions
+--------------
+
+ 1.1. "Contributor"
+
+ means each individual or legal entity that creates, contributes to the creation
+ of, or owns Covered Software.
+
+ 1.2. "Contributor Version"
+
+ means the combination of the Contributions of others (if any) used by a
+ Contributor and that particular Contributor's Contribution.
+
+ 1.3. "Contribution"
+
+ means Covered Software of a particular Contributor.
+
+ 1.4. "Covered Software"
+
+ means Source Code Form to which the initial Contributor has attached the notice
+ in Exhibit A, the Executable Form of such Source Code Form, and Modifications
+ of such Source Code Form, in each case including portions thereof.
+
+ 1.5. "Incompatible With Secondary Licenses"
+
+ means
+
+ a.
+
+ that the initial Contributor has attached the notice described in Exhibit B
+ to the Covered Software; or
+
+ b.
+
+ that the Covered Software was made available under the terms of version 1.1
+ or earlier of the License, but not also under the terms of a Secondary
+ License.
+
+ 1.6. "Executable Form"
+
+ means any form of the work other than Source Code Form.
+
+ 1.7. "Larger Work"
+
+ means a work that combines Covered Software with other material, in a separate
+ file or files, that is not Covered Software.
+
+ 1.8. "License"
+
+ means this document.
+
+ 1.9. "Licensable"
+
+ means having the right to grant, to the maximum extent possible, whether at the
+ time of the initial grant or subsequently, any and all of the rights conveyed
+ by this License.
+
+ 1.10. "Modifications"
+
+ means any of the following:
+
+ a.
+
+ any file in Source Code Form that results from an addition to, deletion
+ from, or modification of the contents of Covered Software; or
+
+ b.
+
+ any new file in Source Code Form that contains any Covered Software.
+
+ 1.11. "Patent Claims" of a Contributor
+
+ means any patent claim(s), including without limitation, method, process, and
+ apparatus claims, in any patent Licensable by such Contributor that would be
+ infringed, but for the grant of the License, by the making, using, selling,
+ offering for sale, having made, import, or transfer of either its Contributions
+ or its Contributor Version.
+
+ 1.12. "Secondary License"
+
+ means either the GNU General Public License, Version 2.0, the GNU Lesser
+ General Public License, Version 2.1, the GNU Affero General Public License,
+ Version 3.0, or any later versions of those licenses.
+
+ 1.13. "Source Code Form"
+
+ means the form of the work preferred for making modifications.
+
+ 1.14. "You" (or "Your")
+
+ means an individual or a legal entity exercising rights under this License. For
+ legal entities, "You" includes any entity that controls, is controlled by, or
+ is under common control with You. For purposes of this definition, "control"
+ means (a) the power, direct or indirect, to cause the direction or management
+ of such entity, whether by contract or otherwise, or (b) ownership of more than
+ fifty percent (50%) of the outstanding shares or beneficial ownership of such
+ entity.
+
+
+2. License Grants and Conditions
+--------------------------------
+
+
+ 2.1. Grants
+
+ Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive
+ license:
+
+ a.
+
+ under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available, modify,
+ display, perform, distribute, and otherwise exploit its Contributions,
+ either on an unmodified basis, with Modifications, or as part of a Larger
+ Work; and
+
+ b.
+
+ under Patent Claims of such Contributor to make, use, sell, offer for sale,
+ have made, import, and otherwise transfer either its Contributions or its
+ Contributor Version.
+
+
+ 2.2. Effective Date
+
+ The licenses granted in Section 2.1 with respect to any Contribution become
+ effective for each Contribution on the date the Contributor first distributes
+ such Contribution.
+
+
+ 2.3. Limitations on Grant Scope
+
+ The licenses granted in this Section 2 are the only rights granted under this
+ License. No additional rights or licenses will be implied from the distribution
+ or licensing of Covered Software under this License. Notwithstanding
+ Section 2.1(b) above, no patent license is granted by a Contributor:
+
+ a.
+
+ for any code that a Contributor has removed from Covered Software; or
+
+ b.
+
+ for infringements caused by: (i) Your and any other third party's
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+ c.
+
+ under Patent Claims infringed by Covered Software in the absence of its
+ Contributions.
+
+ This License does not grant any rights in the trademarks, service marks, or
+ logos of any Contributor (except as may be necessary to comply with the notice
+ requirements in Section 3.4).
+
+
+ 2.4. Subsequent Licenses
+
+ No Contributor makes additional grants as a result of Your choice to distribute
+ the Covered Software under a subsequent version of this License (see
+ Section 10.2) or under the terms of a Secondary License (if permitted under the
+ terms of Section 3.3).
+
+
+ 2.5. Representation
+
+ Each Contributor represents that the Contributor believes its Contributions are
+ its original creation(s) or it has sufficient rights to grant the rights to its
+ Contributions conveyed by this License.
+
+
+ 2.6. Fair Use
+
+ This License is not intended to limit any rights You have under applicable
+ copyright doctrines of fair use, fair dealing, or other equivalents.
+
+
+ 2.7. Conditions
+
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
+ Section 2.1.
+
+
+3. Responsibilities
+-------------------
+
+
+ 3.1. Distribution of Source Form
+
+ All distribution of Covered Software in Source Code Form, including any
+ Modifications that You create or to which You contribute, must be under the
+ terms of this License. You must inform recipients that the Source Code Form of
+ the Covered Software is governed by the terms of this License, and how they can
+ obtain a copy of this License. You may not attempt to alter or restrict the
+ recipients' rights in the Source Code Form.
+
+
+ 3.2. Distribution of Executable Form
+
+ If You distribute Covered Software in Executable Form then:
+
+ a.
+
+ such Covered Software must also be made available in Source Code Form, as
+ described in Section 3.1, and You must inform recipients of the Executable
+ Form how they can obtain a copy of such Source Code Form by reasonable
+ means in a timely manner, at a charge no more than the cost of distribution
+ to the recipient; and
+
+ b.
+
+ You may distribute such Executable Form under the terms of this License, or
+ sublicense it under different terms, provided that the license for the
+ Executable Form does not attempt to limit or alter the recipients' rights
+ in the Source Code Form under this License.
+
+
+ 3.3. Distribution of a Larger Work
+
+ You may create and distribute a Larger Work under terms of Your choice,
+ provided that You also comply with the requirements of this License for the
+ Covered Software. If the Larger Work is a combination of Covered Software with
+ a work governed by one or more Secondary Licenses, and the Covered Software is
+ not Incompatible With Secondary Licenses, this License permits You to
+ additionally distribute such Covered Software under the terms of such Secondary
+ License(s), so that the recipient of the Larger Work may, at their option,
+ further distribute the Covered Software under the terms of either this License
+ or such Secondary License(s).
+
+
+ 3.4. Notices
+
+ You may not remove or alter the substance of any license notices (including
+ copyright notices, patent notices, disclaimers of warranty, or limitations of
+ liability) contained within the Source Code Form of the Covered Software,
+ except that You may alter any license notices to the extent required to remedy
+ known factual inaccuracies.
+
+
+ 3.5. Application of Additional Terms
+
+ You may choose to offer, and to charge a fee for, warranty, support, indemnity
+ or liability obligations to one or more recipients of Covered Software.
+ However, You may do so only on Your own behalf, and not on behalf of any
+ Contributor. You must make it absolutely clear that any such warranty, support,
+ indemnity, or liability obligation is offered by You alone, and You hereby
+ agree to indemnify every Contributor for any liability incurred by such
+ Contributor as a result of warranty, support, indemnity or liability terms You
+ offer. You may include additional disclaimers of warranty and limitations of
+ liability specific to any jurisdiction.
+
+
+4. Inability to Comply Due to Statute or Regulation
+---------------------------------------------------
+
+If it is impossible for You to comply with any of the terms of this License with
+respect to some or all of the Covered Software due to statute, judicial order, or
+regulation then You must: (a) comply with the terms of this License to the
+maximum extent possible; and (b) describe the limitations and the code they
+affect. Such description must be placed in a text file included with all
+distributions of the Covered Software under this License. Except to the extent
+prohibited by statute or regulation, such description must be sufficiently
+detailed for a recipient of ordinary skill to be able to understand it.
+
+
+5. Termination
+--------------
+
+ 5.1. The rights granted under this License will terminate automatically if You
+ fail to comply with any of its terms. However, if You become compliant, then
+ the rights granted under this License from a particular Contributor are
+ reinstated (a) provisionally, unless and until such Contributor explicitly and
+ finally terminates Your grants, and (b) on an ongoing basis, if such
+ Contributor fails to notify You of the non-compliance by some reasonable means
+ prior to 60 days after You have come back into compliance. Moreover, Your
+ grants from a particular Contributor are reinstated on an ongoing basis if such
+ Contributor notifies You of the non-compliance by some reasonable means, this
+ is the first time You have received notice of non-compliance with this License
+ from such Contributor, and You become compliant prior to 30 days after Your
+ receipt of the notice.
+
+ 5.2. If You initiate litigation against any entity by asserting a patent
+ infringement claim (excluding declaratory judgment actions, counter-claims, and
+ cross-claims) alleging that a Contributor Version directly or indirectly
+ infringes any patent, then the rights granted to You by any and all
+ Contributors for the Covered Software under Section 2.1 of this License shall
+ terminate.
+
+ 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
+ license agreements (excluding distributors and resellers) which have been
+ validly granted by You or Your distributors under this License prior to
+ termination shall survive termination.
+
+
+6. Disclaimer of Warranty
+-------------------------
+
+Covered Software is provided under this License on an "as is" basis, without
+warranty of any kind, either expressed, implied, or statutory, including, without
+limitation, warranties that the Covered Software is free of defects,
+merchantable, fit for a particular purpose or non-infringing. The entire risk as
+to the quality and performance of the Covered Software is with You. Should any
+Covered Software prove defective in any respect, You (not any Contributor) assume
+the cost of any necessary servicing, repair, or correction. This disclaimer of
+warranty constitutes an essential part of this License. No use of any Covered
+Software is authorized under this License except under this disclaimer.
+
+
+7. Limitation of Liability
+--------------------------
+
+Under no circumstances and under no legal theory, whether tort (including
+negligence), contract, or otherwise, shall any Contributor, or anyone who
+distributes Covered Software as permitted above, be liable to You for any direct,
+indirect, special, incidental, or consequential damages of any character
+including, without limitation, damages for lost profits, loss of goodwill, work
+stoppage, computer failure or malfunction, or any and all other commercial
+damages or losses, even if such party shall have been informed of the possibility
+of such damages. This limitation of liability shall not apply to liability for
+death or personal injury resulting from such party's negligence to the extent
+applicable law prohibits such limitation. Some jurisdictions do not allow the
+exclusion or limitation of incidental or consequential damages, so this exclusion
+and limitation may not apply to You.
+
+
+8. Litigation
+-------------
+
+Any litigation relating to this License may be brought only in the courts of a
+jurisdiction where the defendant maintains its principal place of business and
+such litigation shall be governed by laws of that jurisdiction, without reference
+to its conflict-of-law provisions. Nothing in this Section shall prevent a
+party's ability to bring cross-claims or counter-claims.
+
+
+9. Miscellaneous
+----------------
+
+This License represents the complete agreement concerning the subject matter
+hereof. If any provision of this License is held to be unenforceable, such
+provision shall be reformed only to the extent necessary to make it enforceable.
+Any law or regulation which provides that the language of a contract shall be
+construed against the drafter shall not be used to construe this License against
+a Contributor.
+
+
+10. Versions of the License
+---------------------------
+
+
+ 10.1. New Versions
+
+ Mozilla Foundation is the license steward. Except as provided in Section 10.3,
+ no one other than the license steward has the right to modify or publish new
+ versions of this License. Each version will be given a distinguishing version
+ number.
+
+
+ 10.2. Effect of New Versions
+
+ You may distribute the Covered Software under the terms of the version of the
+ License under which You originally received the Covered Software, or under the
+ terms of any subsequent version published by the license steward.
+
+
+ 10.3. Modified Versions
+
+ If you create software not governed by this License, and you want to create a
+ new license for such software, you may create and use a modified version of
+ this License if you rename the license and remove any references to the name of
+ the license steward (except to note that such modified license differs from
+ this License).
+
+
+ 10.4. Distributing Source Code Form that is Incompatible With Secondary
+ Licenses
+
+ If You choose to distribute Source Code Form that is Incompatible With
+ Secondary Licenses under the terms of this version of the License, the notice
+ described in Exhibit B of this License must be attached.
+
+
+Exhibit A - Source Code Form License Notice
+-------------------------------------------
+
+ This Source Code Form is subject to the terms of the Mozilla Public License,
+ v. 2.0. If a copy of the MPL was not distributed with this file, You can
+ obtain one at http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular file, then
+You may include the notice in a location (such as a LICENSE file in a relevant
+directory) where a recipient would be likely to look for such a notice.
+
+You may add additional accurate notices of copyright ownership.
+
+
+Exhibit B - "Incompatible With Secondary Licenses" Notice
+---------------------------------------------------------
+
+ This Source Code Form is "Incompatible With Secondary Licenses", as defined
+ by the Mozilla Public License, v. 2.0.
+
diff --git a/docs/licenses/LICENCE-rogpeppe_go-internal.txt b/docs/licenses/LICENCE-rogpeppe_go-internal.txt
new file mode 100644
index 000000000..49ea0f928
--- /dev/null
+++ b/docs/licenses/LICENCE-rogpeppe_go-internal.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2018 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-sftp.txt b/docs/licenses/LICENCE-sftp.txt
new file mode 100644
index 000000000..9ac6ae0a6
--- /dev/null
+++ b/docs/licenses/LICENCE-sftp.txt
@@ -0,0 +1,23 @@
+BSD Two Clause License
+======================
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+ 1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ 2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE AUTHOR "AS IS" AND ANY EXPRESS OR IMPLIED
+WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
+SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
+OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
+OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+DAMAGE.
diff --git a/docs/licenses/LICENCE-sigs_k8s_io_json.txt b/docs/licenses/LICENCE-sigs_k8s_io_json.txt
new file mode 100644
index 000000000..e5adf7f0c
--- /dev/null
+++ b/docs/licenses/LICENCE-sigs_k8s_io_json.txt
@@ -0,0 +1,238 @@
+Files other than internal/golang/* licensed under:
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+------------------
+
+internal/golang/* files licensed under:
+
+
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/docs/licenses/LICENCE-uber-go_atomic.txt b/docs/licenses/LICENCE-uber-go_atomic.txt
new file mode 100644
index 000000000..12cd09580
--- /dev/null
+++ b/docs/licenses/LICENCE-uber-go_atomic.txt
@@ -0,0 +1,19 @@
+Copyright (c) 2016 Uber Technologies, Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE
diff --git a/docs/licenses/LICENCE-ulikunitz_xz.txt b/docs/licenses/LICENCE-ulikunitz_xz.txt
new file mode 100644
index 000000000..d358ed04d
--- /dev/null
+++ b/docs/licenses/LICENCE-ulikunitz_xz.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-urfAVE_cli.txt b/docs/licenses/LICENCE-urfAVE_cli.txt
new file mode 100644
index 000000000..2c84c78a1
--- /dev/null
+++ b/docs/licenses/LICENCE-urfAVE_cli.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 urfave/cli maintainers
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/licenses/LICENCE-urlesc.txt b/docs/licenses/LICENCE-urlesc.txt
new file mode 100644
index 000000000..76427ff52
--- /dev/null
+++ b/docs/licenses/LICENCE-urlesc.txt
@@ -0,0 +1,27 @@
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
diff --git a/docs/licenses/LICENCE-urllib3.txt b/docs/licenses/LICENCE-urllib3.txt
new file mode 100644
index 000000000..429a1767e
--- /dev/null
+++ b/docs/licenses/LICENCE-urllib3.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2008-2020 Andrey Petrov and contributors (see CONTRIBUTORS.txt)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/docs/licenses/LICENCE-xrash_smetrics.txt b/docs/licenses/LICENCE-xrash_smetrics.txt
new file mode 100644
index 000000000..80445682f
--- /dev/null
+++ b/docs/licenses/LICENCE-xrash_smetrics.txt
@@ -0,0 +1,21 @@
+Copyright (C) 2016 Felipe da Cunha Gonçalves
+All Rights Reserved.
+
+MIT LICENSE
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/docs/licenses/LICENCE-yaml-for-Go.txt b/docs/licenses/LICENCE-yaml-for-Go.txt
new file mode 100644
index 000000000..2683e4bb1
--- /dev/null
+++ b/docs/licenses/LICENCE-yaml-for-Go.txt
@@ -0,0 +1,50 @@
+
+This project is covered by two different licenses: MIT and Apache.
+
+#### MIT License ####
+
+The following files were ported to Go from C files of libyaml, and thus
+are still covered by their original MIT license, with the additional
+copyright staring in 2011 when the project was ported over:
+
+ apic.go emitterc.go parserc.go readerc.go scannerc.go
+ writerc.go yamlh.go yamlprivateh.go
+
+Copyright (c) 2006-2010 Kirill Simonov
+Copyright (c) 2006-2011 Kirill Simonov
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+### Apache License ###
+
+All the remaining project files are covered by the Apache license:
+
+Copyright (c) 2011-2019 Canonical Ltd
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/operator/.dockerignore b/operator/.dockerignore
new file mode 100644
index 000000000..0f046820f
--- /dev/null
+++ b/operator/.dockerignore
@@ -0,0 +1,4 @@
+# More info: https://docs.docker.com/engine/reference/builder/#dockerignore-file
+# Ignore build and test binaries.
+bin/
+testbin/
diff --git a/operator/.gitignore b/operator/.gitignore
new file mode 100644
index 000000000..4db1fbd44
--- /dev/null
+++ b/operator/.gitignore
@@ -0,0 +1,25 @@
+
+# Binaries for programs and plugins
+*.exe
+*.exe~
+*.dll
+*.so
+*.dylib
+bin
+testbin/*
+
+# Test binary, build with `go test -c`
+*.test
+
+# Output of the go coverage tool, specifically when used with LiteIDE
+*.out
+
+# Kubernetes Generated files - skip generated files, except for vendored files
+!vendor/**/zz_generated.*
+
+# editor and IDE paraphernalia
+.idea
+*.swp
+*.swo
+*~
+/config/
diff --git a/operator/BUILD.bazel b/operator/BUILD.bazel
new file mode 100755
index 000000000..70ff958ed
--- /dev/null
+++ b/operator/BUILD.bazel
@@ -0,0 +1,5 @@
+filegroup(
+ name = "srcs",
+ srcs = glob(["**"]),
+ visibility = ["//visibility:public"],
+)
diff --git a/operator/Dockerfile b/operator/Dockerfile
new file mode 100644
index 000000000..4152680b7
--- /dev/null
+++ b/operator/Dockerfile
@@ -0,0 +1,27 @@
+# Build the manager binary
+FROM golang:1.16 as builder
+
+WORKDIR /workspace
+# Copy the Go Modules manifests
+COPY go.mod go.mod
+COPY go.sum go.sum
+# cache deps before building and copying source so that we don't need to re-download as much
+# and so that source changes don't invalidate our downloaded layer
+RUN go mod download
+
+# Copy the go source
+COPY main.go main.go
+COPY api/ api/
+COPY controllers/ controllers/
+
+# Build
+RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o manager main.go
+
+# Use distroless as minimal base image to package the manager binary
+# Refer to https://github.com/GoogleContainerTools/distroless for more details
+FROM gcr.io/distroless/static:nonroot
+WORKDIR /
+COPY --from=builder /workspace/manager .
+USER 65532:65532
+
+ENTRYPOINT ["/manager"]
diff --git a/operator/Makefile b/operator/Makefile
new file mode 100644
index 000000000..fddbad234
--- /dev/null
+++ b/operator/Makefile
@@ -0,0 +1,130 @@
+
+# Image URL to use all building/pushing image targets
+IMG ?= controller:latest
+# ENVTEST_K8S_VERSION refers to the version of kubebuilder assets to be downloaded by envtest binary.
+ENVTEST_K8S_VERSION = 1.19.2
+
+# Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set)
+ifeq (,$(shell go env GOBIN))
+GOBIN=$(shell go env GOPATH)/bin
+else
+GOBIN=$(shell go env GOBIN)
+endif
+
+# Setting SHELL to bash allows bash commands to be executed by recipes.
+# This is a requirement for 'setup-envtest.sh' in the test target.
+# Options are set to exit when a recipe line exits non-zero or a piped command fails.
+SHELL = /usr/bin/env bash -o pipefail
+.SHELLFLAGS = -ec
+
+.PHONY: all
+all: build
+
+##@ General
+
+# The help target prints out all targets with their descriptions organized
+# beneath their categories. The categories are represented by '##@' and the
+# target descriptions by '##'. The awk commands is responsible for reading the
+# entire set of makefiles included in this invocation, looking for lines of the
+# file as xyz: ## something, and then pretty-format the target and help. Then,
+# if there's a line with ##@ something, that gets pretty-printed as a category.
+# More info on the usage of ANSI control characters for terminal formatting:
+# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters
+# More info on the awk command:
+# http://linuxcommand.org/lc3_adv_awk.php
+
+.PHONY: help
+help: ## Display this help.
+ @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
+
+##@ Development
+
+.PHONY: manifests
+manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects.
+ $(CONTROLLER_GEN) rbac:roleName=manager-role crd webhook paths="./..." output:crd:artifacts:config=config/crd/bases
+
+.PHONY: generate
+generate: controller-gen ## Generate code containing DeepCopy, DeepCopyInto, and DeepCopyObject method implementations.
+ $(CONTROLLER_GEN) object:headerFile="hack/boilerplate.go.txt" paths="./..."
+
+.PHONY: fmt
+fmt: ## Run go fmt against code.
+ go fmt ./...
+
+.PHONY: vet
+vet: ## Run go vet against code.
+ go vet ./...
+
+.PHONY: test
+test: manifests generate fmt vet envtest ## Run tests.
+ KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) -p path)" go test ./... -coverprofile cover.out
+
+##@ Build
+
+.PHONY: build
+build: generate fmt vet ## Build manager binary.
+ go build -o bin/manager main.go
+
+.PHONY: run
+run: manifests generate fmt vet ## Run a controller from your host.
+ go run ./main.go
+
+.PHONY: docker-build
+docker-build: test ## Build docker image with the manager.
+ docker build -t ${IMG} .
+
+.PHONY: docker-push
+docker-push: ## Push docker image with the manager.
+ docker push ${IMG}
+
+##@ Deployment
+
+ifndef ignore-not-found
+ ignore-not-found = false
+endif
+
+.PHONY: install
+install: manifests kustomize ## Install CRDs into the K8s cluster specified in ~/.kube/config.
+ $(KUSTOMIZE) build config/crd | kubectl apply -f -
+
+.PHONY: uninstall
+uninstall: manifests kustomize ## Uninstall CRDs from the K8s cluster specified in ~/.kube/config. Call with ignore-not-found=true to ignore resource not found errors during deletion.
+ $(KUSTOMIZE) build config/crd | kubectl delete --ignore-not-found=$(ignore-not-found) -f -
+
+.PHONY: deploy
+deploy: manifests kustomize ## Deploy controller to the K8s cluster specified in ~/.kube/config.
+ cd config/manager && $(KUSTOMIZE) edit set image controller=${IMG}
+ $(KUSTOMIZE) build config/default | kubectl apply -f -
+
+.PHONY: undeploy
+undeploy: ## Undeploy controller from the K8s cluster specified in ~/.kube/config. Call with ignore-not-found=true to ignore resource not found errors during deletion.
+ $(KUSTOMIZE) build config/default | kubectl delete --ignore-not-found=$(ignore-not-found) -f -
+
+CONTROLLER_GEN = $(shell pwd)/bin/controller-gen
+.PHONY: controller-gen
+controller-gen: ## Download controller-gen locally if necessary.
+ $(call go-get-tool,$(CONTROLLER_GEN),sigs.k8s.io/controller-tools/cmd/controller-gen@v0.7.0)
+
+KUSTOMIZE = $(shell pwd)/bin/kustomize
+.PHONY: kustomize
+kustomize: ## Download kustomize locally if necessary.
+ $(call go-get-tool,$(KUSTOMIZE),sigs.k8s.io/kustomize/kustomize/v3@v3.8.7)
+
+ENVTEST = $(shell pwd)/bin/setup-envtest
+.PHONY: envtest
+envtest: ## Download envtest-setup locally if necessary.
+ $(call go-get-tool,$(ENVTEST),sigs.k8s.io/controller-runtime/tools/setup-envtest@latest)
+
+# go-get-tool will 'go get' any package $2 and install it to $1.
+PROJECT_DIR := $(shell dirname $(abspath $(lastword $(MAKEFILE_LIST))))
+define go-get-tool
+@[ -f $(1) ] || { \
+set -e ;\
+TMP_DIR=$$(mktemp -d) ;\
+cd $$TMP_DIR ;\
+go mod init tmp ;\
+echo "Downloading $(2)" ;\
+GOBIN=$(PROJECT_DIR)/bin go get $(2) ;\
+rm -rf $$TMP_DIR ;\
+}
+endef
diff --git a/operator/PROJECT b/operator/PROJECT
new file mode 100644
index 000000000..067ff0f40
--- /dev/null
+++ b/operator/PROJECT
@@ -0,0 +1,16 @@
+domain: k8s.io
+layout:
+- go.kubebuilder.io/v3
+projectName: flapp-operator
+repo: fedlearner.net
+resources:
+- api:
+ crdVersion: v1
+ namespaced: true
+ controller: true
+ domain: k8s.io
+ group: fedlearner
+ kind: FedApp
+ path: fedlearner.net/operator/api/v1alpha1
+ version: v1alpha1
+version: "3"
diff --git a/operator/README.md b/operator/README.md
new file mode 100644
index 000000000..16ba93d46
--- /dev/null
+++ b/operator/README.md
@@ -0,0 +1,32 @@
+# Generate yamls
+`make manifests`
+
+To generate yamls in ./config, such as ./config/crd/bases and ./config/rbac
+
+在生成后需要在annotation中添加`api-approved.kubernetes.io: https://github.com/kubernetes/kubernetes/pull/78458`来避免k8s报警
+# Test Controller locally
+`make install `
+
+To install Crd and RBAC in your cluster which specify in your .kube config.
+
+
+`make run`
+
+Local run a Controller in your terminal which watch and update resources in the cluster of .kube.
+
+# Run in cluster
+`make docker-build docker-push IMG=/:tag`
+`make deploy IMG=/:tag`
+
+# Integration test
+`make test`
+
+# 后续开发
+框架相关文档:https://book.kubebuilder.io/
+
+仅需关注与修改/api/v1alpha1/fedapp_types.go(定义)和/contollers.fedapp_controller.go(控制逻辑)即可。
+
+后续增加CRD和对应的Controller均统一在此目录(Project)下,参考文档中指令即可添加新的CRD脚手架。
+
+# 集群依赖
+- 0.1.2 版本以上使用了 非headless service,所以如果想要运行tensorflow,需要集群开启hairpin mode。
\ No newline at end of file
diff --git a/operator/api/v1alpha1/fedapp_types.go b/operator/api/v1alpha1/fedapp_types.go
new file mode 100644
index 000000000..fbfc74c03
--- /dev/null
+++ b/operator/api/v1alpha1/fedapp_types.go
@@ -0,0 +1,185 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package v1alpha1
+
+import (
+ corev1 "k8s.io/api/core/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+)
+
+// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
+// NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized.
+
+// FedReplicaType can be any string.
+type FedReplicaType string
+
+// RestartPolicy describes how the replicas should be restarted.
+// Only one of the following restart policies may be specified.
+// If none of the following policies is specified, the default one
+// is RestartPolicyOnFailure.
+type RestartPolicy string
+
+const (
+ RestartPolicyAlways RestartPolicy = "Always"
+ RestartPolicyOnFailure RestartPolicy = "OnFailure"
+ RestartPolicyNever RestartPolicy = "Never"
+ // RestartPolicyExitCode policy means that user should add exit code by themselves,
+ // The controller will check these exit codes to
+ // determine the behavior when an error occurs:
+ // - 1-127: permanent error, do not restart.
+ // - 128-255: retryable error, will restart the pod.
+ RestartPolicyExitCode RestartPolicy = "ExitCode"
+)
+
+// ReplicaSpec is a description of the replica
+type ReplicaSpec struct {
+ // Replicas is the desired number of replicas of the given template.
+ // +kubebuilder:default=0
+ // +kubebuilder:validation:Maximum=200
+ // +kubebuilder:validation:Minimum=0
+ Replicas *int64 `json:"replicas,omitempty"`
+
+ // Template is the object that describes the pod that
+ // will be created for this replica.
+ Template corev1.PodTemplateSpec `json:"template,omitempty"`
+
+ // Restart policy for all replicas within the app.
+ // One of Always, OnFailure, Never and ExitCode.
+ // +kubebuilder:default=OnFailure
+ RestartPolicy RestartPolicy `json:"restartPolicy,omitempty"`
+
+ // Optional number of retries before marking this job failed.
+ // +kubebuilder:default=1
+ // +kubebuilder:validation:Maximum=100
+ // +kubebuilder:validation:Minimum=1
+ BackoffLimit *int64 `json:"backoffLimit,omitempty"`
+
+ // Whether all pods of this replica are suceeded is necessary for marking the falpp as complete.
+ // +kubebuilder:default=true
+ MustSuccess *bool `json:"mustSuccess,omitempty"`
+
+ // +kubebuilder:default:={"containerPort": 50051, "name": "flapp-port", "protocol": "TCP"}
+ Port *corev1.ContainerPort `json:"port,omitempty"`
+}
+
+// FedReplicaSpecs is the mapping from FedReplicaType to ReplicaSpec
+type FedReplicaSpecs map[FedReplicaType]ReplicaSpec
+
+// FedAppSpec defines the desired state of FedApp
+type FedAppSpec struct {
+ // INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
+ // Important: Run "make" to regenerate code after modifying this file
+ // Defines replica spec for replica type
+ FedReplicaSpecs FedReplicaSpecs `json:"fedReplicaSpecs"`
+
+ // TTLSecondsAfterFinished is the TTL to clean up jobs.
+ // It may take extra ReconcilePeriod seconds for the cleanup, since
+ // reconcile gets called periodically.
+ // Default to 86400(one day).
+ // +kubebuilder:default=86400
+ // +optional
+ TTLSecondsAfterFinished *int64 `json:"ttlSecondsAfterFinished,omitempty"`
+
+ // Specifies the duration in seconds relative to the startTime that the job may be active
+ // before the system tries to terminate it; value must be positive integer.
+ // +optional
+ ActiveDeadlineSeconds *int64 `json:"activeDeadlineSeconds,omitempty"`
+}
+
+// FedAppStatus defines the observed state of FedApp
+type FedAppStatus struct {
+ // INSERT ADDITIONAL STATUS FIELD - define observed state of cluster
+ // Important: Run "make" to regenerate code after modifying this file
+
+ // +optional
+ StartTime *metav1.Time `json:"startTime"`
+
+ Conditions []FedAppCondition `json:"conditions,omitempty"`
+
+ // Record pods name which have terminated, hack for too fast pod GC.
+ // TODO: when pods gc collection is too fast that fedapp controller
+ // dont have enough time to record them in TerminatedPodsMap field,
+ // use finalizer to avoid it.
+ // +optional
+ TerminatedPodsMap TerminatedPodsMap `json:"terminatedPodsMap,omitempty"`
+}
+type empty struct{}
+type PodSet map[string]empty
+type TerminatedPodsMap map[FedReplicaType]*TerminatedPods
+
+// TerminatedPods holds name of Pods that have terminated.
+type TerminatedPods struct {
+ // Succeeded holds name of succeeded Pods.
+ // +optional
+ Succeeded []PodSet `json:"succeeded,omitempty"`
+
+ // Failed holds name of failed Pods.
+ // +optional
+ Failed []PodSet `json:"failed,omitempty"`
+}
+
+// FedAppConditionType is a valid value for FedAppCondition.Type
+type FedAppConditionType string
+
+// These are valid conditions of a job.
+const (
+ // FedAppComplete means the job has completed its execution.
+ // true: completed, false: failed, unknown: running
+ Succeeded FedAppConditionType = "succeeded"
+)
+
+// FedAppCondition describes current state of a job.
+type FedAppCondition struct {
+ // Type of job condition.
+ Type FedAppConditionType `json:"type"`
+ // Status of the condition, one of True, False, Unknown.
+ Status corev1.ConditionStatus `json:"status"`
+
+ // Last time the condition transit from one status to another.
+ // +optional
+ LastTransitionTime metav1.Time `json:"lastTransitionTime"`
+ // (brief) reason for the condition's last transition.
+ // +optional
+ Reason string `json:"reason"`
+ // Human readable message indicating details about last transition.
+ // +optional
+ Message string `json:"message"`
+}
+
+//+kubebuilder:object:root=true
+//+kubebuilder:subresource:status
+
+// FedApp is the Schema for the fedapps API
+type FedApp struct {
+ metav1.TypeMeta `json:",inline"`
+ metav1.ObjectMeta `json:"metadata,omitempty"`
+
+ Spec FedAppSpec `json:"spec,omitempty"`
+ Status FedAppStatus `json:"status,omitempty"`
+}
+
+//+kubebuilder:object:root=true
+
+// FedAppList contains a list of FedApp
+type FedAppList struct {
+ metav1.TypeMeta `json:",inline"`
+ metav1.ListMeta `json:"metadata,omitempty"`
+ Items []FedApp `json:"items"`
+}
+
+func init() {
+ SchemeBuilder.Register(&FedApp{}, &FedAppList{})
+}
diff --git a/operator/api/v1alpha1/groupversion_info.go b/operator/api/v1alpha1/groupversion_info.go
new file mode 100644
index 000000000..d39699efd
--- /dev/null
+++ b/operator/api/v1alpha1/groupversion_info.go
@@ -0,0 +1,35 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Package v1alpha1 contains API Schema definitions for the fedlearner v1alpha1 API group
+//+kubebuilder:object:generate=true
+//+groupName=fedlearner.k8s.io
+package v1alpha1
+
+import (
+ "k8s.io/apimachinery/pkg/runtime/schema"
+ "sigs.k8s.io/controller-runtime/pkg/scheme"
+)
+
+var (
+ // GroupVersion is group version used to register these objects
+ GroupVersion = schema.GroupVersion{Group: "fedlearner.k8s.io", Version: "v1alpha1"}
+
+ // SchemeBuilder is used to add go types to the GroupVersionKind scheme
+ SchemeBuilder = &scheme.Builder{GroupVersion: GroupVersion}
+
+ // AddToScheme adds the types in this group-version to the given scheme.
+ AddToScheme = SchemeBuilder.AddToScheme
+)
diff --git a/operator/api/v1alpha1/zz_generated.deepcopy.go b/operator/api/v1alpha1/zz_generated.deepcopy.go
new file mode 100644
index 000000000..9330b4627
--- /dev/null
+++ b/operator/api/v1alpha1/zz_generated.deepcopy.go
@@ -0,0 +1,322 @@
+//go:build !ignore_autogenerated
+// +build !ignore_autogenerated
+
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Code generated by controller-gen. DO NOT EDIT.
+
+package v1alpha1
+
+import (
+ "k8s.io/api/core/v1"
+ runtime "k8s.io/apimachinery/pkg/runtime"
+)
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *FedApp) DeepCopyInto(out *FedApp) {
+ *out = *in
+ out.TypeMeta = in.TypeMeta
+ in.ObjectMeta.DeepCopyInto(&out.ObjectMeta)
+ in.Spec.DeepCopyInto(&out.Spec)
+ in.Status.DeepCopyInto(&out.Status)
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FedApp.
+func (in *FedApp) DeepCopy() *FedApp {
+ if in == nil {
+ return nil
+ }
+ out := new(FedApp)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
+func (in *FedApp) DeepCopyObject() runtime.Object {
+ if c := in.DeepCopy(); c != nil {
+ return c
+ }
+ return nil
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *FedAppCondition) DeepCopyInto(out *FedAppCondition) {
+ *out = *in
+ in.LastTransitionTime.DeepCopyInto(&out.LastTransitionTime)
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FedAppCondition.
+func (in *FedAppCondition) DeepCopy() *FedAppCondition {
+ if in == nil {
+ return nil
+ }
+ out := new(FedAppCondition)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *FedAppList) DeepCopyInto(out *FedAppList) {
+ *out = *in
+ out.TypeMeta = in.TypeMeta
+ in.ListMeta.DeepCopyInto(&out.ListMeta)
+ if in.Items != nil {
+ in, out := &in.Items, &out.Items
+ *out = make([]FedApp, len(*in))
+ for i := range *in {
+ (*in)[i].DeepCopyInto(&(*out)[i])
+ }
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FedAppList.
+func (in *FedAppList) DeepCopy() *FedAppList {
+ if in == nil {
+ return nil
+ }
+ out := new(FedAppList)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
+func (in *FedAppList) DeepCopyObject() runtime.Object {
+ if c := in.DeepCopy(); c != nil {
+ return c
+ }
+ return nil
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *FedAppSpec) DeepCopyInto(out *FedAppSpec) {
+ *out = *in
+ if in.FedReplicaSpecs != nil {
+ in, out := &in.FedReplicaSpecs, &out.FedReplicaSpecs
+ *out = make(FedReplicaSpecs, len(*in))
+ for key, val := range *in {
+ (*out)[key] = *val.DeepCopy()
+ }
+ }
+ if in.TTLSecondsAfterFinished != nil {
+ in, out := &in.TTLSecondsAfterFinished, &out.TTLSecondsAfterFinished
+ *out = new(int64)
+ **out = **in
+ }
+ if in.ActiveDeadlineSeconds != nil {
+ in, out := &in.ActiveDeadlineSeconds, &out.ActiveDeadlineSeconds
+ *out = new(int64)
+ **out = **in
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FedAppSpec.
+func (in *FedAppSpec) DeepCopy() *FedAppSpec {
+ if in == nil {
+ return nil
+ }
+ out := new(FedAppSpec)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *FedAppStatus) DeepCopyInto(out *FedAppStatus) {
+ *out = *in
+ if in.StartTime != nil {
+ in, out := &in.StartTime, &out.StartTime
+ *out = (*in).DeepCopy()
+ }
+ if in.Conditions != nil {
+ in, out := &in.Conditions, &out.Conditions
+ *out = make([]FedAppCondition, len(*in))
+ for i := range *in {
+ (*in)[i].DeepCopyInto(&(*out)[i])
+ }
+ }
+ if in.TerminatedPodsMap != nil {
+ in, out := &in.TerminatedPodsMap, &out.TerminatedPodsMap
+ *out = make(TerminatedPodsMap, len(*in))
+ for key, val := range *in {
+ var outVal *TerminatedPods
+ if val == nil {
+ (*out)[key] = nil
+ } else {
+ in, out := &val, &outVal
+ *out = new(TerminatedPods)
+ (*in).DeepCopyInto(*out)
+ }
+ (*out)[key] = outVal
+ }
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FedAppStatus.
+func (in *FedAppStatus) DeepCopy() *FedAppStatus {
+ if in == nil {
+ return nil
+ }
+ out := new(FedAppStatus)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in FedReplicaSpecs) DeepCopyInto(out *FedReplicaSpecs) {
+ {
+ in := &in
+ *out = make(FedReplicaSpecs, len(*in))
+ for key, val := range *in {
+ (*out)[key] = *val.DeepCopy()
+ }
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new FedReplicaSpecs.
+func (in FedReplicaSpecs) DeepCopy() FedReplicaSpecs {
+ if in == nil {
+ return nil
+ }
+ out := new(FedReplicaSpecs)
+ in.DeepCopyInto(out)
+ return *out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in PodSet) DeepCopyInto(out *PodSet) {
+ {
+ in := &in
+ *out = make(PodSet, len(*in))
+ for key, val := range *in {
+ (*out)[key] = val
+ }
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PodSet.
+func (in PodSet) DeepCopy() PodSet {
+ if in == nil {
+ return nil
+ }
+ out := new(PodSet)
+ in.DeepCopyInto(out)
+ return *out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *ReplicaSpec) DeepCopyInto(out *ReplicaSpec) {
+ *out = *in
+ if in.Replicas != nil {
+ in, out := &in.Replicas, &out.Replicas
+ *out = new(int64)
+ **out = **in
+ }
+ in.Template.DeepCopyInto(&out.Template)
+ if in.BackoffLimit != nil {
+ in, out := &in.BackoffLimit, &out.BackoffLimit
+ *out = new(int64)
+ **out = **in
+ }
+ if in.MustSuccess != nil {
+ in, out := &in.MustSuccess, &out.MustSuccess
+ *out = new(bool)
+ **out = **in
+ }
+ if in.Port != nil {
+ in, out := &in.Port, &out.Port
+ *out = new(v1.ContainerPort)
+ **out = **in
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ReplicaSpec.
+func (in *ReplicaSpec) DeepCopy() *ReplicaSpec {
+ if in == nil {
+ return nil
+ }
+ out := new(ReplicaSpec)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *TerminatedPods) DeepCopyInto(out *TerminatedPods) {
+ *out = *in
+ if in.Succeeded != nil {
+ in, out := &in.Succeeded, &out.Succeeded
+ *out = make([]PodSet, len(*in))
+ for i := range *in {
+ if (*in)[i] != nil {
+ in, out := &(*in)[i], &(*out)[i]
+ *out = make(PodSet, len(*in))
+ for key, val := range *in {
+ (*out)[key] = val
+ }
+ }
+ }
+ }
+ if in.Failed != nil {
+ in, out := &in.Failed, &out.Failed
+ *out = make([]PodSet, len(*in))
+ for i := range *in {
+ if (*in)[i] != nil {
+ in, out := &(*in)[i], &(*out)[i]
+ *out = make(PodSet, len(*in))
+ for key, val := range *in {
+ (*out)[key] = val
+ }
+ }
+ }
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TerminatedPods.
+func (in *TerminatedPods) DeepCopy() *TerminatedPods {
+ if in == nil {
+ return nil
+ }
+ out := new(TerminatedPods)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in TerminatedPodsMap) DeepCopyInto(out *TerminatedPodsMap) {
+ {
+ in := &in
+ *out = make(TerminatedPodsMap, len(*in))
+ for key, val := range *in {
+ var outVal *TerminatedPods
+ if val == nil {
+ (*out)[key] = nil
+ } else {
+ in, out := &val, &outVal
+ *out = new(TerminatedPods)
+ (*in).DeepCopyInto(*out)
+ }
+ (*out)[key] = outVal
+ }
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TerminatedPodsMap.
+func (in TerminatedPodsMap) DeepCopy() TerminatedPodsMap {
+ if in == nil {
+ return nil
+ }
+ out := new(TerminatedPodsMap)
+ in.DeepCopyInto(out)
+ return *out
+}
diff --git a/operator/controllers/cluster_spec.go b/operator/controllers/cluster_spec.go
new file mode 100644
index 000000000..a30dc07aa
--- /dev/null
+++ b/operator/controllers/cluster_spec.go
@@ -0,0 +1,59 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "strings"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+)
+
+const (
+ ServiceFormat = "%s.%s.svc"
+)
+
+type ClusterSpec struct {
+ Services map[fedlearnerv2.FedReplicaType][]string `json:"clusterSpec"`
+}
+
+func GenIndexName(appName string, rt string, index int) string {
+ n := appName + "-" + rt + "-" + strconv.Itoa(index)
+ return strings.Replace(n, "/", "-", -1)
+}
+
+func NewClusterSpec(namespace string, app *fedlearnerv2.FedApp) ClusterSpec {
+ clusterSpec := ClusterSpec{
+ Services: make(map[fedlearnerv2.FedReplicaType][]string),
+ }
+ for rtype, spec := range app.Spec.FedReplicaSpecs {
+ rt := strings.ToLower(string(rtype))
+ replicas := int(*spec.Replicas)
+ port := spec.Port.ContainerPort
+
+ for index := 0; index < replicas; index++ {
+ serviceName := fmt.Sprintf(ServiceFormat, GenIndexName(app.Name, rt, index), namespace)
+ clusterSpec.Services[rtype] = append(clusterSpec.Services[rtype], fmt.Sprintf("%s:%d", serviceName, port))
+ }
+ }
+ return clusterSpec
+}
+
+func (cs ClusterSpec) Marshal() ([]byte, error) {
+ return json.Marshal(cs)
+}
diff --git a/operator/controllers/fedapp_controller.go b/operator/controllers/fedapp_controller.go
new file mode 100644
index 000000000..d83033166
--- /dev/null
+++ b/operator/controllers/fedapp_controller.go
@@ -0,0 +1,346 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "context"
+ "time"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ v1 "k8s.io/api/core/v1"
+ networking "k8s.io/api/networking/v1beta1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/runtime"
+ ctrl "sigs.k8s.io/controller-runtime"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+ "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+// FedAppReconciler reconciles a FedApp object
+type FedAppReconciler struct {
+ client.Client
+ Scheme *runtime.Scheme
+}
+
+const (
+ flReplicaTypeLabel = "fl-replica-type"
+ flReplicaIndexLabel = "fl-replica-index"
+ AppNameLabel = "app-name"
+
+ // Env key in pod
+ serviceID = "SERVICE_ID"
+ clusterSpec = "CLUSTER_SPEC"
+ replicaIndex = "INDEX"
+)
+
+//+kubebuilder:rbac:groups=fedlearner.k8s.io,resources=fedapps,verbs=get;list;watch;create;update;patch;delete
+//+kubebuilder:rbac:groups=fedlearner.k8s.io,resources=fedapps/status,verbs=get;update;patch
+//+kubebuilder:rbac:groups=fedlearner.k8s.io,resources=fedapps/finalizers,verbs=update
+//+kubebuilder:rbac:groups=core,resources=pods,verbs=get;list;watch;create;delete
+//+kubebuilder:rbac:groups=core,resources=services,verbs=get;list;create;delete
+//+kubebuilder:rbac:groups=networking,resources=ingress,verbs=get;create;delete
+
+// Reconcile is part of the main kubernetes reconciliation loop which aims to
+// move the current state of the cluster closer to the desired state.
+// the FedApp object against the actual cluster state, and then
+// perform operations to make the cluster state reflect the state specified by
+// the user.
+//
+// For more details, check Reconcile and its Result here:
+// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.10.0/pkg/reconcile
+func (r *FedAppReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
+ log := log.FromContext(ctx)
+
+ startTime := time.Now()
+ defer func() {
+ log.Info("Finished syncing job", req.NamespacedName.Name, time.Since(startTime).String())
+ }()
+ var app fedlearnerv2.FedApp
+ if err := r.Get(ctx, req.NamespacedName, &app); err != nil {
+ log.Info("unable to fetch FedApp")
+ // we'll ignore not-found errors, since they can't be fixed by an immediate
+ // requeue (we'll need to wait for a new notification), and we can get them
+ // on deleted requests.
+ return ctrl.Result{}, client.IgnoreNotFound(err)
+ }
+
+ // if job was finished previously, we don't want to redo the termination
+ if isAppFinished(&app) {
+ // release all resource
+ // r.releaseAppResource(ctx, req)
+ // Complete ttl check
+ if app.Spec.TTLSecondsAfterFinished != nil {
+ now := metav1.Now()
+ var finished time.Time
+ for _, c := range app.Status.Conditions {
+ if c.Type == fedlearnerv2.Succeeded {
+ finished = c.LastTransitionTime.Time
+ break
+ }
+ }
+ duration := now.Time.Sub(finished)
+ allowedDuration := time.Duration(*app.Spec.TTLSecondsAfterFinished) * time.Second
+ if duration >= allowedDuration {
+ log.Info("FedApp TTLSecondsAfterFinished terminating")
+ if err := r.Delete(ctx, &app); client.IgnoreNotFound(err) != nil {
+ return ctrl.Result{}, err
+ }
+ return ctrl.Result{}, nil
+ }
+ return ctrl.Result{RequeueAfter: time.Duration(*app.Spec.TTLSecondsAfterFinished) * time.Second}, nil
+
+ }
+ return ctrl.Result{}, nil
+ }
+
+ if app.Status.TerminatedPodsMap == nil {
+ app.Status.TerminatedPodsMap = InitTerminatedPodsMap(app)
+ }
+
+ if app.Status.StartTime == nil {
+ // Check if pods of last execution have all been deleted.
+ var childPods v1.PodList
+ if err := r.List(ctx, &childPods, client.InNamespace(req.Namespace), client.MatchingFields{ownerKey: req.Name}); err != nil {
+ log.Error(err, "unable to list child Pods")
+ return ctrl.Result{}, err
+ }
+ if len(childPods.Items) > 0 {
+ log.Info("Delete all pods for last Execution.")
+ for _, pod := range childPods.Items {
+ log.Info("Delete Pod", "Pod", pod.Name)
+ if err := r.Delete(ctx, &pod, client.PropagationPolicy(metav1.DeletePropagationBackground)); client.IgnoreNotFound(err) != nil {
+ log.Error(err, "Failed to delete pods")
+ return ctrl.Result{}, err
+ }
+ }
+
+ return ctrl.Result{}, nil
+ }
+ now := metav1.Now()
+ app.Status.StartTime = &now
+ if err := r.Status().Update(ctx, &app); err != nil {
+ log.Error(err, "unable to update FedApp status StartTime")
+ return ctrl.Result{}, err
+ }
+
+ // enqueue a sync to check if job past ActiveDeadlineSeconds
+ if app.Spec.ActiveDeadlineSeconds != nil {
+ log.Info("FedApp has ActiveDeadlineSeconds will sync after", "ActiveDeadlineSeconds", *app.Spec.ActiveDeadlineSeconds)
+ return ctrl.Result{RequeueAfter: time.Duration(*app.Spec.ActiveDeadlineSeconds) * time.Second}, nil
+ }
+ return ctrl.Result{}, nil
+ }
+ if app.Spec.ActiveDeadlineSeconds != nil && app.Status.StartTime != nil {
+ now := metav1.Now()
+ start := app.Status.StartTime.Time
+ duration := now.Time.Sub(start)
+ allowedDuration := time.Duration(*app.Spec.ActiveDeadlineSeconds) * time.Second
+ if duration >= allowedDuration {
+ log.Info("FedApp has running exceeding activeDeadlineSeconds")
+ app.Status.Conditions, _ = ensureConditionStatus(app.Status.Conditions, fedlearnerv2.Succeeded, v1.ConditionFalse, "DeadlineExceeded", "FedApp was active longer than specified deadline")
+ if err := r.Status().Update(ctx, &app); err != nil {
+ log.Error(err, "unable to update FedApp status DeadlineExceeded")
+ return ctrl.Result{}, err
+ }
+ // Release the resource when next reconcile.
+ // Can not release resources here synchronously, because the status update request's response can't insure the status has been updated in etcd.
+ // And if delete the pods synchoronously, k8s cant't promise the delete requeset is behind the status update request, so the next reconcile may create new pods.
+ return ctrl.Result{}, nil
+ }
+
+ }
+
+ // sync service
+ if err := r.syncServices(ctx, &app); err != nil {
+ log.Error(err, "unable to sync service")
+ return ctrl.Result{}, err
+ }
+ // sync ingress
+ if err := r.syncIngress(ctx, &app); err != nil {
+ log.Error(err, "unable to sync ingress")
+ return ctrl.Result{}, err
+ }
+ // sync pod
+ completed := true
+ var childPods v1.PodList
+ if err := r.List(ctx, &childPods, client.InNamespace(req.Namespace), client.MatchingFields{ownerKey: req.Name}); err != nil {
+ log.Error(err, "unable to list child Pods")
+ return ctrl.Result{}, err
+ }
+ for rtype, spec := range app.Spec.FedReplicaSpecs {
+ replicaResult, err := r.SyncReplicas(ctx, &app, rtype, &childPods, &spec)
+ if replicaResult.isFailed {
+ log.Info("FedApp failed")
+ // Dont's clean resource synchronously, becase we must wait for update request finished inorder to keep reconcile idempotent.
+ return ctrl.Result{}, nil
+ }
+ if err != nil {
+ return ctrl.Result{}, err
+ }
+
+ completed = completed && replicaResult.isCompleted
+
+ }
+ if completed {
+ app.Status.Conditions, _ = ensureConditionStatus(app.Status.Conditions, fedlearnerv2.Succeeded, v1.ConditionTrue, "Completed", "")
+ if err := r.Status().Update(ctx, &app); err != nil {
+ log.Error(err, "unable to update FedApp status Completed")
+ return ctrl.Result{}, err
+ }
+ }
+ if err := r.Status().Update(ctx, &app); err != nil {
+ log.Error(err, "unable to update FedApp status when reconcile finished")
+ return ctrl.Result{}, err
+ }
+ return ctrl.Result{}, nil
+}
+
+type ReplicaResult struct {
+ isFailed bool
+ isCompleted bool
+}
+
+// ensureJobConditionStatus appends or updates an existing job condition of the
+// given type with the given status value.The function returns a bool to let the
+// caller know if the list was changed (either appended or updated).
+func ensureConditionStatus(list []fedlearnerv2.FedAppCondition, cType fedlearnerv2.FedAppConditionType, status v1.ConditionStatus, reason, message string) ([]fedlearnerv2.FedAppCondition, bool) {
+ for i := range list {
+ if list[i].Type == cType {
+ if list[i].Status != status || list[i].Reason != reason || list[i].Message != message {
+ list[i].Status = status
+ list[i].LastTransitionTime = metav1.Now()
+ list[i].Reason = reason
+ list[i].Message = message
+ return list, true
+ }
+ return list, false
+ }
+ }
+
+ return append(list, *newCondition(cType, status, reason, message)), true
+}
+
+func newCondition(conditionType fedlearnerv2.FedAppConditionType, status v1.ConditionStatus, reason, message string) *fedlearnerv2.FedAppCondition {
+ return &fedlearnerv2.FedAppCondition{
+ Type: conditionType,
+ Status: status,
+ LastTransitionTime: metav1.Now(),
+ Reason: reason,
+ Message: message,
+ }
+}
+
+// IsAppFinished checks whether the given fedapp has finished execution.
+// It does not discriminate between successful and failed terminations.
+func isAppFinished(app *fedlearnerv2.FedApp) bool {
+ for _, c := range app.Status.Conditions {
+ if c.Type == fedlearnerv2.Succeeded && (c.Status == v1.ConditionTrue || c.Status == v1.ConditionFalse) {
+ return true
+ }
+ }
+ return false
+}
+
+func (r *FedAppReconciler) releaseAppResource(ctx context.Context, req ctrl.Request) error {
+ log := log.FromContext(ctx)
+ var ingress networking.Ingress
+ err := r.Get(ctx, req.NamespacedName, &ingress)
+ if client.IgnoreNotFound(err) != nil {
+ log.Error(err, "Get Ingress failed")
+ return err
+ }
+ if err == nil {
+ log.Info("Delete Ingress")
+ if err := r.Delete(ctx, &ingress, client.PropagationPolicy(metav1.DeletePropagationBackground)); client.IgnoreNotFound(err) != nil {
+ log.Error(err, "Delete Ingress failed")
+ return err
+ }
+ }
+
+ var childPods v1.PodList
+ if err := r.List(ctx, &childPods, client.InNamespace(req.Namespace), client.MatchingFields{ownerKey: req.Name}); err != nil {
+ log.Error(err, "unable to list child Pods")
+ return err
+ }
+ for _, pod := range childPods.Items {
+ log.Info("Delete Pod", "Pod", pod.Name)
+ if err := r.Delete(ctx, &pod, client.PropagationPolicy(metav1.DeletePropagationBackground)); client.IgnoreNotFound(err) != nil {
+ log.Error(err, "Failed to delete pods")
+ return err
+ }
+ }
+ var childServices v1.ServiceList
+ if err := r.List(ctx, &childServices, client.InNamespace(req.Namespace), client.MatchingFields{ownerKey: req.Name}); err != nil {
+ log.Error(err, "unable to list child Pods")
+ return err
+ }
+ for _, service := range childServices.Items {
+ log.Info("Delete Service", "Pod", service.Name)
+ if err := r.Delete(ctx, &service, client.PropagationPolicy(metav1.DeletePropagationBackground)); client.IgnoreNotFound(err) != nil {
+ log.Error(err, "Failed to delete pods")
+ return err
+ }
+ }
+ return nil
+}
+
+var (
+ ownerKey = ".metadata.controller"
+ apiGVStr = fedlearnerv2.GroupVersion.String()
+)
+
+// SetupWithManager sets up the controller with the Manager.
+func (r *FedAppReconciler) SetupWithManager(mgr ctrl.Manager) error {
+ // For a more efficient lookup, these Pods and Service will be indexed locally on their controller(FedApp)'s name.
+ if err := mgr.GetFieldIndexer().IndexField(context.Background(), &v1.Pod{}, ownerKey, func(rawObj client.Object) []string {
+ // grab the job object, extract the owner...
+ pod := rawObj.(*v1.Pod)
+ owner := metav1.GetControllerOf(pod)
+ if owner == nil {
+ return nil
+ }
+ // ...make sure it's a FedApp...
+ if owner.APIVersion != apiGVStr || owner.Kind != "FedApp" {
+ return nil
+ }
+
+ // ...and if so, return it
+ return []string{owner.Name}
+ }); err != nil {
+ return err
+ }
+ if err := mgr.GetFieldIndexer().IndexField(context.Background(), &v1.Service{}, ownerKey, func(rawObj client.Object) []string {
+ // grab the job object, extract the owner...
+ service := rawObj.(*v1.Service)
+ owner := metav1.GetControllerOf(service)
+ if owner == nil {
+ return nil
+ }
+ // ...make sure it's a FedApp...
+ if owner.APIVersion != apiGVStr || owner.Kind != "FedApp" {
+ return nil
+ }
+
+ // ...and if so, return it
+ return []string{owner.Name}
+ }); err != nil {
+ return err
+ }
+ return ctrl.NewControllerManagedBy(mgr).
+ For(&fedlearnerv2.FedApp{}).
+ Owns(&v1.Pod{}).
+ Complete(r)
+}
diff --git a/operator/controllers/fedapp_controller_test.go b/operator/controllers/fedapp_controller_test.go
new file mode 100644
index 000000000..cd33e07f8
--- /dev/null
+++ b/operator/controllers/fedapp_controller_test.go
@@ -0,0 +1,154 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "context"
+ "time"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+ corev1 "k8s.io/api/core/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/types"
+)
+
+var _ = Describe("Fedapp controller", func() {
+
+ // Define utility constants for object names and testing timeouts/durations and intervals.
+ const (
+ FedAppName = "test-fedapp"
+ FailedAppName = "failed-test-fedapp"
+ FedAppNamespace = "default"
+ FedReplicaType = "Worker"
+ timeout = time.Second * 10
+ interval = time.Millisecond * 250
+ )
+ var replicas int64 = 2
+ Context("When updating FedApp Status", func() {
+ It("Should FedApp created successfully", func() {
+ By("By creating a new Fedapp")
+ ctx := context.Background()
+ fedapp := &fedlearnerv2.FedApp{
+ TypeMeta: metav1.TypeMeta{
+ APIVersion: "fedlearner.k8s.io/v1alpha1",
+ Kind: "FedApp",
+ },
+ ObjectMeta: metav1.ObjectMeta{
+ Name: FedAppName,
+ Namespace: FedAppNamespace,
+ },
+ Spec: fedlearnerv2.FedAppSpec{
+ FedReplicaSpecs: fedlearnerv2.FedReplicaSpecs{
+ FedReplicaType: fedlearnerv2.ReplicaSpec{
+ Replicas: &replicas,
+ Template: corev1.PodTemplateSpec{
+ Spec: corev1.PodSpec{
+ Containers: []corev1.Container{
+ {
+ Name: "test-container",
+ Image: "test-image",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ Expect(k8sClient.Create(ctx, fedapp)).Should(Succeed())
+ fedappLookupKey := types.NamespacedName{Name: FedAppName, Namespace: FedAppNamespace}
+ createdFedApp := &fedlearnerv2.FedApp{}
+
+ // We'll need to retry getting this newly created FedApp, given that creation may not immediately happen.
+ Eventually(func() bool {
+ err := k8sClient.Get(ctx, fedappLookupKey, createdFedApp)
+ if err == nil && createdFedApp.Status.StartTime != nil {
+ return true
+ }
+ return false
+ }, timeout, interval).Should(BeTrue(), "should have startTime in the status")
+ // Let's make sure our Schedule string value was properly converted/handled.
+ Expect(createdFedApp.Spec.FedReplicaSpecs[FedReplicaType].Port.ContainerPort).Should(Equal(int32(50051)))
+ By("By checking Pod Service and Ingress created succcessfully")
+ var childPods corev1.PodList
+ Eventually(func() (int, error) {
+ err := k8sClient.List(ctx, &childPods)
+ if err != nil {
+ return -1, err
+ }
+ return len(childPods.Items), nil
+ }, timeout, interval).Should(Equal(2), "should create pods")
+
+ })
+ It("Should FedApp create pod failed", func() {
+ By("By creating a new Fedapp")
+ ctx := context.Background()
+ fedapp := &fedlearnerv2.FedApp{
+ TypeMeta: metav1.TypeMeta{
+ APIVersion: "fedlearner.k8s.io/v1alpha1",
+ Kind: "FedApp",
+ },
+ ObjectMeta: metav1.ObjectMeta{
+ Name: FailedAppName,
+ Namespace: FedAppNamespace,
+ },
+ Spec: fedlearnerv2.FedAppSpec{
+ FedReplicaSpecs: fedlearnerv2.FedReplicaSpecs{
+ FedReplicaType: fedlearnerv2.ReplicaSpec{
+ Replicas: &replicas,
+ Template: corev1.PodTemplateSpec{
+ Spec: corev1.PodSpec{
+ Containers: []corev1.Container{
+ {
+ Name: "test-container",
+ Image: " failed-image",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ Expect(k8sClient.Create(ctx, fedapp)).Should(Succeed())
+ fedappLookupKey := types.NamespacedName{Name: FailedAppName, Namespace: FedAppNamespace}
+ createdFedApp := &fedlearnerv2.FedApp{}
+
+ // We'll need to retry getting this newly created FedApp, given that creation may not immediately happen.
+ Eventually(func() bool {
+ err := k8sClient.Get(ctx, fedappLookupKey, createdFedApp)
+ if err == nil {
+ conditions := createdFedApp.Status.Conditions
+ for i := range conditions {
+ if conditions[i].Type != fedlearnerv2.Succeeded {
+ continue
+ }
+ if conditions[i].Status != corev1.ConditionFalse {
+ break
+ }
+ if conditions[i].Reason == "CreatePodFailed" {
+ return true
+ }
+ }
+ }
+ return false
+ }, timeout, interval).Should(BeTrue(), "should be failed for create pod failed")
+ })
+ })
+
+})
diff --git a/operator/controllers/ingress.go b/operator/controllers/ingress.go
new file mode 100644
index 000000000..7d5ce6b11
--- /dev/null
+++ b/operator/controllers/ingress.go
@@ -0,0 +1,91 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "context"
+ "strings"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ networking "k8s.io/api/networking/v1beta1"
+ "k8s.io/apimachinery/pkg/api/errors"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/types"
+ "k8s.io/apimachinery/pkg/util/intstr"
+ ctrl "sigs.k8s.io/controller-runtime"
+ "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+func (r *FedAppReconciler) syncIngress(ctx context.Context, app *fedlearnerv2.FedApp) error {
+ log := log.FromContext(ctx)
+ var ingress networking.Ingress
+ err := r.Get(ctx, types.NamespacedName{Name: app.Name, Namespace: app.Namespace}, &ingress)
+ if errors.IsNotFound(err) {
+ ingressName := app.Name
+ labels := GenLabels(app)
+ annotations := map[string]string{
+ //"kubernetes.io/ingress.class": ingressClassName,
+ "nginx.ingress.kubernetes.io/backend-protocol": "GRPC",
+ "nginx.ingress.kubernetes.io/configuration-snippet": "grpc_next_upstream_tries 5;",
+ "nginx.ingress.kubernetes.io/http2-insecure-port": "true",
+ }
+
+ newIngress := &networking.Ingress{
+ ObjectMeta: metav1.ObjectMeta{
+ Name: ingressName,
+ Namespace: app.Namespace,
+ Labels: labels,
+ Annotations: annotations,
+ },
+ // Explicitly set IngressClassName to nil for k8s backward compatibility
+ Spec: networking.IngressSpec{
+ IngressClassName: nil,
+ },
+ }
+ for rtype, spec := range app.Spec.FedReplicaSpecs {
+ replicas := int(*spec.Replicas)
+ rt := strings.ToLower(string(rtype))
+ for index := 0; index < replicas; index++ {
+ path := networking.HTTPIngressPath{
+ Backend: networking.IngressBackend{
+ ServiceName: GenIndexName(app.Name, rt, index),
+ ServicePort: intstr.FromString(spec.Port.Name),
+ },
+ }
+ host := GenIndexName(app.Name, rt, index) + IngressExtraHostSuffix
+ rule := networking.IngressRule{
+ Host: host,
+ IngressRuleValue: networking.IngressRuleValue{
+ HTTP: &networking.HTTPIngressRuleValue{
+ Paths: []networking.HTTPIngressPath{path},
+ },
+ },
+ }
+ newIngress.Spec.Rules = append(newIngress.Spec.Rules, rule)
+ }
+ }
+ if err := ctrl.SetControllerReference(app, newIngress, r.Scheme); err != nil {
+ return err
+ }
+ log.Info("Create Ingress", "Ingress", newIngress.Name)
+ err := r.Create(ctx, newIngress)
+ if err != nil && errors.IsAlreadyExists(err) {
+ return nil
+ }
+ return err
+ }
+ return err
+}
diff --git a/operator/controllers/options.go b/operator/controllers/options.go
new file mode 100644
index 000000000..e046ad613
--- /dev/null
+++ b/operator/controllers/options.go
@@ -0,0 +1,20 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+var (
+ IngressExtraHostSuffix string
+)
diff --git a/operator/controllers/pod.go b/operator/controllers/pod.go
new file mode 100644
index 000000000..217d08341
--- /dev/null
+++ b/operator/controllers/pod.go
@@ -0,0 +1,123 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "context"
+ "strconv"
+ "strings"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ v1 "k8s.io/api/core/v1"
+ ctrl "sigs.k8s.io/controller-runtime"
+ "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+func (r *FedAppReconciler) CreatePod(ctx context.Context, app *fedlearnerv2.FedApp, spec *fedlearnerv2.ReplicaSpec, index int, rt string, podSliceHasFailed int) error {
+ log := log.FromContext(ctx)
+ podTemplate := spec.Template.DeepCopy()
+ labels := GenLabels(app)
+ labels[flReplicaTypeLabel] = rt
+ labels[flReplicaIndexLabel] = strconv.Itoa(index)
+ podTemplate.Name = GenIndexName(app.Name, rt, index) + "-retry-" + strconv.Itoa(podSliceHasFailed)
+ podTemplate.Namespace = app.Namespace
+ if podTemplate.Labels == nil {
+ podTemplate.Labels = make(map[string]string)
+ }
+ for key, value := range labels {
+ podTemplate.Labels[key] = value
+ }
+ // The controller will restart pod according to FedReplicaSpec
+ podTemplate.Spec.RestartPolicy = v1.RestartPolicyNever
+
+ clusterSpecValue, err := makeClusterSpec(app.Namespace, app)
+ if err != nil {
+ log.Error(err, "unable to make cluster spec")
+ return err
+ }
+ for idx := range podTemplate.Spec.Containers {
+ container := &podTemplate.Spec.Containers[idx]
+ container.Env = ensureEnv(container.Env, v1.EnvVar{
+ Name: replicaIndex,
+ Value: strconv.Itoa(index),
+ })
+ container.Env = ensureEnv(container.Env, v1.EnvVar{
+ Name: serviceID,
+ Value: GenIndexName(app.Name, rt, index),
+ })
+ container.Env = ensureEnv(container.Env, v1.EnvVar{
+ Name: clusterSpec,
+ Value: clusterSpecValue,
+ })
+
+ // If pod use host network, overwrite all port to 0 to support autoport.
+ if podTemplate.Spec.HostNetwork {
+ for i := range container.Ports {
+ container.Ports[i].ContainerPort = 0
+ }
+ }
+
+ }
+
+ pod := &v1.Pod{
+ ObjectMeta: podTemplate.ObjectMeta,
+ Spec: podTemplate.Spec,
+ }
+ if err := ctrl.SetControllerReference(app, pod, r.Scheme); err != nil {
+ return err
+ }
+ log.Info("Create Pod", "Pod", pod.Name)
+ if err = r.Create(ctx, pod); err != nil {
+ return err
+ }
+ return nil
+}
+
+func ensureEnv(envVars []v1.EnvVar, item v1.EnvVar) []v1.EnvVar {
+ for idx := range envVars {
+ if envVars[idx].Name == item.Name {
+ envVars[idx] = item
+ return envVars
+ }
+ }
+ envVars = append(envVars, item)
+ return envVars
+}
+
+func GenLabels(app *fedlearnerv2.FedApp) map[string]string {
+ return map[string]string{
+ AppNameLabel: strings.Replace(app.Name, "/", "-", -1),
+ }
+}
+
+func makeClusterSpec(namespace string, app *fedlearnerv2.FedApp) (string, error) {
+ clusterSpec := NewClusterSpec(namespace, app)
+ bytes, err := clusterSpec.Marshal()
+ if err != nil {
+ return "", err
+ }
+ return string(bytes), nil
+}
+
+func AllPodsFailed(podSlice []*v1.Pod) bool {
+ for _, pod := range podSlice {
+ // TODO: support restart policy
+ if pod.Status.Phase != v1.PodFailed {
+ return false
+ }
+ }
+ return true
+}
diff --git a/operator/controllers/replica.go b/operator/controllers/replica.go
new file mode 100644
index 000000000..78b0aee42
--- /dev/null
+++ b/operator/controllers/replica.go
@@ -0,0 +1,124 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "context"
+ "strconv"
+ "strings"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ v1 "k8s.io/api/core/v1"
+ "k8s.io/apimachinery/pkg/api/errors"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/labels"
+ "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+func (r *FedAppReconciler) SyncReplicas(ctx context.Context, app *fedlearnerv2.FedApp, rtype fedlearnerv2.FedReplicaType,
+ childPods *v1.PodList, spec *fedlearnerv2.ReplicaSpec) (ReplicaResult, error) {
+ log := log.FromContext(ctx)
+ rt := strings.ToLower(string(rtype))
+ pods, err := filterPodsForReplicaType(childPods, rt)
+ if err != nil {
+ log.Error(err, "filter pods error: %v")
+ return ReplicaResult{}, err
+ }
+ replicas := int(*spec.Replicas)
+ podSlices := make([][]*v1.Pod, replicas)
+ for _, pod := range pods {
+ val, ok := pod.Labels[flReplicaIndexLabel]
+ if !ok {
+ log.Info("The pod do not have the index label.")
+ continue
+ }
+ index, err := strconv.Atoi(val)
+ if err != nil {
+ log.Error(err, "Error when strconv.Atoi.")
+ continue
+ }
+ if index < 0 || index >= replicas {
+ log.Info("The label index is not expected", "index", index)
+ } else {
+ podSlices[index] = append(podSlices[index], pod)
+ }
+ }
+ SyncTerminatedPods(app, rtype, podSlices)
+ // Fedapp old crd in some environment does not have terminatedPodsMap field.
+ // Can't update status here, because terminatedPodsMap will be nil after fedapp updated.
+ terminatedPods := *app.Status.TerminatedPodsMap[rtype]
+ failedPodsNames := GetAllFailedPodsNames(terminatedPods)
+ if len(failedPodsNames) >= int(*spec.BackoffLimit) {
+ // TODO(xiangyuxuan.prs): remove failed pod name, and add pod details in fedapp status.
+ app.Status.Conditions, _ = ensureConditionStatus(app.Status.Conditions, fedlearnerv2.Succeeded, v1.ConditionFalse, "BackoffLimitExceeded", "FedApp has reached the specified backoff limit: "+strings.Join(failedPodsNames, ", "))
+ if err := r.Status().Update(ctx, app); err != nil {
+ log.Error(err, "unable to update FedApp status BackoffLimitExceeded")
+ return ReplicaResult{}, err
+ }
+ // Requeue to rlease the resource
+ return ReplicaResult{isFailed: true}, nil
+ }
+
+ for index, podSlice := range podSlices {
+ if IfSliceHasSucceeded(terminatedPods, index) {
+ continue
+ }
+ needCreate := AllPodsFailed(podSlice)
+ if !needCreate {
+ continue
+ }
+ sliceHasFailedNum := len(terminatedPods.Failed[index])
+ if err := r.CreatePod(ctx, app, spec, index, rt, sliceHasFailedNum); !errors.IsAlreadyExists(err) {
+ if err == nil {
+ return ReplicaResult{}, nil
+ }
+ log.Error(err, "create Pod failed")
+ app.Status.Conditions, _ = ensureConditionStatus(app.Status.Conditions, fedlearnerv2.Succeeded, v1.ConditionFalse, "CreatePodFailed", err.Error())
+ if err := r.Status().Update(ctx, app); err != nil {
+ log.Error(err, "unable to update FedApp status CreatePodFailed")
+ return ReplicaResult{}, err
+ }
+ return ReplicaResult{isFailed: true}, nil
+ }
+
+ }
+ replicaCompleted := AllSliceCompletedOnce(terminatedPods, replicas)
+ return ReplicaResult{isCompleted: !*spec.MustSuccess || replicaCompleted}, nil
+}
+
+// filterPodsForReplicaType returns pods belong to a replicaType.
+func filterPodsForReplicaType(childPods *v1.PodList, replicaType string) ([]*v1.Pod, error) {
+ var result []*v1.Pod
+
+ replicaSelector := &metav1.LabelSelector{
+ MatchLabels: make(map[string]string),
+ }
+
+ replicaSelector.MatchLabels[flReplicaTypeLabel] = replicaType
+
+ selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
+ if err != nil {
+ return nil, err
+ }
+ pods := childPods.Items
+ for i := range pods {
+ if !selector.Matches(labels.Set(pods[i].Labels)) {
+ continue
+ }
+ result = append(result, &pods[i])
+ }
+ return result, nil
+}
diff --git a/operator/controllers/service.go b/operator/controllers/service.go
new file mode 100644
index 000000000..be50d098e
--- /dev/null
+++ b/operator/controllers/service.go
@@ -0,0 +1,177 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "context"
+ "strconv"
+ "strings"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ v1 "k8s.io/api/core/v1"
+ "k8s.io/apimachinery/pkg/api/errors"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/labels"
+ "k8s.io/apimachinery/pkg/util/intstr"
+ ctrl "sigs.k8s.io/controller-runtime"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+ "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+func (r *FedAppReconciler) syncServices(ctx context.Context, app *fedlearnerv2.FedApp) error {
+ log := log.FromContext(ctx)
+ var services v1.ServiceList
+ if err := r.List(ctx, &services, client.InNamespace(app.Namespace), client.MatchingFields{ownerKey: app.Name}); err != nil {
+ log.Error(err, "unable to list child Pods")
+ return err
+ }
+ for rtype, spec := range app.Spec.FedReplicaSpecs {
+ rt := strings.ToLower(string(rtype))
+ replicas := int(*spec.Replicas)
+ // Get all services for the type rt.
+ typeServices, err := filterServicesForReplicaType(&services, rt)
+ if err != nil {
+ return err
+ }
+ serviceSlices := makeServiceSlicesByIndex(ctx, typeServices, replicas)
+ for index, serviceSlice := range serviceSlices {
+ if len(serviceSlice) == 0 {
+ log.Info("need to create new service for", string(rtype), strconv.Itoa(index))
+ if err = r.createNewService(ctx, app, rtype, spec, index); err != nil {
+ return err
+ }
+ }
+ }
+
+ }
+ return nil
+}
+
+// createNewService creates a new service for the given index and type.
+func (r *FedAppReconciler) createNewService(ctx context.Context, app *fedlearnerv2.FedApp, rtype fedlearnerv2.FedReplicaType, spec fedlearnerv2.ReplicaSpec, index int) error {
+ log := log.FromContext(ctx)
+ rt := strings.ToLower(string(rtype))
+
+ // Append tfReplicaTypeLabel and tfReplicaIndexLabel labels.
+ labels := GenLabels(app)
+ labels[flReplicaTypeLabel] = rt
+ labels[flReplicaIndexLabel] = strconv.Itoa(index)
+ ports := GetPortsFromFedReplicaSpecs(app.Spec.FedReplicaSpecs[rtype])
+ var servicePorts []v1.ServicePort
+ for _, port := range ports {
+ servicePorts = append(servicePorts, v1.ServicePort{
+ Name: port.Name,
+ Port: port.ContainerPort,
+ TargetPort: intstr.IntOrString{
+ Type: 1, // means string
+ StrVal: port.Name,
+ },
+ })
+ }
+ service := &v1.Service{
+ Spec: v1.ServiceSpec{
+ Selector: labels,
+ Ports: servicePorts,
+ },
+ }
+
+ service.Name = GenIndexName(app.Name, rt, index)
+ service.Namespace = app.Namespace
+ service.Labels = labels
+ if err := ctrl.SetControllerReference(app, service, r.Scheme); err != nil {
+ return err
+ }
+ log.Info("Create Service", "Service", service.Name)
+ err := r.Create(ctx, service)
+ if err != nil && errors.IsAlreadyExists(err) {
+ return nil
+ }
+ return err
+}
+
+// GetPortsFromApp gets the ports of all containers.
+func GetPortsFromFedReplicaSpecs(replicaSpec fedlearnerv2.ReplicaSpec) []v1.ContainerPort {
+ var ports []v1.ContainerPort
+ containers := replicaSpec.Template.Spec.Containers
+ for _, container := range containers {
+ for _, port := range container.Ports {
+ if PortNotInPortList(port, ports) {
+ ports = append(ports, port)
+ }
+ }
+ }
+ if PortNotInPortList(*replicaSpec.Port, ports) {
+ ports = append(ports, *replicaSpec.Port)
+ }
+ return ports
+}
+
+func PortNotInPortList(port v1.ContainerPort, ports []v1.ContainerPort) bool {
+ for _, p := range ports {
+ if p.Name == port.Name || p.ContainerPort == port.ContainerPort {
+ return false
+ }
+ }
+ return true
+}
+
+// filterServicesForReplicaType returns service belong to a replicaType.
+func filterServicesForReplicaType(servicesList *v1.ServiceList, replicaType string) ([]*v1.Service, error) {
+ var result []*v1.Service
+ replicaSelector := &metav1.LabelSelector{
+ MatchLabels: make(map[string]string),
+ }
+
+ replicaSelector.MatchLabels[flReplicaTypeLabel] = replicaType
+ services := servicesList.Items
+ for index := range services {
+ selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
+ if err != nil {
+ return nil, err
+ }
+ if !selector.Matches(labels.Set(services[index].Labels)) {
+ continue
+ }
+ result = append(result, &services[index])
+ }
+ return result, nil
+}
+
+// makeServiceSlicesByIndex returns a slice, which element is the slice of service.
+// Assume the return object is serviceSlices, then serviceSlices[i] is an
+// array of pointers to services corresponding to Services for replica i.
+func makeServiceSlicesByIndex(ctx context.Context, services []*v1.Service, replicas int) [][]*v1.Service {
+ log := log.FromContext(ctx)
+ serviceSlices := make([][]*v1.Service, replicas)
+ for _, service := range services {
+ if _, ok := service.Labels[flReplicaIndexLabel]; !ok {
+ log.Info("The pod do not have the index label.")
+ continue
+ }
+ index, err := strconv.Atoi(service.Labels[flReplicaIndexLabel])
+ if err != nil {
+ log.Error(err, "Error when strconv.Atoi.")
+ continue
+ }
+ if index < 0 || index >= replicas {
+ log.Info("The label index is not expected", "index", index, "replicas", replicas)
+ continue
+ } else {
+ serviceSlices[index] = append(serviceSlices[index], service)
+ }
+ }
+ return serviceSlices
+}
diff --git a/operator/controllers/suite_test.go b/operator/controllers/suite_test.go
new file mode 100644
index 000000000..552077940
--- /dev/null
+++ b/operator/controllers/suite_test.go
@@ -0,0 +1,98 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ "path/filepath"
+ "testing"
+
+ ctrl "sigs.k8s.io/controller-runtime"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+ "k8s.io/client-go/kubernetes/scheme"
+ "k8s.io/client-go/rest"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+ "sigs.k8s.io/controller-runtime/pkg/envtest"
+ "sigs.k8s.io/controller-runtime/pkg/envtest/printer"
+ logf "sigs.k8s.io/controller-runtime/pkg/log"
+ "sigs.k8s.io/controller-runtime/pkg/log/zap"
+
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ //+kubebuilder:scaffold:imports
+)
+
+// These tests use Ginkgo (BDD-style Go testing framework). Refer to
+// http://onsi.github.io/ginkgo/ to learn more about Ginkgo.
+
+var cfg *rest.Config
+var k8sClient client.Client
+var testEnv *envtest.Environment
+
+func TestAPIs(t *testing.T) {
+ RegisterFailHandler(Fail)
+
+ RunSpecsWithDefaultAndCustomReporters(t,
+ "Controller Suite",
+ []Reporter{printer.NewlineReporter{}})
+}
+
+var _ = BeforeSuite(func() {
+ logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)))
+
+ By("bootstrapping test environment")
+ testEnv = &envtest.Environment{
+ CRDDirectoryPaths: []string{filepath.Join("..", "deploy_charts", "fedapp.yaml")},
+ ErrorIfCRDPathMissing: true,
+ }
+
+ cfg, err := testEnv.Start()
+ Expect(err).NotTo(HaveOccurred())
+ Expect(cfg).NotTo(BeNil())
+
+ err = fedlearnerv2.AddToScheme(scheme.Scheme)
+ Expect(err).NotTo(HaveOccurred())
+
+ //+kubebuilder:scaffold:scheme
+
+ k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})
+ Expect(err).NotTo(HaveOccurred())
+ Expect(k8sClient).NotTo(BeNil())
+
+ k8sManager, err := ctrl.NewManager(cfg, ctrl.Options{
+ Scheme: scheme.Scheme,
+ })
+ Expect(err).ToNot(HaveOccurred())
+
+ err = (&FedAppReconciler{
+ Client: k8sManager.GetClient(),
+ Scheme: k8sManager.GetScheme(),
+ }).SetupWithManager(k8sManager)
+ Expect(err).ToNot(HaveOccurred())
+
+ go func() {
+ defer GinkgoRecover()
+ err = k8sManager.Start(ctrl.SetupSignalHandler())
+ Expect(err).ToNot(HaveOccurred(), "failed to run manager")
+ }()
+
+}, 60)
+
+var _ = AfterSuite(func() {
+ By("tearing down the test environment")
+ err := testEnv.Stop()
+ Expect(err).NotTo(HaveOccurred())
+})
diff --git a/operator/controllers/terminated_pods.go b/operator/controllers/terminated_pods.go
new file mode 100644
index 000000000..06deb4161
--- /dev/null
+++ b/operator/controllers/terminated_pods.go
@@ -0,0 +1,90 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package controllers
+
+import (
+ fedlearnerv2 "fedlearner.net/operator/api/v1alpha1"
+ v1 "k8s.io/api/core/v1"
+)
+
+func GetAllFailedPodsNames(tPods fedlearnerv2.TerminatedPods) []string {
+ result := []string{}
+ for _, podSet := range tPods.Failed {
+ result = append(result, getPodNames(podSet)...)
+ }
+ return result
+}
+
+func getPodNames(set fedlearnerv2.PodSet) []string {
+ result := []string{}
+ for p := range set {
+ result = append(result, p)
+ }
+ return result
+}
+
+func AllSliceCompletedOnce(tPods fedlearnerv2.TerminatedPods, replicas int) bool {
+ for i := 0; i < replicas; i++ {
+ if !IfSliceHasSucceeded(tPods, i) {
+ return false
+ }
+ }
+ return true
+}
+
+func IfSliceHasSucceeded(tPods fedlearnerv2.TerminatedPods, index int) bool {
+ return len(tPods.Succeeded[index]) > 0
+}
+
+func SyncTerminatedPods(app *fedlearnerv2.FedApp, rtype fedlearnerv2.FedReplicaType, podSlices [][]*v1.Pod) {
+ for index, podSlice := range podSlices {
+ for _, pod := range podSlice {
+ // TODO: support restart policy
+ if pod.Status.Phase == v1.PodSucceeded {
+ setAdd(app.Status.TerminatedPodsMap[rtype].Succeeded[index], pod.Name)
+ }
+ if pod.Status.Phase == v1.PodFailed {
+ setAdd(app.Status.TerminatedPodsMap[rtype].Failed[index], pod.Name)
+ }
+
+ }
+ }
+}
+
+func setAdd(set fedlearnerv2.PodSet, name string) {
+ // TODO: remove finalizer when pod created with finalizer.
+ set[name] = struct{}{}
+}
+
+func InitTerminatedPodsMap(app fedlearnerv2.FedApp) map[fedlearnerv2.FedReplicaType]*fedlearnerv2.TerminatedPods {
+ terminatedPodsMap := fedlearnerv2.TerminatedPodsMap{}
+
+ for rtype, spec := range app.Spec.FedReplicaSpecs {
+ replicas := int(*spec.Replicas)
+ succeeded := make([]fedlearnerv2.PodSet, replicas)
+ failed := make([]fedlearnerv2.PodSet, replicas)
+ for i := range succeeded {
+ succeeded[i] = fedlearnerv2.PodSet{}
+ }
+ for i := range failed {
+ failed[i] = fedlearnerv2.PodSet{}
+ }
+ terminatedPodsMap[rtype] = &fedlearnerv2.TerminatedPods{
+ Succeeded: succeeded, Failed: failed,
+ }
+ }
+ return terminatedPodsMap
+}
diff --git a/operator/deploy_charts/fedapp.yaml b/operator/deploy_charts/fedapp.yaml
new file mode 100644
index 000000000..eafb0eb2e
--- /dev/null
+++ b/operator/deploy_charts/fedapp.yaml
@@ -0,0 +1,7508 @@
+
+---
+apiVersion: apiextensions.k8s.io/v1
+kind: CustomResourceDefinition
+metadata:
+ annotations:
+ api-approved.kubernetes.io: https://github.com/kubernetes/kubernetes/pull/78458
+ controller-gen.kubebuilder.io/version: v0.7.0
+ creationTimestamp: null
+ name: fedapps.fedlearner.k8s.io
+spec:
+ group: fedlearner.k8s.io
+ names:
+ kind: FedApp
+ listKind: FedAppList
+ plural: fedapps
+ singular: fedapp
+ scope: Namespaced
+ versions:
+ - name: v1alpha1
+ schema:
+ openAPIV3Schema:
+ description: FedApp is the Schema for the fedapps API
+ properties:
+ apiVersion:
+ description: 'APIVersion defines the versioned schema of this representation
+ of an object. Servers should convert recognized schemas to the latest
+ internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources'
+ type: string
+ kind:
+ description: 'Kind is a string value representing the REST resource this
+ object represents. Servers may infer this from the endpoint the client
+ submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds'
+ type: string
+ metadata:
+ type: object
+ spec:
+ description: FedAppSpec defines the desired state of FedApp
+ properties:
+ activeDeadlineSeconds:
+ description: Specifies the duration in seconds relative to the startTime
+ that the job may be active before the system tries to terminate
+ it; value must be positive integer.
+ format: int64
+ type: integer
+ fedReplicaSpecs:
+ additionalProperties:
+ description: ReplicaSpec is a description of the replica
+ properties:
+ backoffLimit:
+ default: 1
+ description: Optional number of retries before marking this
+ job failed.
+ format: int64
+ maximum: 100
+ minimum: 1
+ type: integer
+ mustSuccess:
+ default: true
+ description: Whether all pods of this replica are suceeded is
+ necessary for marking the falpp as complete.
+ type: boolean
+ port:
+ default:
+ containerPort: 50051
+ name: flapp-port
+ protocol: TCP
+ description: ContainerPort represents a network port in a single
+ container.
+ properties:
+ containerPort:
+ description: Number of port to expose on the pod's IP address.
+ This must be a valid port number, 0 < x < 65536.
+ format: int32
+ type: integer
+ hostIP:
+ description: What host IP to bind the external port to.
+ type: string
+ hostPort:
+ description: Number of port to expose on the host. If specified,
+ this must be a valid port number, 0 < x < 65536. If HostNetwork
+ is specified, this must match ContainerPort. Most containers
+ do not need this.
+ format: int32
+ type: integer
+ name:
+ description: If specified, this must be an IANA_SVC_NAME
+ and unique within the pod. Each named port in a pod must
+ have a unique name. Name for the port that can be referred
+ to by services.
+ type: string
+ protocol:
+ default: TCP
+ description: Protocol for port. Must be UDP, TCP, or SCTP.
+ Defaults to "TCP".
+ type: string
+ required:
+ - containerPort
+ type: object
+ replicas:
+ default: 0
+ description: Replicas is the desired number of replicas of the
+ given template.
+ format: int64
+ maximum: 200
+ minimum: 0
+ type: integer
+ restartPolicy:
+ default: OnFailure
+ description: Restart policy for all replicas within the app.
+ One of Always, OnFailure, Never and ExitCode.
+ type: string
+ template:
+ description: Template is the object that describes the pod that
+ will be created for this replica.
+ properties:
+ metadata:
+ description: 'Standard object''s metadata. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#metadata'
+ type: object
+ spec:
+ description: 'Specification of the desired behavior of the
+ pod. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status'
+ properties:
+ activeDeadlineSeconds:
+ description: Optional duration in seconds the pod may
+ be active on the node relative to StartTime before
+ the system will actively try to mark it failed and
+ kill associated containers. Value must be a positive
+ integer.
+ format: int64
+ type: integer
+ affinity:
+ description: If specified, the pod's scheduling constraints
+ properties:
+ nodeAffinity:
+ description: Describes node affinity scheduling
+ rules for the pod.
+ properties:
+ preferredDuringSchedulingIgnoredDuringExecution:
+ description: The scheduler will prefer to schedule
+ pods to nodes that satisfy the affinity expressions
+ specified by this field, but it may choose
+ a node that violates one or more of the expressions.
+ The node that is most preferred is the one
+ with the greatest sum of weights, i.e. for
+ each node that meets all of the scheduling
+ requirements (resource request, requiredDuringScheduling
+ affinity expressions, etc.), compute a sum
+ by iterating through the elements of this
+ field and adding "weight" to the sum if the
+ node matches the corresponding matchExpressions;
+ the node(s) with the highest sum are the most
+ preferred.
+ items:
+ description: An empty preferred scheduling
+ term matches all objects with implicit weight
+ 0 (i.e. it's a no-op). A null preferred
+ scheduling term matches no objects (i.e.
+ is also a no-op).
+ properties:
+ preference:
+ description: A node selector term, associated
+ with the corresponding weight.
+ properties:
+ matchExpressions:
+ description: A list of node selector
+ requirements by node's labels.
+ items:
+ description: A node selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: The label key that
+ the selector applies to.
+ type: string
+ operator:
+ description: Represents a key's
+ relationship to a set of values.
+ Valid operators are In, NotIn,
+ Exists, DoesNotExist. Gt,
+ and Lt.
+ type: string
+ values:
+ description: An array of string
+ values. If the operator is
+ In or NotIn, the values array
+ must be non-empty. If the
+ operator is Exists or DoesNotExist,
+ the values array must be empty.
+ If the operator is Gt or Lt,
+ the values array must have
+ a single element, which will
+ be interpreted as an integer.
+ This array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchFields:
+ description: A list of node selector
+ requirements by node's fields.
+ items:
+ description: A node selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: The label key that
+ the selector applies to.
+ type: string
+ operator:
+ description: Represents a key's
+ relationship to a set of values.
+ Valid operators are In, NotIn,
+ Exists, DoesNotExist. Gt,
+ and Lt.
+ type: string
+ values:
+ description: An array of string
+ values. If the operator is
+ In or NotIn, the values array
+ must be non-empty. If the
+ operator is Exists or DoesNotExist,
+ the values array must be empty.
+ If the operator is Gt or Lt,
+ the values array must have
+ a single element, which will
+ be interpreted as an integer.
+ This array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ type: object
+ weight:
+ description: Weight associated with matching
+ the corresponding nodeSelectorTerm,
+ in the range 1-100.
+ format: int32
+ type: integer
+ required:
+ - preference
+ - weight
+ type: object
+ type: array
+ requiredDuringSchedulingIgnoredDuringExecution:
+ description: If the affinity requirements specified
+ by this field are not met at scheduling time,
+ the pod will not be scheduled onto the node.
+ If the affinity requirements specified by
+ this field cease to be met at some point during
+ pod execution (e.g. due to an update), the
+ system may or may not try to eventually evict
+ the pod from its node.
+ properties:
+ nodeSelectorTerms:
+ description: Required. A list of node selector
+ terms. The terms are ORed.
+ items:
+ description: A null or empty node selector
+ term matches no objects. The requirements
+ of them are ANDed. The TopologySelectorTerm
+ type implements a subset of the NodeSelectorTerm.
+ properties:
+ matchExpressions:
+ description: A list of node selector
+ requirements by node's labels.
+ items:
+ description: A node selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: The label key that
+ the selector applies to.
+ type: string
+ operator:
+ description: Represents a key's
+ relationship to a set of values.
+ Valid operators are In, NotIn,
+ Exists, DoesNotExist. Gt,
+ and Lt.
+ type: string
+ values:
+ description: An array of string
+ values. If the operator is
+ In or NotIn, the values array
+ must be non-empty. If the
+ operator is Exists or DoesNotExist,
+ the values array must be empty.
+ If the operator is Gt or Lt,
+ the values array must have
+ a single element, which will
+ be interpreted as an integer.
+ This array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchFields:
+ description: A list of node selector
+ requirements by node's fields.
+ items:
+ description: A node selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: The label key that
+ the selector applies to.
+ type: string
+ operator:
+ description: Represents a key's
+ relationship to a set of values.
+ Valid operators are In, NotIn,
+ Exists, DoesNotExist. Gt,
+ and Lt.
+ type: string
+ values:
+ description: An array of string
+ values. If the operator is
+ In or NotIn, the values array
+ must be non-empty. If the
+ operator is Exists or DoesNotExist,
+ the values array must be empty.
+ If the operator is Gt or Lt,
+ the values array must have
+ a single element, which will
+ be interpreted as an integer.
+ This array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ type: object
+ type: array
+ required:
+ - nodeSelectorTerms
+ type: object
+ type: object
+ podAffinity:
+ description: Describes pod affinity scheduling rules
+ (e.g. co-locate this pod in the same node, zone,
+ etc. as some other pod(s)).
+ properties:
+ preferredDuringSchedulingIgnoredDuringExecution:
+ description: The scheduler will prefer to schedule
+ pods to nodes that satisfy the affinity expressions
+ specified by this field, but it may choose
+ a node that violates one or more of the expressions.
+ The node that is most preferred is the one
+ with the greatest sum of weights, i.e. for
+ each node that meets all of the scheduling
+ requirements (resource request, requiredDuringScheduling
+ affinity expressions, etc.), compute a sum
+ by iterating through the elements of this
+ field and adding "weight" to the sum if the
+ node has pods which matches the corresponding
+ podAffinityTerm; the node(s) with the highest
+ sum are the most preferred.
+ items:
+ description: The weights of all of the matched
+ WeightedPodAffinityTerm fields are added
+ per-node to find the most preferred node(s)
+ properties:
+ podAffinityTerm:
+ description: Required. A pod affinity
+ term, associated with the corresponding
+ weight.
+ properties:
+ labelSelector:
+ description: A label query over a
+ set of resources, in this case pods.
+ properties:
+ matchExpressions:
+ description: matchExpressions
+ is a list of label selector
+ requirements. The requirements
+ are ANDed.
+ items:
+ description: A label selector
+ requirement is a selector
+ that contains values, a key,
+ and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the
+ label key that the selector
+ applies to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to
+ a set of values. Valid
+ operators are In, NotIn,
+ Exists and DoesNotExist.
+ type: string
+ values:
+ description: values is an
+ array of string values.
+ If the operator is In
+ or NotIn, the values array
+ must be non-empty. If
+ the operator is Exists
+ or DoesNotExist, the values
+ array must be empty. This
+ array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a
+ map of {key,value} pairs. A
+ single {key,value} in the matchLabels
+ map is equivalent to an element
+ of matchExpressions, whose key
+ field is "key", the operator
+ is "In", and the values array
+ contains only "value". The requirements
+ are ANDed.
+ type: object
+ type: object
+ namespaceSelector:
+ description: A label query over the
+ set of namespaces that the term
+ applies to. The term is applied
+ to the union of the namespaces selected
+ by this field and the ones listed
+ in the namespaces field. null selector
+ and null or empty namespaces list
+ means "this pod's namespace". An
+ empty selector ({}) matches all
+ namespaces. This field is beta-level
+ and is only honored when PodAffinityNamespaceSelector
+ feature is enabled.
+ properties:
+ matchExpressions:
+ description: matchExpressions
+ is a list of label selector
+ requirements. The requirements
+ are ANDed.
+ items:
+ description: A label selector
+ requirement is a selector
+ that contains values, a key,
+ and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the
+ label key that the selector
+ applies to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to
+ a set of values. Valid
+ operators are In, NotIn,
+ Exists and DoesNotExist.
+ type: string
+ values:
+ description: values is an
+ array of string values.
+ If the operator is In
+ or NotIn, the values array
+ must be non-empty. If
+ the operator is Exists
+ or DoesNotExist, the values
+ array must be empty. This
+ array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a
+ map of {key,value} pairs. A
+ single {key,value} in the matchLabels
+ map is equivalent to an element
+ of matchExpressions, whose key
+ field is "key", the operator
+ is "In", and the values array
+ contains only "value". The requirements
+ are ANDed.
+ type: object
+ type: object
+ namespaces:
+ description: namespaces specifies
+ a static list of namespace names
+ that the term applies to. The term
+ is applied to the union of the namespaces
+ listed in this field and the ones
+ selected by namespaceSelector. null
+ or empty namespaces list and null
+ namespaceSelector means "this pod's
+ namespace"
+ items:
+ type: string
+ type: array
+ topologyKey:
+ description: This pod should be co-located
+ (affinity) or not co-located (anti-affinity)
+ with the pods matching the labelSelector
+ in the specified namespaces, where
+ co-located is defined as running
+ on a node whose value of the label
+ with key topologyKey matches that
+ of any node on which any of the
+ selected pods is running. Empty
+ topologyKey is not allowed.
+ type: string
+ required:
+ - topologyKey
+ type: object
+ weight:
+ description: weight associated with matching
+ the corresponding podAffinityTerm, in
+ the range 1-100.
+ format: int32
+ type: integer
+ required:
+ - podAffinityTerm
+ - weight
+ type: object
+ type: array
+ requiredDuringSchedulingIgnoredDuringExecution:
+ description: If the affinity requirements specified
+ by this field are not met at scheduling time,
+ the pod will not be scheduled onto the node.
+ If the affinity requirements specified by
+ this field cease to be met at some point during
+ pod execution (e.g. due to a pod label update),
+ the system may or may not try to eventually
+ evict the pod from its node. When there are
+ multiple elements, the lists of nodes corresponding
+ to each podAffinityTerm are intersected, i.e.
+ all terms must be satisfied.
+ items:
+ description: Defines a set of pods (namely
+ those matching the labelSelector relative
+ to the given namespace(s)) that this pod
+ should be co-located (affinity) or not co-located
+ (anti-affinity) with, where co-located is
+ defined as running on a node whose value
+ of the label with key matches
+ that of any node on which a pod of the set
+ of pods is running
+ properties:
+ labelSelector:
+ description: A label query over a set
+ of resources, in this case pods.
+ properties:
+ matchExpressions:
+ description: matchExpressions is a
+ list of label selector requirements.
+ The requirements are ANDed.
+ items:
+ description: A label selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the label
+ key that the selector applies
+ to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to a
+ set of values. Valid operators
+ are In, NotIn, Exists and
+ DoesNotExist.
+ type: string
+ values:
+ description: values is an array
+ of string values. If the operator
+ is In or NotIn, the values
+ array must be non-empty. If
+ the operator is Exists or
+ DoesNotExist, the values array
+ must be empty. This array
+ is replaced during a strategic
+ merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a map
+ of {key,value} pairs. A single {key,value}
+ in the matchLabels map is equivalent
+ to an element of matchExpressions,
+ whose key field is "key", the operator
+ is "In", and the values array contains
+ only "value". The requirements are
+ ANDed.
+ type: object
+ type: object
+ namespaceSelector:
+ description: A label query over the set
+ of namespaces that the term applies
+ to. The term is applied to the union
+ of the namespaces selected by this field
+ and the ones listed in the namespaces
+ field. null selector and null or empty
+ namespaces list means "this pod's namespace".
+ An empty selector ({}) matches all namespaces.
+ This field is beta-level and is only
+ honored when PodAffinityNamespaceSelector
+ feature is enabled.
+ properties:
+ matchExpressions:
+ description: matchExpressions is a
+ list of label selector requirements.
+ The requirements are ANDed.
+ items:
+ description: A label selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the label
+ key that the selector applies
+ to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to a
+ set of values. Valid operators
+ are In, NotIn, Exists and
+ DoesNotExist.
+ type: string
+ values:
+ description: values is an array
+ of string values. If the operator
+ is In or NotIn, the values
+ array must be non-empty. If
+ the operator is Exists or
+ DoesNotExist, the values array
+ must be empty. This array
+ is replaced during a strategic
+ merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a map
+ of {key,value} pairs. A single {key,value}
+ in the matchLabels map is equivalent
+ to an element of matchExpressions,
+ whose key field is "key", the operator
+ is "In", and the values array contains
+ only "value". The requirements are
+ ANDed.
+ type: object
+ type: object
+ namespaces:
+ description: namespaces specifies a static
+ list of namespace names that the term
+ applies to. The term is applied to the
+ union of the namespaces listed in this
+ field and the ones selected by namespaceSelector.
+ null or empty namespaces list and null
+ namespaceSelector means "this pod's
+ namespace"
+ items:
+ type: string
+ type: array
+ topologyKey:
+ description: This pod should be co-located
+ (affinity) or not co-located (anti-affinity)
+ with the pods matching the labelSelector
+ in the specified namespaces, where co-located
+ is defined as running on a node whose
+ value of the label with key topologyKey
+ matches that of any node on which any
+ of the selected pods is running. Empty
+ topologyKey is not allowed.
+ type: string
+ required:
+ - topologyKey
+ type: object
+ type: array
+ type: object
+ podAntiAffinity:
+ description: Describes pod anti-affinity scheduling
+ rules (e.g. avoid putting this pod in the same
+ node, zone, etc. as some other pod(s)).
+ properties:
+ preferredDuringSchedulingIgnoredDuringExecution:
+ description: The scheduler will prefer to schedule
+ pods to nodes that satisfy the anti-affinity
+ expressions specified by this field, but it
+ may choose a node that violates one or more
+ of the expressions. The node that is most
+ preferred is the one with the greatest sum
+ of weights, i.e. for each node that meets
+ all of the scheduling requirements (resource
+ request, requiredDuringScheduling anti-affinity
+ expressions, etc.), compute a sum by iterating
+ through the elements of this field and adding
+ "weight" to the sum if the node has pods which
+ matches the corresponding podAffinityTerm;
+ the node(s) with the highest sum are the most
+ preferred.
+ items:
+ description: The weights of all of the matched
+ WeightedPodAffinityTerm fields are added
+ per-node to find the most preferred node(s)
+ properties:
+ podAffinityTerm:
+ description: Required. A pod affinity
+ term, associated with the corresponding
+ weight.
+ properties:
+ labelSelector:
+ description: A label query over a
+ set of resources, in this case pods.
+ properties:
+ matchExpressions:
+ description: matchExpressions
+ is a list of label selector
+ requirements. The requirements
+ are ANDed.
+ items:
+ description: A label selector
+ requirement is a selector
+ that contains values, a key,
+ and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the
+ label key that the selector
+ applies to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to
+ a set of values. Valid
+ operators are In, NotIn,
+ Exists and DoesNotExist.
+ type: string
+ values:
+ description: values is an
+ array of string values.
+ If the operator is In
+ or NotIn, the values array
+ must be non-empty. If
+ the operator is Exists
+ or DoesNotExist, the values
+ array must be empty. This
+ array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a
+ map of {key,value} pairs. A
+ single {key,value} in the matchLabels
+ map is equivalent to an element
+ of matchExpressions, whose key
+ field is "key", the operator
+ is "In", and the values array
+ contains only "value". The requirements
+ are ANDed.
+ type: object
+ type: object
+ namespaceSelector:
+ description: A label query over the
+ set of namespaces that the term
+ applies to. The term is applied
+ to the union of the namespaces selected
+ by this field and the ones listed
+ in the namespaces field. null selector
+ and null or empty namespaces list
+ means "this pod's namespace". An
+ empty selector ({}) matches all
+ namespaces. This field is beta-level
+ and is only honored when PodAffinityNamespaceSelector
+ feature is enabled.
+ properties:
+ matchExpressions:
+ description: matchExpressions
+ is a list of label selector
+ requirements. The requirements
+ are ANDed.
+ items:
+ description: A label selector
+ requirement is a selector
+ that contains values, a key,
+ and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the
+ label key that the selector
+ applies to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to
+ a set of values. Valid
+ operators are In, NotIn,
+ Exists and DoesNotExist.
+ type: string
+ values:
+ description: values is an
+ array of string values.
+ If the operator is In
+ or NotIn, the values array
+ must be non-empty. If
+ the operator is Exists
+ or DoesNotExist, the values
+ array must be empty. This
+ array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a
+ map of {key,value} pairs. A
+ single {key,value} in the matchLabels
+ map is equivalent to an element
+ of matchExpressions, whose key
+ field is "key", the operator
+ is "In", and the values array
+ contains only "value". The requirements
+ are ANDed.
+ type: object
+ type: object
+ namespaces:
+ description: namespaces specifies
+ a static list of namespace names
+ that the term applies to. The term
+ is applied to the union of the namespaces
+ listed in this field and the ones
+ selected by namespaceSelector. null
+ or empty namespaces list and null
+ namespaceSelector means "this pod's
+ namespace"
+ items:
+ type: string
+ type: array
+ topologyKey:
+ description: This pod should be co-located
+ (affinity) or not co-located (anti-affinity)
+ with the pods matching the labelSelector
+ in the specified namespaces, where
+ co-located is defined as running
+ on a node whose value of the label
+ with key topologyKey matches that
+ of any node on which any of the
+ selected pods is running. Empty
+ topologyKey is not allowed.
+ type: string
+ required:
+ - topologyKey
+ type: object
+ weight:
+ description: weight associated with matching
+ the corresponding podAffinityTerm, in
+ the range 1-100.
+ format: int32
+ type: integer
+ required:
+ - podAffinityTerm
+ - weight
+ type: object
+ type: array
+ requiredDuringSchedulingIgnoredDuringExecution:
+ description: If the anti-affinity requirements
+ specified by this field are not met at scheduling
+ time, the pod will not be scheduled onto the
+ node. If the anti-affinity requirements specified
+ by this field cease to be met at some point
+ during pod execution (e.g. due to a pod label
+ update), the system may or may not try to
+ eventually evict the pod from its node. When
+ there are multiple elements, the lists of
+ nodes corresponding to each podAffinityTerm
+ are intersected, i.e. all terms must be satisfied.
+ items:
+ description: Defines a set of pods (namely
+ those matching the labelSelector relative
+ to the given namespace(s)) that this pod
+ should be co-located (affinity) or not co-located
+ (anti-affinity) with, where co-located is
+ defined as running on a node whose value
+ of the label with key matches
+ that of any node on which a pod of the set
+ of pods is running
+ properties:
+ labelSelector:
+ description: A label query over a set
+ of resources, in this case pods.
+ properties:
+ matchExpressions:
+ description: matchExpressions is a
+ list of label selector requirements.
+ The requirements are ANDed.
+ items:
+ description: A label selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the label
+ key that the selector applies
+ to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to a
+ set of values. Valid operators
+ are In, NotIn, Exists and
+ DoesNotExist.
+ type: string
+ values:
+ description: values is an array
+ of string values. If the operator
+ is In or NotIn, the values
+ array must be non-empty. If
+ the operator is Exists or
+ DoesNotExist, the values array
+ must be empty. This array
+ is replaced during a strategic
+ merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a map
+ of {key,value} pairs. A single {key,value}
+ in the matchLabels map is equivalent
+ to an element of matchExpressions,
+ whose key field is "key", the operator
+ is "In", and the values array contains
+ only "value". The requirements are
+ ANDed.
+ type: object
+ type: object
+ namespaceSelector:
+ description: A label query over the set
+ of namespaces that the term applies
+ to. The term is applied to the union
+ of the namespaces selected by this field
+ and the ones listed in the namespaces
+ field. null selector and null or empty
+ namespaces list means "this pod's namespace".
+ An empty selector ({}) matches all namespaces.
+ This field is beta-level and is only
+ honored when PodAffinityNamespaceSelector
+ feature is enabled.
+ properties:
+ matchExpressions:
+ description: matchExpressions is a
+ list of label selector requirements.
+ The requirements are ANDed.
+ items:
+ description: A label selector requirement
+ is a selector that contains values,
+ a key, and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the label
+ key that the selector applies
+ to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to a
+ set of values. Valid operators
+ are In, NotIn, Exists and
+ DoesNotExist.
+ type: string
+ values:
+ description: values is an array
+ of string values. If the operator
+ is In or NotIn, the values
+ array must be non-empty. If
+ the operator is Exists or
+ DoesNotExist, the values array
+ must be empty. This array
+ is replaced during a strategic
+ merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a map
+ of {key,value} pairs. A single {key,value}
+ in the matchLabels map is equivalent
+ to an element of matchExpressions,
+ whose key field is "key", the operator
+ is "In", and the values array contains
+ only "value". The requirements are
+ ANDed.
+ type: object
+ type: object
+ namespaces:
+ description: namespaces specifies a static
+ list of namespace names that the term
+ applies to. The term is applied to the
+ union of the namespaces listed in this
+ field and the ones selected by namespaceSelector.
+ null or empty namespaces list and null
+ namespaceSelector means "this pod's
+ namespace"
+ items:
+ type: string
+ type: array
+ topologyKey:
+ description: This pod should be co-located
+ (affinity) or not co-located (anti-affinity)
+ with the pods matching the labelSelector
+ in the specified namespaces, where co-located
+ is defined as running on a node whose
+ value of the label with key topologyKey
+ matches that of any node on which any
+ of the selected pods is running. Empty
+ topologyKey is not allowed.
+ type: string
+ required:
+ - topologyKey
+ type: object
+ type: array
+ type: object
+ type: object
+ automountServiceAccountToken:
+ description: AutomountServiceAccountToken indicates
+ whether a service account token should be automatically
+ mounted.
+ type: boolean
+ containers:
+ description: List of containers belonging to the pod.
+ Containers cannot currently be added or removed. There
+ must be at least one container in a Pod. Cannot be
+ updated.
+ items:
+ description: A single application container that you
+ want to run within a pod.
+ properties:
+ args:
+ description: 'Arguments to the entrypoint. The
+ docker image''s CMD is used if this is not provided.
+ Variable references $(VAR_NAME) are expanded
+ using the container''s environment. If a variable
+ cannot be resolved, the reference in the input
+ string will be unchanged. Double $$ are reduced
+ to a single $, which allows for escaping the
+ $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)" will
+ produce the string literal "$(VAR_NAME)". Escaped
+ references will never be expanded, regardless
+ of whether the variable exists or not. Cannot
+ be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell'
+ items:
+ type: string
+ type: array
+ command:
+ description: 'Entrypoint array. Not executed within
+ a shell. The docker image''s ENTRYPOINT is used
+ if this is not provided. Variable references
+ $(VAR_NAME) are expanded using the container''s
+ environment. If a variable cannot be resolved,
+ the reference in the input string will be unchanged.
+ Double $$ are reduced to a single $, which allows
+ for escaping the $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)"
+ will produce the string literal "$(VAR_NAME)".
+ Escaped references will never be expanded, regardless
+ of whether the variable exists or not. Cannot
+ be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell'
+ items:
+ type: string
+ type: array
+ env:
+ description: List of environment variables to
+ set in the container. Cannot be updated.
+ items:
+ description: EnvVar represents an environment
+ variable present in a Container.
+ properties:
+ name:
+ description: Name of the environment variable.
+ Must be a C_IDENTIFIER.
+ type: string
+ value:
+ description: 'Variable references $(VAR_NAME)
+ are expanded using the previously defined
+ environment variables in the container
+ and any service environment variables.
+ If a variable cannot be resolved, the
+ reference in the input string will be
+ unchanged. Double $$ are reduced to a
+ single $, which allows for escaping the
+ $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)"
+ will produce the string literal "$(VAR_NAME)".
+ Escaped references will never be expanded,
+ regardless of whether the variable exists
+ or not. Defaults to "".'
+ type: string
+ valueFrom:
+ description: Source for the environment
+ variable's value. Cannot be used if value
+ is not empty.
+ properties:
+ configMapKeyRef:
+ description: Selects a key of a ConfigMap.
+ properties:
+ key:
+ description: The key to select.
+ type: string
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ ConfigMap or its key must be defined
+ type: boolean
+ required:
+ - key
+ type: object
+ fieldRef:
+ description: 'Selects a field of the
+ pod: supports metadata.name, metadata.namespace,
+ `metadata.labels['''']`, `metadata.annotations['''']`,
+ spec.nodeName, spec.serviceAccountName,
+ status.hostIP, status.podIP, status.podIPs.'
+ properties:
+ apiVersion:
+ description: Version of the schema
+ the FieldPath is written in terms
+ of, defaults to "v1".
+ type: string
+ fieldPath:
+ description: Path of the field to
+ select in the specified API version.
+ type: string
+ required:
+ - fieldPath
+ type: object
+ resourceFieldRef:
+ description: 'Selects a resource of
+ the container: only resources limits
+ and requests (limits.cpu, limits.memory,
+ limits.ephemeral-storage, requests.cpu,
+ requests.memory and requests.ephemeral-storage)
+ are currently supported.'
+ properties:
+ containerName:
+ description: 'Container name: required
+ for volumes, optional for env
+ vars'
+ type: string
+ divisor:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Specifies the output
+ format of the exposed resources,
+ defaults to "1"
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ resource:
+ description: 'Required: resource
+ to select'
+ type: string
+ required:
+ - resource
+ type: object
+ secretKeyRef:
+ description: Selects a key of a secret
+ in the pod's namespace
+ properties:
+ key:
+ description: The key of the secret
+ to select from. Must be a valid
+ secret key.
+ type: string
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ Secret or its key must be defined
+ type: boolean
+ required:
+ - key
+ type: object
+ type: object
+ required:
+ - name
+ type: object
+ type: array
+ envFrom:
+ description: List of sources to populate environment
+ variables in the container. The keys defined
+ within a source must be a C_IDENTIFIER. All
+ invalid keys will be reported as an event when
+ the container is starting. When a key exists
+ in multiple sources, the value associated with
+ the last source will take precedence. Values
+ defined by an Env with a duplicate key will
+ take precedence. Cannot be updated.
+ items:
+ description: EnvFromSource represents the source
+ of a set of ConfigMaps
+ properties:
+ configMapRef:
+ description: The ConfigMap to select from
+ properties:
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the ConfigMap
+ must be defined
+ type: boolean
+ type: object
+ prefix:
+ description: An optional identifier to prepend
+ to each key in the ConfigMap. Must be
+ a C_IDENTIFIER.
+ type: string
+ secretRef:
+ description: The Secret to select from
+ properties:
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the Secret
+ must be defined
+ type: boolean
+ type: object
+ type: object
+ type: array
+ image:
+ description: 'Docker image name. More info: https://kubernetes.io/docs/concepts/containers/images
+ This field is optional to allow higher level
+ config management to default or override container
+ images in workload controllers like Deployments
+ and StatefulSets.'
+ type: string
+ imagePullPolicy:
+ description: 'Image pull policy. One of Always,
+ Never, IfNotPresent. Defaults to Always if :latest
+ tag is specified, or IfNotPresent otherwise.
+ Cannot be updated. More info: https://kubernetes.io/docs/concepts/containers/images#updating-images'
+ type: string
+ lifecycle:
+ description: Actions that the management system
+ should take in response to container lifecycle
+ events. Cannot be updated.
+ properties:
+ postStart:
+ description: 'PostStart is called immediately
+ after a container is created. If the handler
+ fails, the container is terminated and restarted
+ according to its restart policy. Other management
+ of the container blocks until the hook completes.
+ More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies
+ the action to take.
+ properties:
+ command:
+ description: Command is the command
+ line to execute inside the container,
+ the working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it
+ is not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to
+ explicitly call out to that shell.
+ Exit status of 0 is treated as live/healthy
+ and non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ httpGet:
+ description: HTTPGet specifies the http
+ request to perform.
+ properties:
+ host:
+ description: Host name to connect
+ to, defaults to the pod IP. You
+ probably want to set "Host" in httpHeaders
+ instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set
+ in the request. HTTP allows repeated
+ headers.
+ items:
+ description: HTTPHeader describes
+ a custom header to be used in
+ HTTP probes
+ properties:
+ name:
+ description: The header field
+ name
+ type: string
+ value:
+ description: The header field
+ value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the
+ HTTP server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not
+ yet supported TODO: implement a realistic
+ TCP lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name
+ to connect to, defaults to the pod
+ IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ type: object
+ preStop:
+ description: 'PreStop is called immediately
+ before a container is terminated due to
+ an API request or management event such
+ as liveness/startup probe failure, preemption,
+ resource contention, etc. The handler is
+ not called if the container crashes or exits.
+ The reason for termination is passed to
+ the handler. The Pod''s termination grace
+ period countdown begins before the PreStop
+ hooked is executed. Regardless of the outcome
+ of the handler, the container will eventually
+ terminate within the Pod''s termination
+ grace period. Other management of the container
+ blocks until the hook completes or until
+ the termination grace period is reached.
+ More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies
+ the action to take.
+ properties:
+ command:
+ description: Command is the command
+ line to execute inside the container,
+ the working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it
+ is not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to
+ explicitly call out to that shell.
+ Exit status of 0 is treated as live/healthy
+ and non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ httpGet:
+ description: HTTPGet specifies the http
+ request to perform.
+ properties:
+ host:
+ description: Host name to connect
+ to, defaults to the pod IP. You
+ probably want to set "Host" in httpHeaders
+ instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set
+ in the request. HTTP allows repeated
+ headers.
+ items:
+ description: HTTPHeader describes
+ a custom header to be used in
+ HTTP probes
+ properties:
+ name:
+ description: The header field
+ name
+ type: string
+ value:
+ description: The header field
+ value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the
+ HTTP server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not
+ yet supported TODO: implement a realistic
+ TCP lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name
+ to connect to, defaults to the pod
+ IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ type: object
+ type: object
+ livenessProbe:
+ description: 'Periodic probe of container liveness.
+ Container will be restarted if the probe fails.
+ Cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ name:
+ description: Name of the container specified as
+ a DNS_LABEL. Each container in a pod must have
+ a unique name (DNS_LABEL). Cannot be updated.
+ type: string
+ ports:
+ description: List of ports to expose from the
+ container. Exposing a port here gives the system
+ additional information about the network connections
+ a container uses, but is primarily informational.
+ Not specifying a port here DOES NOT prevent
+ that port from being exposed. Any port which
+ is listening on the default "0.0.0.0" address
+ inside a container will be accessible from the
+ network. Cannot be updated.
+ items:
+ description: ContainerPort represents a network
+ port in a single container.
+ properties:
+ containerPort:
+ description: Number of port to expose on
+ the pod's IP address. This must be a valid
+ port number, 0 < x < 65536.
+ format: int32
+ type: integer
+ hostIP:
+ description: What host IP to bind the external
+ port to.
+ type: string
+ hostPort:
+ description: Number of port to expose on
+ the host. If specified, this must be a
+ valid port number, 0 < x < 65536. If HostNetwork
+ is specified, this must match ContainerPort.
+ Most containers do not need this.
+ format: int32
+ type: integer
+ name:
+ description: If specified, this must be
+ an IANA_SVC_NAME and unique within the
+ pod. Each named port in a pod must have
+ a unique name. Name for the port that
+ can be referred to by services.
+ type: string
+ protocol:
+ default: TCP
+ description: Protocol for port. Must be
+ UDP, TCP, or SCTP. Defaults to "TCP".
+ type: string
+ required:
+ - containerPort
+ type: object
+ type: array
+ x-kubernetes-list-map-keys:
+ - containerPort
+ - protocol
+ x-kubernetes-list-type: map
+ readinessProbe:
+ description: 'Periodic probe of container service
+ readiness. Container will be removed from service
+ endpoints if the probe fails. Cannot be updated.
+ More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ resources:
+ description: 'Compute Resources required by this
+ container. Cannot be updated. More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ properties:
+ limits:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Limits describes the maximum
+ amount of compute resources allowed. More
+ info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ requests:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Requests describes the minimum
+ amount of compute resources required. If
+ Requests is omitted for a container, it
+ defaults to Limits if that is explicitly
+ specified, otherwise to an implementation-defined
+ value. More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ type: object
+ securityContext:
+ description: 'SecurityContext defines the security
+ options the container should be run with. If
+ set, the fields of SecurityContext override
+ the equivalent fields of PodSecurityContext.
+ More info: https://kubernetes.io/docs/tasks/configure-pod-container/security-context/'
+ properties:
+ allowPrivilegeEscalation:
+ description: 'AllowPrivilegeEscalation controls
+ whether a process can gain more privileges
+ than its parent process. This bool directly
+ controls if the no_new_privs flag will be
+ set on the container process. AllowPrivilegeEscalation
+ is true always when the container is: 1)
+ run as Privileged 2) has CAP_SYS_ADMIN'
+ type: boolean
+ capabilities:
+ description: The capabilities to add/drop
+ when running containers. Defaults to the
+ default set of capabilities granted by the
+ container runtime.
+ properties:
+ add:
+ description: Added capabilities
+ items:
+ description: Capability represent POSIX
+ capabilities type
+ type: string
+ type: array
+ drop:
+ description: Removed capabilities
+ items:
+ description: Capability represent POSIX
+ capabilities type
+ type: string
+ type: array
+ type: object
+ privileged:
+ description: Run container in privileged mode.
+ Processes in privileged containers are essentially
+ equivalent to root on the host. Defaults
+ to false.
+ type: boolean
+ procMount:
+ description: procMount denotes the type of
+ proc mount to use for the containers. The
+ default is DefaultProcMount which uses the
+ container runtime defaults for readonly
+ paths and masked paths. This requires the
+ ProcMountType feature flag to be enabled.
+ type: string
+ readOnlyRootFilesystem:
+ description: Whether this container has a
+ read-only root filesystem. Default is false.
+ type: boolean
+ runAsGroup:
+ description: The GID to run the entrypoint
+ of the container process. Uses runtime default
+ if unset. May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ format: int64
+ type: integer
+ runAsNonRoot:
+ description: Indicates that the container
+ must run as a non-root user. If true, the
+ Kubelet will validate the image at runtime
+ to ensure that it does not run as UID 0
+ (root) and fail to start the container if
+ it does. If unset or false, no such validation
+ will be performed. May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ type: boolean
+ runAsUser:
+ description: The UID to run the entrypoint
+ of the container process. Defaults to user
+ specified in image metadata if unspecified.
+ May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ format: int64
+ type: integer
+ seLinuxOptions:
+ description: The SELinux context to be applied
+ to the container. If unspecified, the container
+ runtime will allocate a random SELinux context
+ for each container. May also be set in
+ PodSecurityContext. If set in both SecurityContext
+ and PodSecurityContext, the value specified
+ in SecurityContext takes precedence.
+ properties:
+ level:
+ description: Level is SELinux level label
+ that applies to the container.
+ type: string
+ role:
+ description: Role is a SELinux role label
+ that applies to the container.
+ type: string
+ type:
+ description: Type is a SELinux type label
+ that applies to the container.
+ type: string
+ user:
+ description: User is a SELinux user label
+ that applies to the container.
+ type: string
+ type: object
+ seccompProfile:
+ description: The seccomp options to use by
+ this container. If seccomp options are provided
+ at both the pod & container level, the container
+ options override the pod options.
+ properties:
+ localhostProfile:
+ description: localhostProfile indicates
+ a profile defined in a file on the node
+ should be used. The profile must be
+ preconfigured on the node to work. Must
+ be a descending path, relative to the
+ kubelet's configured seccomp profile
+ location. Must only be set if type is
+ "Localhost".
+ type: string
+ type:
+ description: "type indicates which kind
+ of seccomp profile will be applied.
+ Valid options are: \n Localhost - a
+ profile defined in a file on the node
+ should be used. RuntimeDefault - the
+ container runtime default profile should
+ be used. Unconfined - no profile should
+ be applied."
+ type: string
+ required:
+ - type
+ type: object
+ windowsOptions:
+ description: The Windows specific settings
+ applied to all containers. If unspecified,
+ the options from the PodSecurityContext
+ will be used. If set in both SecurityContext
+ and PodSecurityContext, the value specified
+ in SecurityContext takes precedence.
+ properties:
+ gmsaCredentialSpec:
+ description: GMSACredentialSpec is where
+ the GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa)
+ inlines the contents of the GMSA credential
+ spec named by the GMSACredentialSpecName
+ field.
+ type: string
+ gmsaCredentialSpecName:
+ description: GMSACredentialSpecName is
+ the name of the GMSA credential spec
+ to use.
+ type: string
+ hostProcess:
+ description: HostProcess determines if
+ a container should be run as a 'Host
+ Process' container. This field is alpha-level
+ and will only be honored by components
+ that enable the WindowsHostProcessContainers
+ feature flag. Setting this field without
+ the feature flag will result in errors
+ when validating the Pod. All of a Pod's
+ containers must have the same effective
+ HostProcess value (it is not allowed
+ to have a mix of HostProcess containers
+ and non-HostProcess containers). In
+ addition, if HostProcess is true then
+ HostNetwork must also be set to true.
+ type: boolean
+ runAsUserName:
+ description: The UserName in Windows to
+ run the entrypoint of the container
+ process. Defaults to the user specified
+ in image metadata if unspecified. May
+ also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext
+ takes precedence.
+ type: string
+ type: object
+ type: object
+ startupProbe:
+ description: 'StartupProbe indicates that the
+ Pod has successfully initialized. If specified,
+ no other probes are executed until this completes
+ successfully. If this probe fails, the Pod will
+ be restarted, just as if the livenessProbe failed.
+ This can be used to provide different probe
+ parameters at the beginning of a Pod''s lifecycle,
+ when it might take a long time to load data
+ or warm a cache, than during steady-state operation.
+ This cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ stdin:
+ description: Whether this container should allocate
+ a buffer for stdin in the container runtime.
+ If this is not set, reads from stdin in the
+ container will always result in EOF. Default
+ is false.
+ type: boolean
+ stdinOnce:
+ description: Whether the container runtime should
+ close the stdin channel after it has been opened
+ by a single attach. When stdin is true the stdin
+ stream will remain open across multiple attach
+ sessions. If stdinOnce is set to true, stdin
+ is opened on container start, is empty until
+ the first client attaches to stdin, and then
+ remains open and accepts data until the client
+ disconnects, at which time stdin is closed and
+ remains closed until the container is restarted.
+ If this flag is false, a container processes
+ that reads from stdin will never receive an
+ EOF. Default is false
+ type: boolean
+ terminationMessagePath:
+ description: 'Optional: Path at which the file
+ to which the container''s termination message
+ will be written is mounted into the container''s
+ filesystem. Message written is intended to be
+ brief final status, such as an assertion failure
+ message. Will be truncated by the node if greater
+ than 4096 bytes. The total message length across
+ all containers will be limited to 12kb. Defaults
+ to /dev/termination-log. Cannot be updated.'
+ type: string
+ terminationMessagePolicy:
+ description: Indicate how the termination message
+ should be populated. File will use the contents
+ of terminationMessagePath to populate the container
+ status message on both success and failure.
+ FallbackToLogsOnError will use the last chunk
+ of container log output if the termination message
+ file is empty and the container exited with
+ an error. The log output is limited to 2048
+ bytes or 80 lines, whichever is smaller. Defaults
+ to File. Cannot be updated.
+ type: string
+ tty:
+ description: Whether this container should allocate
+ a TTY for itself, also requires 'stdin' to be
+ true. Default is false.
+ type: boolean
+ volumeDevices:
+ description: volumeDevices is the list of block
+ devices to be used by the container.
+ items:
+ description: volumeDevice describes a mapping
+ of a raw block device within a container.
+ properties:
+ devicePath:
+ description: devicePath is the path inside
+ of the container that the device will
+ be mapped to.
+ type: string
+ name:
+ description: name must match the name of
+ a persistentVolumeClaim in the pod
+ type: string
+ required:
+ - devicePath
+ - name
+ type: object
+ type: array
+ volumeMounts:
+ description: Pod volumes to mount into the container's
+ filesystem. Cannot be updated.
+ items:
+ description: VolumeMount describes a mounting
+ of a Volume within a container.
+ properties:
+ mountPath:
+ description: Path within the container at
+ which the volume should be mounted. Must
+ not contain ':'.
+ type: string
+ mountPropagation:
+ description: mountPropagation determines
+ how mounts are propagated from the host
+ to container and the other way around.
+ When not set, MountPropagationNone is
+ used. This field is beta in 1.10.
+ type: string
+ name:
+ description: This must match the Name of
+ a Volume.
+ type: string
+ readOnly:
+ description: Mounted read-only if true,
+ read-write otherwise (false or unspecified).
+ Defaults to false.
+ type: boolean
+ subPath:
+ description: Path within the volume from
+ which the container's volume should be
+ mounted. Defaults to "" (volume's root).
+ type: string
+ subPathExpr:
+ description: Expanded path within the volume
+ from which the container's volume should
+ be mounted. Behaves similarly to SubPath
+ but environment variable references $(VAR_NAME)
+ are expanded using the container's environment.
+ Defaults to "" (volume's root). SubPathExpr
+ and SubPath are mutually exclusive.
+ type: string
+ required:
+ - mountPath
+ - name
+ type: object
+ type: array
+ workingDir:
+ description: Container's working directory. If
+ not specified, the container runtime's default
+ will be used, which might be configured in the
+ container image. Cannot be updated.
+ type: string
+ required:
+ - name
+ type: object
+ type: array
+ dnsConfig:
+ description: Specifies the DNS parameters of a pod.
+ Parameters specified here will be merged to the generated
+ DNS configuration based on DNSPolicy.
+ properties:
+ nameservers:
+ description: A list of DNS name server IP addresses.
+ This will be appended to the base nameservers
+ generated from DNSPolicy. Duplicated nameservers
+ will be removed.
+ items:
+ type: string
+ type: array
+ options:
+ description: A list of DNS resolver options. This
+ will be merged with the base options generated
+ from DNSPolicy. Duplicated entries will be removed.
+ Resolution options given in Options will override
+ those that appear in the base DNSPolicy.
+ items:
+ description: PodDNSConfigOption defines DNS resolver
+ options of a pod.
+ properties:
+ name:
+ description: Required.
+ type: string
+ value:
+ type: string
+ type: object
+ type: array
+ searches:
+ description: A list of DNS search domains for host-name
+ lookup. This will be appended to the base search
+ paths generated from DNSPolicy. Duplicated search
+ paths will be removed.
+ items:
+ type: string
+ type: array
+ type: object
+ dnsPolicy:
+ description: Set DNS policy for the pod. Defaults to
+ "ClusterFirst". Valid values are 'ClusterFirstWithHostNet',
+ 'ClusterFirst', 'Default' or 'None'. DNS parameters
+ given in DNSConfig will be merged with the policy
+ selected with DNSPolicy. To have DNS options set along
+ with hostNetwork, you have to specify DNS policy explicitly
+ to 'ClusterFirstWithHostNet'.
+ type: string
+ enableServiceLinks:
+ description: 'EnableServiceLinks indicates whether information
+ about services should be injected into pod''s environment
+ variables, matching the syntax of Docker links. Optional:
+ Defaults to true.'
+ type: boolean
+ ephemeralContainers:
+ description: List of ephemeral containers run in this
+ pod. Ephemeral containers may be run in an existing
+ pod to perform user-initiated actions such as debugging.
+ This list cannot be specified when creating a pod,
+ and it cannot be modified by updating the pod spec.
+ In order to add an ephemeral container to an existing
+ pod, use the pod's ephemeralcontainers subresource.
+ This field is alpha-level and is only honored by servers
+ that enable the EphemeralContainers feature.
+ items:
+ description: An EphemeralContainer is a container
+ that may be added temporarily to an existing pod
+ for user-initiated activities such as debugging.
+ Ephemeral containers have no resource or scheduling
+ guarantees, and they will not be restarted when
+ they exit or when a pod is removed or restarted.
+ If an ephemeral container causes a pod to exceed
+ its resource allocation, the pod may be evicted.
+ Ephemeral containers may not be added by directly
+ updating the pod spec. They must be added via the
+ pod's ephemeralcontainers subresource, and they
+ will appear in the pod spec once added. This is
+ an alpha feature enabled by the EphemeralContainers
+ feature flag.
+ properties:
+ args:
+ description: 'Arguments to the entrypoint. The
+ docker image''s CMD is used if this is not provided.
+ Variable references $(VAR_NAME) are expanded
+ using the container''s environment. If a variable
+ cannot be resolved, the reference in the input
+ string will be unchanged. Double $$ are reduced
+ to a single $, which allows for escaping the
+ $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)" will
+ produce the string literal "$(VAR_NAME)". Escaped
+ references will never be expanded, regardless
+ of whether the variable exists or not. Cannot
+ be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell'
+ items:
+ type: string
+ type: array
+ command:
+ description: 'Entrypoint array. Not executed within
+ a shell. The docker image''s ENTRYPOINT is used
+ if this is not provided. Variable references
+ $(VAR_NAME) are expanded using the container''s
+ environment. If a variable cannot be resolved,
+ the reference in the input string will be unchanged.
+ Double $$ are reduced to a single $, which allows
+ for escaping the $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)"
+ will produce the string literal "$(VAR_NAME)".
+ Escaped references will never be expanded, regardless
+ of whether the variable exists or not. Cannot
+ be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell'
+ items:
+ type: string
+ type: array
+ env:
+ description: List of environment variables to
+ set in the container. Cannot be updated.
+ items:
+ description: EnvVar represents an environment
+ variable present in a Container.
+ properties:
+ name:
+ description: Name of the environment variable.
+ Must be a C_IDENTIFIER.
+ type: string
+ value:
+ description: 'Variable references $(VAR_NAME)
+ are expanded using the previously defined
+ environment variables in the container
+ and any service environment variables.
+ If a variable cannot be resolved, the
+ reference in the input string will be
+ unchanged. Double $$ are reduced to a
+ single $, which allows for escaping the
+ $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)"
+ will produce the string literal "$(VAR_NAME)".
+ Escaped references will never be expanded,
+ regardless of whether the variable exists
+ or not. Defaults to "".'
+ type: string
+ valueFrom:
+ description: Source for the environment
+ variable's value. Cannot be used if value
+ is not empty.
+ properties:
+ configMapKeyRef:
+ description: Selects a key of a ConfigMap.
+ properties:
+ key:
+ description: The key to select.
+ type: string
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ ConfigMap or its key must be defined
+ type: boolean
+ required:
+ - key
+ type: object
+ fieldRef:
+ description: 'Selects a field of the
+ pod: supports metadata.name, metadata.namespace,
+ `metadata.labels['''']`, `metadata.annotations['''']`,
+ spec.nodeName, spec.serviceAccountName,
+ status.hostIP, status.podIP, status.podIPs.'
+ properties:
+ apiVersion:
+ description: Version of the schema
+ the FieldPath is written in terms
+ of, defaults to "v1".
+ type: string
+ fieldPath:
+ description: Path of the field to
+ select in the specified API version.
+ type: string
+ required:
+ - fieldPath
+ type: object
+ resourceFieldRef:
+ description: 'Selects a resource of
+ the container: only resources limits
+ and requests (limits.cpu, limits.memory,
+ limits.ephemeral-storage, requests.cpu,
+ requests.memory and requests.ephemeral-storage)
+ are currently supported.'
+ properties:
+ containerName:
+ description: 'Container name: required
+ for volumes, optional for env
+ vars'
+ type: string
+ divisor:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Specifies the output
+ format of the exposed resources,
+ defaults to "1"
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ resource:
+ description: 'Required: resource
+ to select'
+ type: string
+ required:
+ - resource
+ type: object
+ secretKeyRef:
+ description: Selects a key of a secret
+ in the pod's namespace
+ properties:
+ key:
+ description: The key of the secret
+ to select from. Must be a valid
+ secret key.
+ type: string
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ Secret or its key must be defined
+ type: boolean
+ required:
+ - key
+ type: object
+ type: object
+ required:
+ - name
+ type: object
+ type: array
+ envFrom:
+ description: List of sources to populate environment
+ variables in the container. The keys defined
+ within a source must be a C_IDENTIFIER. All
+ invalid keys will be reported as an event when
+ the container is starting. When a key exists
+ in multiple sources, the value associated with
+ the last source will take precedence. Values
+ defined by an Env with a duplicate key will
+ take precedence. Cannot be updated.
+ items:
+ description: EnvFromSource represents the source
+ of a set of ConfigMaps
+ properties:
+ configMapRef:
+ description: The ConfigMap to select from
+ properties:
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the ConfigMap
+ must be defined
+ type: boolean
+ type: object
+ prefix:
+ description: An optional identifier to prepend
+ to each key in the ConfigMap. Must be
+ a C_IDENTIFIER.
+ type: string
+ secretRef:
+ description: The Secret to select from
+ properties:
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the Secret
+ must be defined
+ type: boolean
+ type: object
+ type: object
+ type: array
+ image:
+ description: 'Docker image name. More info: https://kubernetes.io/docs/concepts/containers/images'
+ type: string
+ imagePullPolicy:
+ description: 'Image pull policy. One of Always,
+ Never, IfNotPresent. Defaults to Always if :latest
+ tag is specified, or IfNotPresent otherwise.
+ Cannot be updated. More info: https://kubernetes.io/docs/concepts/containers/images#updating-images'
+ type: string
+ lifecycle:
+ description: Lifecycle is not allowed for ephemeral
+ containers.
+ properties:
+ postStart:
+ description: 'PostStart is called immediately
+ after a container is created. If the handler
+ fails, the container is terminated and restarted
+ according to its restart policy. Other management
+ of the container blocks until the hook completes.
+ More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies
+ the action to take.
+ properties:
+ command:
+ description: Command is the command
+ line to execute inside the container,
+ the working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it
+ is not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to
+ explicitly call out to that shell.
+ Exit status of 0 is treated as live/healthy
+ and non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ httpGet:
+ description: HTTPGet specifies the http
+ request to perform.
+ properties:
+ host:
+ description: Host name to connect
+ to, defaults to the pod IP. You
+ probably want to set "Host" in httpHeaders
+ instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set
+ in the request. HTTP allows repeated
+ headers.
+ items:
+ description: HTTPHeader describes
+ a custom header to be used in
+ HTTP probes
+ properties:
+ name:
+ description: The header field
+ name
+ type: string
+ value:
+ description: The header field
+ value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the
+ HTTP server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not
+ yet supported TODO: implement a realistic
+ TCP lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name
+ to connect to, defaults to the pod
+ IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ type: object
+ preStop:
+ description: 'PreStop is called immediately
+ before a container is terminated due to
+ an API request or management event such
+ as liveness/startup probe failure, preemption,
+ resource contention, etc. The handler is
+ not called if the container crashes or exits.
+ The reason for termination is passed to
+ the handler. The Pod''s termination grace
+ period countdown begins before the PreStop
+ hooked is executed. Regardless of the outcome
+ of the handler, the container will eventually
+ terminate within the Pod''s termination
+ grace period. Other management of the container
+ blocks until the hook completes or until
+ the termination grace period is reached.
+ More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies
+ the action to take.
+ properties:
+ command:
+ description: Command is the command
+ line to execute inside the container,
+ the working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it
+ is not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to
+ explicitly call out to that shell.
+ Exit status of 0 is treated as live/healthy
+ and non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ httpGet:
+ description: HTTPGet specifies the http
+ request to perform.
+ properties:
+ host:
+ description: Host name to connect
+ to, defaults to the pod IP. You
+ probably want to set "Host" in httpHeaders
+ instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set
+ in the request. HTTP allows repeated
+ headers.
+ items:
+ description: HTTPHeader describes
+ a custom header to be used in
+ HTTP probes
+ properties:
+ name:
+ description: The header field
+ name
+ type: string
+ value:
+ description: The header field
+ value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the
+ HTTP server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not
+ yet supported TODO: implement a realistic
+ TCP lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name
+ to connect to, defaults to the pod
+ IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ type: object
+ type: object
+ livenessProbe:
+ description: Probes are not allowed for ephemeral
+ containers.
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ name:
+ description: Name of the ephemeral container specified
+ as a DNS_LABEL. This name must be unique among
+ all containers, init containers and ephemeral
+ containers.
+ type: string
+ ports:
+ description: Ports are not allowed for ephemeral
+ containers.
+ items:
+ description: ContainerPort represents a network
+ port in a single container.
+ properties:
+ containerPort:
+ description: Number of port to expose on
+ the pod's IP address. This must be a valid
+ port number, 0 < x < 65536.
+ format: int32
+ type: integer
+ hostIP:
+ description: What host IP to bind the external
+ port to.
+ type: string
+ hostPort:
+ description: Number of port to expose on
+ the host. If specified, this must be a
+ valid port number, 0 < x < 65536. If HostNetwork
+ is specified, this must match ContainerPort.
+ Most containers do not need this.
+ format: int32
+ type: integer
+ name:
+ description: If specified, this must be
+ an IANA_SVC_NAME and unique within the
+ pod. Each named port in a pod must have
+ a unique name. Name for the port that
+ can be referred to by services.
+ type: string
+ protocol:
+ default: TCP
+ description: Protocol for port. Must be
+ UDP, TCP, or SCTP. Defaults to "TCP".
+ type: string
+ required:
+ - containerPort
+ type: object
+ type: array
+ readinessProbe:
+ description: Probes are not allowed for ephemeral
+ containers.
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ resources:
+ description: Resources are not allowed for ephemeral
+ containers. Ephemeral containers use spare resources
+ already allocated to the pod.
+ properties:
+ limits:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Limits describes the maximum
+ amount of compute resources allowed. More
+ info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ requests:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Requests describes the minimum
+ amount of compute resources required. If
+ Requests is omitted for a container, it
+ defaults to Limits if that is explicitly
+ specified, otherwise to an implementation-defined
+ value. More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ type: object
+ securityContext:
+ description: 'Optional: SecurityContext defines
+ the security options the ephemeral container
+ should be run with. If set, the fields of SecurityContext
+ override the equivalent fields of PodSecurityContext.'
+ properties:
+ allowPrivilegeEscalation:
+ description: 'AllowPrivilegeEscalation controls
+ whether a process can gain more privileges
+ than its parent process. This bool directly
+ controls if the no_new_privs flag will be
+ set on the container process. AllowPrivilegeEscalation
+ is true always when the container is: 1)
+ run as Privileged 2) has CAP_SYS_ADMIN'
+ type: boolean
+ capabilities:
+ description: The capabilities to add/drop
+ when running containers. Defaults to the
+ default set of capabilities granted by the
+ container runtime.
+ properties:
+ add:
+ description: Added capabilities
+ items:
+ description: Capability represent POSIX
+ capabilities type
+ type: string
+ type: array
+ drop:
+ description: Removed capabilities
+ items:
+ description: Capability represent POSIX
+ capabilities type
+ type: string
+ type: array
+ type: object
+ privileged:
+ description: Run container in privileged mode.
+ Processes in privileged containers are essentially
+ equivalent to root on the host. Defaults
+ to false.
+ type: boolean
+ procMount:
+ description: procMount denotes the type of
+ proc mount to use for the containers. The
+ default is DefaultProcMount which uses the
+ container runtime defaults for readonly
+ paths and masked paths. This requires the
+ ProcMountType feature flag to be enabled.
+ type: string
+ readOnlyRootFilesystem:
+ description: Whether this container has a
+ read-only root filesystem. Default is false.
+ type: boolean
+ runAsGroup:
+ description: The GID to run the entrypoint
+ of the container process. Uses runtime default
+ if unset. May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ format: int64
+ type: integer
+ runAsNonRoot:
+ description: Indicates that the container
+ must run as a non-root user. If true, the
+ Kubelet will validate the image at runtime
+ to ensure that it does not run as UID 0
+ (root) and fail to start the container if
+ it does. If unset or false, no such validation
+ will be performed. May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ type: boolean
+ runAsUser:
+ description: The UID to run the entrypoint
+ of the container process. Defaults to user
+ specified in image metadata if unspecified.
+ May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ format: int64
+ type: integer
+ seLinuxOptions:
+ description: The SELinux context to be applied
+ to the container. If unspecified, the container
+ runtime will allocate a random SELinux context
+ for each container. May also be set in
+ PodSecurityContext. If set in both SecurityContext
+ and PodSecurityContext, the value specified
+ in SecurityContext takes precedence.
+ properties:
+ level:
+ description: Level is SELinux level label
+ that applies to the container.
+ type: string
+ role:
+ description: Role is a SELinux role label
+ that applies to the container.
+ type: string
+ type:
+ description: Type is a SELinux type label
+ that applies to the container.
+ type: string
+ user:
+ description: User is a SELinux user label
+ that applies to the container.
+ type: string
+ type: object
+ seccompProfile:
+ description: The seccomp options to use by
+ this container. If seccomp options are provided
+ at both the pod & container level, the container
+ options override the pod options.
+ properties:
+ localhostProfile:
+ description: localhostProfile indicates
+ a profile defined in a file on the node
+ should be used. The profile must be
+ preconfigured on the node to work. Must
+ be a descending path, relative to the
+ kubelet's configured seccomp profile
+ location. Must only be set if type is
+ "Localhost".
+ type: string
+ type:
+ description: "type indicates which kind
+ of seccomp profile will be applied.
+ Valid options are: \n Localhost - a
+ profile defined in a file on the node
+ should be used. RuntimeDefault - the
+ container runtime default profile should
+ be used. Unconfined - no profile should
+ be applied."
+ type: string
+ required:
+ - type
+ type: object
+ windowsOptions:
+ description: The Windows specific settings
+ applied to all containers. If unspecified,
+ the options from the PodSecurityContext
+ will be used. If set in both SecurityContext
+ and PodSecurityContext, the value specified
+ in SecurityContext takes precedence.
+ properties:
+ gmsaCredentialSpec:
+ description: GMSACredentialSpec is where
+ the GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa)
+ inlines the contents of the GMSA credential
+ spec named by the GMSACredentialSpecName
+ field.
+ type: string
+ gmsaCredentialSpecName:
+ description: GMSACredentialSpecName is
+ the name of the GMSA credential spec
+ to use.
+ type: string
+ hostProcess:
+ description: HostProcess determines if
+ a container should be run as a 'Host
+ Process' container. This field is alpha-level
+ and will only be honored by components
+ that enable the WindowsHostProcessContainers
+ feature flag. Setting this field without
+ the feature flag will result in errors
+ when validating the Pod. All of a Pod's
+ containers must have the same effective
+ HostProcess value (it is not allowed
+ to have a mix of HostProcess containers
+ and non-HostProcess containers). In
+ addition, if HostProcess is true then
+ HostNetwork must also be set to true.
+ type: boolean
+ runAsUserName:
+ description: The UserName in Windows to
+ run the entrypoint of the container
+ process. Defaults to the user specified
+ in image metadata if unspecified. May
+ also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext
+ takes precedence.
+ type: string
+ type: object
+ type: object
+ startupProbe:
+ description: Probes are not allowed for ephemeral
+ containers.
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ stdin:
+ description: Whether this container should allocate
+ a buffer for stdin in the container runtime.
+ If this is not set, reads from stdin in the
+ container will always result in EOF. Default
+ is false.
+ type: boolean
+ stdinOnce:
+ description: Whether the container runtime should
+ close the stdin channel after it has been opened
+ by a single attach. When stdin is true the stdin
+ stream will remain open across multiple attach
+ sessions. If stdinOnce is set to true, stdin
+ is opened on container start, is empty until
+ the first client attaches to stdin, and then
+ remains open and accepts data until the client
+ disconnects, at which time stdin is closed and
+ remains closed until the container is restarted.
+ If this flag is false, a container processes
+ that reads from stdin will never receive an
+ EOF. Default is false
+ type: boolean
+ targetContainerName:
+ description: If set, the name of the container
+ from PodSpec that this ephemeral container targets.
+ The ephemeral container will be run in the namespaces
+ (IPC, PID, etc) of this container. If not set
+ then the ephemeral container is run in whatever
+ namespaces are shared for the pod. Note that
+ the container runtime must support this feature.
+ type: string
+ terminationMessagePath:
+ description: 'Optional: Path at which the file
+ to which the container''s termination message
+ will be written is mounted into the container''s
+ filesystem. Message written is intended to be
+ brief final status, such as an assertion failure
+ message. Will be truncated by the node if greater
+ than 4096 bytes. The total message length across
+ all containers will be limited to 12kb. Defaults
+ to /dev/termination-log. Cannot be updated.'
+ type: string
+ terminationMessagePolicy:
+ description: Indicate how the termination message
+ should be populated. File will use the contents
+ of terminationMessagePath to populate the container
+ status message on both success and failure.
+ FallbackToLogsOnError will use the last chunk
+ of container log output if the termination message
+ file is empty and the container exited with
+ an error. The log output is limited to 2048
+ bytes or 80 lines, whichever is smaller. Defaults
+ to File. Cannot be updated.
+ type: string
+ tty:
+ description: Whether this container should allocate
+ a TTY for itself, also requires 'stdin' to be
+ true. Default is false.
+ type: boolean
+ volumeDevices:
+ description: volumeDevices is the list of block
+ devices to be used by the container.
+ items:
+ description: volumeDevice describes a mapping
+ of a raw block device within a container.
+ properties:
+ devicePath:
+ description: devicePath is the path inside
+ of the container that the device will
+ be mapped to.
+ type: string
+ name:
+ description: name must match the name of
+ a persistentVolumeClaim in the pod
+ type: string
+ required:
+ - devicePath
+ - name
+ type: object
+ type: array
+ volumeMounts:
+ description: Pod volumes to mount into the container's
+ filesystem. Cannot be updated.
+ items:
+ description: VolumeMount describes a mounting
+ of a Volume within a container.
+ properties:
+ mountPath:
+ description: Path within the container at
+ which the volume should be mounted. Must
+ not contain ':'.
+ type: string
+ mountPropagation:
+ description: mountPropagation determines
+ how mounts are propagated from the host
+ to container and the other way around.
+ When not set, MountPropagationNone is
+ used. This field is beta in 1.10.
+ type: string
+ name:
+ description: This must match the Name of
+ a Volume.
+ type: string
+ readOnly:
+ description: Mounted read-only if true,
+ read-write otherwise (false or unspecified).
+ Defaults to false.
+ type: boolean
+ subPath:
+ description: Path within the volume from
+ which the container's volume should be
+ mounted. Defaults to "" (volume's root).
+ type: string
+ subPathExpr:
+ description: Expanded path within the volume
+ from which the container's volume should
+ be mounted. Behaves similarly to SubPath
+ but environment variable references $(VAR_NAME)
+ are expanded using the container's environment.
+ Defaults to "" (volume's root). SubPathExpr
+ and SubPath are mutually exclusive.
+ type: string
+ required:
+ - mountPath
+ - name
+ type: object
+ type: array
+ workingDir:
+ description: Container's working directory. If
+ not specified, the container runtime's default
+ will be used, which might be configured in the
+ container image. Cannot be updated.
+ type: string
+ required:
+ - name
+ type: object
+ type: array
+ hostAliases:
+ description: HostAliases is an optional list of hosts
+ and IPs that will be injected into the pod's hosts
+ file if specified. This is only valid for non-hostNetwork
+ pods.
+ items:
+ description: HostAlias holds the mapping between IP
+ and hostnames that will be injected as an entry
+ in the pod's hosts file.
+ properties:
+ hostnames:
+ description: Hostnames for the above IP address.
+ items:
+ type: string
+ type: array
+ ip:
+ description: IP address of the host file entry.
+ type: string
+ type: object
+ type: array
+ hostIPC:
+ description: 'Use the host''s ipc namespace. Optional:
+ Default to false.'
+ type: boolean
+ hostNetwork:
+ description: Host networking requested for this pod.
+ Use the host's network namespace. If this option is
+ set, the ports that will be used must be specified.
+ Default to false.
+ type: boolean
+ hostPID:
+ description: 'Use the host''s pid namespace. Optional:
+ Default to false.'
+ type: boolean
+ hostname:
+ description: Specifies the hostname of the Pod If not
+ specified, the pod's hostname will be set to a system-defined
+ value.
+ type: string
+ imagePullSecrets:
+ description: 'ImagePullSecrets is an optional list of
+ references to secrets in the same namespace to use
+ for pulling any of the images used by this PodSpec.
+ If specified, these secrets will be passed to individual
+ puller implementations for them to use. For example,
+ in the case of docker, only DockerConfig type secrets
+ are honored. More info: https://kubernetes.io/docs/concepts/containers/images#specifying-imagepullsecrets-on-a-pod'
+ items:
+ description: LocalObjectReference contains enough
+ information to let you locate the referenced object
+ inside the same namespace.
+ properties:
+ name:
+ description: 'Name of the referent. More info:
+ https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion, kind,
+ uid?'
+ type: string
+ type: object
+ type: array
+ initContainers:
+ description: 'List of initialization containers belonging
+ to the pod. Init containers are executed in order
+ prior to containers being started. If any init container
+ fails, the pod is considered to have failed and is
+ handled according to its restartPolicy. The name for
+ an init container or normal container must be unique
+ among all containers. Init containers may not have
+ Lifecycle actions, Readiness probes, Liveness probes,
+ or Startup probes. The resourceRequirements of an
+ init container are taken into account during scheduling
+ by finding the highest request/limit for each resource
+ type, and then using the max of of that value or the
+ sum of the normal containers. Limits are applied to
+ init containers in a similar fashion. Init containers
+ cannot currently be added or removed. Cannot be updated.
+ More info: https://kubernetes.io/docs/concepts/workloads/pods/init-containers/'
+ items:
+ description: A single application container that you
+ want to run within a pod.
+ properties:
+ args:
+ description: 'Arguments to the entrypoint. The
+ docker image''s CMD is used if this is not provided.
+ Variable references $(VAR_NAME) are expanded
+ using the container''s environment. If a variable
+ cannot be resolved, the reference in the input
+ string will be unchanged. Double $$ are reduced
+ to a single $, which allows for escaping the
+ $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)" will
+ produce the string literal "$(VAR_NAME)". Escaped
+ references will never be expanded, regardless
+ of whether the variable exists or not. Cannot
+ be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell'
+ items:
+ type: string
+ type: array
+ command:
+ description: 'Entrypoint array. Not executed within
+ a shell. The docker image''s ENTRYPOINT is used
+ if this is not provided. Variable references
+ $(VAR_NAME) are expanded using the container''s
+ environment. If a variable cannot be resolved,
+ the reference in the input string will be unchanged.
+ Double $$ are reduced to a single $, which allows
+ for escaping the $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)"
+ will produce the string literal "$(VAR_NAME)".
+ Escaped references will never be expanded, regardless
+ of whether the variable exists or not. Cannot
+ be updated. More info: https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell'
+ items:
+ type: string
+ type: array
+ env:
+ description: List of environment variables to
+ set in the container. Cannot be updated.
+ items:
+ description: EnvVar represents an environment
+ variable present in a Container.
+ properties:
+ name:
+ description: Name of the environment variable.
+ Must be a C_IDENTIFIER.
+ type: string
+ value:
+ description: 'Variable references $(VAR_NAME)
+ are expanded using the previously defined
+ environment variables in the container
+ and any service environment variables.
+ If a variable cannot be resolved, the
+ reference in the input string will be
+ unchanged. Double $$ are reduced to a
+ single $, which allows for escaping the
+ $(VAR_NAME) syntax: i.e. "$$(VAR_NAME)"
+ will produce the string literal "$(VAR_NAME)".
+ Escaped references will never be expanded,
+ regardless of whether the variable exists
+ or not. Defaults to "".'
+ type: string
+ valueFrom:
+ description: Source for the environment
+ variable's value. Cannot be used if value
+ is not empty.
+ properties:
+ configMapKeyRef:
+ description: Selects a key of a ConfigMap.
+ properties:
+ key:
+ description: The key to select.
+ type: string
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ ConfigMap or its key must be defined
+ type: boolean
+ required:
+ - key
+ type: object
+ fieldRef:
+ description: 'Selects a field of the
+ pod: supports metadata.name, metadata.namespace,
+ `metadata.labels['''']`, `metadata.annotations['''']`,
+ spec.nodeName, spec.serviceAccountName,
+ status.hostIP, status.podIP, status.podIPs.'
+ properties:
+ apiVersion:
+ description: Version of the schema
+ the FieldPath is written in terms
+ of, defaults to "v1".
+ type: string
+ fieldPath:
+ description: Path of the field to
+ select in the specified API version.
+ type: string
+ required:
+ - fieldPath
+ type: object
+ resourceFieldRef:
+ description: 'Selects a resource of
+ the container: only resources limits
+ and requests (limits.cpu, limits.memory,
+ limits.ephemeral-storage, requests.cpu,
+ requests.memory and requests.ephemeral-storage)
+ are currently supported.'
+ properties:
+ containerName:
+ description: 'Container name: required
+ for volumes, optional for env
+ vars'
+ type: string
+ divisor:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Specifies the output
+ format of the exposed resources,
+ defaults to "1"
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ resource:
+ description: 'Required: resource
+ to select'
+ type: string
+ required:
+ - resource
+ type: object
+ secretKeyRef:
+ description: Selects a key of a secret
+ in the pod's namespace
+ properties:
+ key:
+ description: The key of the secret
+ to select from. Must be a valid
+ secret key.
+ type: string
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ Secret or its key must be defined
+ type: boolean
+ required:
+ - key
+ type: object
+ type: object
+ required:
+ - name
+ type: object
+ type: array
+ envFrom:
+ description: List of sources to populate environment
+ variables in the container. The keys defined
+ within a source must be a C_IDENTIFIER. All
+ invalid keys will be reported as an event when
+ the container is starting. When a key exists
+ in multiple sources, the value associated with
+ the last source will take precedence. Values
+ defined by an Env with a duplicate key will
+ take precedence. Cannot be updated.
+ items:
+ description: EnvFromSource represents the source
+ of a set of ConfigMaps
+ properties:
+ configMapRef:
+ description: The ConfigMap to select from
+ properties:
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the ConfigMap
+ must be defined
+ type: boolean
+ type: object
+ prefix:
+ description: An optional identifier to prepend
+ to each key in the ConfigMap. Must be
+ a C_IDENTIFIER.
+ type: string
+ secretRef:
+ description: The Secret to select from
+ properties:
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the Secret
+ must be defined
+ type: boolean
+ type: object
+ type: object
+ type: array
+ image:
+ description: 'Docker image name. More info: https://kubernetes.io/docs/concepts/containers/images
+ This field is optional to allow higher level
+ config management to default or override container
+ images in workload controllers like Deployments
+ and StatefulSets.'
+ type: string
+ imagePullPolicy:
+ description: 'Image pull policy. One of Always,
+ Never, IfNotPresent. Defaults to Always if :latest
+ tag is specified, or IfNotPresent otherwise.
+ Cannot be updated. More info: https://kubernetes.io/docs/concepts/containers/images#updating-images'
+ type: string
+ lifecycle:
+ description: Actions that the management system
+ should take in response to container lifecycle
+ events. Cannot be updated.
+ properties:
+ postStart:
+ description: 'PostStart is called immediately
+ after a container is created. If the handler
+ fails, the container is terminated and restarted
+ according to its restart policy. Other management
+ of the container blocks until the hook completes.
+ More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies
+ the action to take.
+ properties:
+ command:
+ description: Command is the command
+ line to execute inside the container,
+ the working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it
+ is not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to
+ explicitly call out to that shell.
+ Exit status of 0 is treated as live/healthy
+ and non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ httpGet:
+ description: HTTPGet specifies the http
+ request to perform.
+ properties:
+ host:
+ description: Host name to connect
+ to, defaults to the pod IP. You
+ probably want to set "Host" in httpHeaders
+ instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set
+ in the request. HTTP allows repeated
+ headers.
+ items:
+ description: HTTPHeader describes
+ a custom header to be used in
+ HTTP probes
+ properties:
+ name:
+ description: The header field
+ name
+ type: string
+ value:
+ description: The header field
+ value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the
+ HTTP server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not
+ yet supported TODO: implement a realistic
+ TCP lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name
+ to connect to, defaults to the pod
+ IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ type: object
+ preStop:
+ description: 'PreStop is called immediately
+ before a container is terminated due to
+ an API request or management event such
+ as liveness/startup probe failure, preemption,
+ resource contention, etc. The handler is
+ not called if the container crashes or exits.
+ The reason for termination is passed to
+ the handler. The Pod''s termination grace
+ period countdown begins before the PreStop
+ hooked is executed. Regardless of the outcome
+ of the handler, the container will eventually
+ terminate within the Pod''s termination
+ grace period. Other management of the container
+ blocks until the hook completes or until
+ the termination grace period is reached.
+ More info: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies
+ the action to take.
+ properties:
+ command:
+ description: Command is the command
+ line to execute inside the container,
+ the working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it
+ is not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to
+ explicitly call out to that shell.
+ Exit status of 0 is treated as live/healthy
+ and non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ httpGet:
+ description: HTTPGet specifies the http
+ request to perform.
+ properties:
+ host:
+ description: Host name to connect
+ to, defaults to the pod IP. You
+ probably want to set "Host" in httpHeaders
+ instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set
+ in the request. HTTP allows repeated
+ headers.
+ items:
+ description: HTTPHeader describes
+ a custom header to be used in
+ HTTP probes
+ properties:
+ name:
+ description: The header field
+ name
+ type: string
+ value:
+ description: The header field
+ value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the
+ HTTP server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not
+ yet supported TODO: implement a realistic
+ TCP lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name
+ to connect to, defaults to the pod
+ IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the
+ port to access on the container.
+ Number must be in the range 1 to
+ 65535. Name must be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ type: object
+ type: object
+ livenessProbe:
+ description: 'Periodic probe of container liveness.
+ Container will be restarted if the probe fails.
+ Cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ name:
+ description: Name of the container specified as
+ a DNS_LABEL. Each container in a pod must have
+ a unique name (DNS_LABEL). Cannot be updated.
+ type: string
+ ports:
+ description: List of ports to expose from the
+ container. Exposing a port here gives the system
+ additional information about the network connections
+ a container uses, but is primarily informational.
+ Not specifying a port here DOES NOT prevent
+ that port from being exposed. Any port which
+ is listening on the default "0.0.0.0" address
+ inside a container will be accessible from the
+ network. Cannot be updated.
+ items:
+ description: ContainerPort represents a network
+ port in a single container.
+ properties:
+ containerPort:
+ description: Number of port to expose on
+ the pod's IP address. This must be a valid
+ port number, 0 < x < 65536.
+ format: int32
+ type: integer
+ hostIP:
+ description: What host IP to bind the external
+ port to.
+ type: string
+ hostPort:
+ description: Number of port to expose on
+ the host. If specified, this must be a
+ valid port number, 0 < x < 65536. If HostNetwork
+ is specified, this must match ContainerPort.
+ Most containers do not need this.
+ format: int32
+ type: integer
+ name:
+ description: If specified, this must be
+ an IANA_SVC_NAME and unique within the
+ pod. Each named port in a pod must have
+ a unique name. Name for the port that
+ can be referred to by services.
+ type: string
+ protocol:
+ default: TCP
+ description: Protocol for port. Must be
+ UDP, TCP, or SCTP. Defaults to "TCP".
+ type: string
+ required:
+ - containerPort
+ type: object
+ type: array
+ x-kubernetes-list-map-keys:
+ - containerPort
+ - protocol
+ x-kubernetes-list-type: map
+ readinessProbe:
+ description: 'Periodic probe of container service
+ readiness. Container will be removed from service
+ endpoints if the probe fails. Cannot be updated.
+ More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ resources:
+ description: 'Compute Resources required by this
+ container. Cannot be updated. More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ properties:
+ limits:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Limits describes the maximum
+ amount of compute resources allowed. More
+ info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ requests:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Requests describes the minimum
+ amount of compute resources required. If
+ Requests is omitted for a container, it
+ defaults to Limits if that is explicitly
+ specified, otherwise to an implementation-defined
+ value. More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ type: object
+ securityContext:
+ description: 'SecurityContext defines the security
+ options the container should be run with. If
+ set, the fields of SecurityContext override
+ the equivalent fields of PodSecurityContext.
+ More info: https://kubernetes.io/docs/tasks/configure-pod-container/security-context/'
+ properties:
+ allowPrivilegeEscalation:
+ description: 'AllowPrivilegeEscalation controls
+ whether a process can gain more privileges
+ than its parent process. This bool directly
+ controls if the no_new_privs flag will be
+ set on the container process. AllowPrivilegeEscalation
+ is true always when the container is: 1)
+ run as Privileged 2) has CAP_SYS_ADMIN'
+ type: boolean
+ capabilities:
+ description: The capabilities to add/drop
+ when running containers. Defaults to the
+ default set of capabilities granted by the
+ container runtime.
+ properties:
+ add:
+ description: Added capabilities
+ items:
+ description: Capability represent POSIX
+ capabilities type
+ type: string
+ type: array
+ drop:
+ description: Removed capabilities
+ items:
+ description: Capability represent POSIX
+ capabilities type
+ type: string
+ type: array
+ type: object
+ privileged:
+ description: Run container in privileged mode.
+ Processes in privileged containers are essentially
+ equivalent to root on the host. Defaults
+ to false.
+ type: boolean
+ procMount:
+ description: procMount denotes the type of
+ proc mount to use for the containers. The
+ default is DefaultProcMount which uses the
+ container runtime defaults for readonly
+ paths and masked paths. This requires the
+ ProcMountType feature flag to be enabled.
+ type: string
+ readOnlyRootFilesystem:
+ description: Whether this container has a
+ read-only root filesystem. Default is false.
+ type: boolean
+ runAsGroup:
+ description: The GID to run the entrypoint
+ of the container process. Uses runtime default
+ if unset. May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ format: int64
+ type: integer
+ runAsNonRoot:
+ description: Indicates that the container
+ must run as a non-root user. If true, the
+ Kubelet will validate the image at runtime
+ to ensure that it does not run as UID 0
+ (root) and fail to start the container if
+ it does. If unset or false, no such validation
+ will be performed. May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ type: boolean
+ runAsUser:
+ description: The UID to run the entrypoint
+ of the container process. Defaults to user
+ specified in image metadata if unspecified.
+ May also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ format: int64
+ type: integer
+ seLinuxOptions:
+ description: The SELinux context to be applied
+ to the container. If unspecified, the container
+ runtime will allocate a random SELinux context
+ for each container. May also be set in
+ PodSecurityContext. If set in both SecurityContext
+ and PodSecurityContext, the value specified
+ in SecurityContext takes precedence.
+ properties:
+ level:
+ description: Level is SELinux level label
+ that applies to the container.
+ type: string
+ role:
+ description: Role is a SELinux role label
+ that applies to the container.
+ type: string
+ type:
+ description: Type is a SELinux type label
+ that applies to the container.
+ type: string
+ user:
+ description: User is a SELinux user label
+ that applies to the container.
+ type: string
+ type: object
+ seccompProfile:
+ description: The seccomp options to use by
+ this container. If seccomp options are provided
+ at both the pod & container level, the container
+ options override the pod options.
+ properties:
+ localhostProfile:
+ description: localhostProfile indicates
+ a profile defined in a file on the node
+ should be used. The profile must be
+ preconfigured on the node to work. Must
+ be a descending path, relative to the
+ kubelet's configured seccomp profile
+ location. Must only be set if type is
+ "Localhost".
+ type: string
+ type:
+ description: "type indicates which kind
+ of seccomp profile will be applied.
+ Valid options are: \n Localhost - a
+ profile defined in a file on the node
+ should be used. RuntimeDefault - the
+ container runtime default profile should
+ be used. Unconfined - no profile should
+ be applied."
+ type: string
+ required:
+ - type
+ type: object
+ windowsOptions:
+ description: The Windows specific settings
+ applied to all containers. If unspecified,
+ the options from the PodSecurityContext
+ will be used. If set in both SecurityContext
+ and PodSecurityContext, the value specified
+ in SecurityContext takes precedence.
+ properties:
+ gmsaCredentialSpec:
+ description: GMSACredentialSpec is where
+ the GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa)
+ inlines the contents of the GMSA credential
+ spec named by the GMSACredentialSpecName
+ field.
+ type: string
+ gmsaCredentialSpecName:
+ description: GMSACredentialSpecName is
+ the name of the GMSA credential spec
+ to use.
+ type: string
+ hostProcess:
+ description: HostProcess determines if
+ a container should be run as a 'Host
+ Process' container. This field is alpha-level
+ and will only be honored by components
+ that enable the WindowsHostProcessContainers
+ feature flag. Setting this field without
+ the feature flag will result in errors
+ when validating the Pod. All of a Pod's
+ containers must have the same effective
+ HostProcess value (it is not allowed
+ to have a mix of HostProcess containers
+ and non-HostProcess containers). In
+ addition, if HostProcess is true then
+ HostNetwork must also be set to true.
+ type: boolean
+ runAsUserName:
+ description: The UserName in Windows to
+ run the entrypoint of the container
+ process. Defaults to the user specified
+ in image metadata if unspecified. May
+ also be set in PodSecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext
+ takes precedence.
+ type: string
+ type: object
+ type: object
+ startupProbe:
+ description: 'StartupProbe indicates that the
+ Pod has successfully initialized. If specified,
+ no other probes are executed until this completes
+ successfully. If this probe fails, the Pod will
+ be restarted, just as if the livenessProbe failed.
+ This can be used to provide different probe
+ parameters at the beginning of a Pod''s lifecycle,
+ when it might take a long time to load data
+ or warm a cache, than during steady-state operation.
+ This cannot be updated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ properties:
+ exec:
+ description: One and only one of the following
+ should be specified. Exec specifies the
+ action to take.
+ properties:
+ command:
+ description: Command is the command line
+ to execute inside the container, the
+ working directory for the command is
+ root ('/') in the container's filesystem.
+ The command is simply exec'd, it is
+ not run inside a shell, so traditional
+ shell instructions ('|', etc) won't
+ work. To use a shell, you need to explicitly
+ call out to that shell. Exit status
+ of 0 is treated as live/healthy and
+ non-zero is unhealthy.
+ items:
+ type: string
+ type: array
+ type: object
+ failureThreshold:
+ description: Minimum consecutive failures
+ for the probe to be considered failed after
+ having succeeded. Defaults to 3. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ httpGet:
+ description: HTTPGet specifies the http request
+ to perform.
+ properties:
+ host:
+ description: Host name to connect to,
+ defaults to the pod IP. You probably
+ want to set "Host" in httpHeaders instead.
+ type: string
+ httpHeaders:
+ description: Custom headers to set in
+ the request. HTTP allows repeated headers.
+ items:
+ description: HTTPHeader describes a
+ custom header to be used in HTTP probes
+ properties:
+ name:
+ description: The header field name
+ type: string
+ value:
+ description: The header field value
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ path:
+ description: Path to access on the HTTP
+ server.
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Name or number of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ scheme:
+ description: Scheme to use for connecting
+ to the host. Defaults to HTTP.
+ type: string
+ required:
+ - port
+ type: object
+ initialDelaySeconds:
+ description: 'Number of seconds after the
+ container has started before liveness probes
+ are initiated. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ periodSeconds:
+ description: How often (in seconds) to perform
+ the probe. Default to 10 seconds. Minimum
+ value is 1.
+ format: int32
+ type: integer
+ successThreshold:
+ description: Minimum consecutive successes
+ for the probe to be considered successful
+ after having failed. Defaults to 1. Must
+ be 1 for liveness and startup. Minimum value
+ is 1.
+ format: int32
+ type: integer
+ tcpSocket:
+ description: 'TCPSocket specifies an action
+ involving a TCP port. TCP hooks not yet
+ supported TODO: implement a realistic TCP
+ lifecycle hook'
+ properties:
+ host:
+ description: 'Optional: Host name to connect
+ to, defaults to the pod IP.'
+ type: string
+ port:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Number or name of the port
+ to access on the container. Number must
+ be in the range 1 to 65535. Name must
+ be an IANA_SVC_NAME.
+ x-kubernetes-int-or-string: true
+ required:
+ - port
+ type: object
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds
+ the pod needs to terminate gracefully upon
+ probe failure. The grace period is the duration
+ in seconds after the processes running in
+ the pod are sent a termination signal and
+ the time when the processes are forcibly
+ halted with a kill signal. Set this value
+ longer than the expected cleanup time for
+ your process. If this value is nil, the
+ pod's terminationGracePeriodSeconds will
+ be used. Otherwise, this value overrides
+ the value provided by the pod spec. Value
+ must be non-negative integer. The value
+ zero indicates stop immediately via the
+ kill signal (no opportunity to shut down).
+ This is a beta field and requires enabling
+ ProbeTerminationGracePeriod feature gate.
+ Minimum value is 1. spec.terminationGracePeriodSeconds
+ is used if unset.
+ format: int64
+ type: integer
+ timeoutSeconds:
+ description: 'Number of seconds after which
+ the probe times out. Defaults to 1 second.
+ Minimum value is 1. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle#container-probes'
+ format: int32
+ type: integer
+ type: object
+ stdin:
+ description: Whether this container should allocate
+ a buffer for stdin in the container runtime.
+ If this is not set, reads from stdin in the
+ container will always result in EOF. Default
+ is false.
+ type: boolean
+ stdinOnce:
+ description: Whether the container runtime should
+ close the stdin channel after it has been opened
+ by a single attach. When stdin is true the stdin
+ stream will remain open across multiple attach
+ sessions. If stdinOnce is set to true, stdin
+ is opened on container start, is empty until
+ the first client attaches to stdin, and then
+ remains open and accepts data until the client
+ disconnects, at which time stdin is closed and
+ remains closed until the container is restarted.
+ If this flag is false, a container processes
+ that reads from stdin will never receive an
+ EOF. Default is false
+ type: boolean
+ terminationMessagePath:
+ description: 'Optional: Path at which the file
+ to which the container''s termination message
+ will be written is mounted into the container''s
+ filesystem. Message written is intended to be
+ brief final status, such as an assertion failure
+ message. Will be truncated by the node if greater
+ than 4096 bytes. The total message length across
+ all containers will be limited to 12kb. Defaults
+ to /dev/termination-log. Cannot be updated.'
+ type: string
+ terminationMessagePolicy:
+ description: Indicate how the termination message
+ should be populated. File will use the contents
+ of terminationMessagePath to populate the container
+ status message on both success and failure.
+ FallbackToLogsOnError will use the last chunk
+ of container log output if the termination message
+ file is empty and the container exited with
+ an error. The log output is limited to 2048
+ bytes or 80 lines, whichever is smaller. Defaults
+ to File. Cannot be updated.
+ type: string
+ tty:
+ description: Whether this container should allocate
+ a TTY for itself, also requires 'stdin' to be
+ true. Default is false.
+ type: boolean
+ volumeDevices:
+ description: volumeDevices is the list of block
+ devices to be used by the container.
+ items:
+ description: volumeDevice describes a mapping
+ of a raw block device within a container.
+ properties:
+ devicePath:
+ description: devicePath is the path inside
+ of the container that the device will
+ be mapped to.
+ type: string
+ name:
+ description: name must match the name of
+ a persistentVolumeClaim in the pod
+ type: string
+ required:
+ - devicePath
+ - name
+ type: object
+ type: array
+ volumeMounts:
+ description: Pod volumes to mount into the container's
+ filesystem. Cannot be updated.
+ items:
+ description: VolumeMount describes a mounting
+ of a Volume within a container.
+ properties:
+ mountPath:
+ description: Path within the container at
+ which the volume should be mounted. Must
+ not contain ':'.
+ type: string
+ mountPropagation:
+ description: mountPropagation determines
+ how mounts are propagated from the host
+ to container and the other way around.
+ When not set, MountPropagationNone is
+ used. This field is beta in 1.10.
+ type: string
+ name:
+ description: This must match the Name of
+ a Volume.
+ type: string
+ readOnly:
+ description: Mounted read-only if true,
+ read-write otherwise (false or unspecified).
+ Defaults to false.
+ type: boolean
+ subPath:
+ description: Path within the volume from
+ which the container's volume should be
+ mounted. Defaults to "" (volume's root).
+ type: string
+ subPathExpr:
+ description: Expanded path within the volume
+ from which the container's volume should
+ be mounted. Behaves similarly to SubPath
+ but environment variable references $(VAR_NAME)
+ are expanded using the container's environment.
+ Defaults to "" (volume's root). SubPathExpr
+ and SubPath are mutually exclusive.
+ type: string
+ required:
+ - mountPath
+ - name
+ type: object
+ type: array
+ workingDir:
+ description: Container's working directory. If
+ not specified, the container runtime's default
+ will be used, which might be configured in the
+ container image. Cannot be updated.
+ type: string
+ required:
+ - name
+ type: object
+ type: array
+ nodeName:
+ description: NodeName is a request to schedule this
+ pod onto a specific node. If it is non-empty, the
+ scheduler simply schedules this pod onto that node,
+ assuming that it fits resource requirements.
+ type: string
+ nodeSelector:
+ additionalProperties:
+ type: string
+ description: 'NodeSelector is a selector which must
+ be true for the pod to fit on a node. Selector which
+ must match a node''s labels for the pod to be scheduled
+ on that node. More info: https://kubernetes.io/docs/concepts/configuration/assign-pod-node/'
+ type: object
+ x-kubernetes-map-type: atomic
+ overhead:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Overhead represents the resource overhead
+ associated with running a pod for a given RuntimeClass.
+ This field will be autopopulated at admission time
+ by the RuntimeClass admission controller. If the RuntimeClass
+ admission controller is enabled, overhead must not
+ be set in Pod create requests. The RuntimeClass admission
+ controller will reject Pod create requests which have
+ the overhead already set. If RuntimeClass is configured
+ and selected in the PodSpec, Overhead will be set
+ to the value defined in the corresponding RuntimeClass,
+ otherwise it will remain unset and treated as zero.
+ More info: https://git.k8s.io/enhancements/keps/sig-node/688-pod-overhead/README.md
+ This field is beta-level as of Kubernetes v1.18, and
+ is only honored by servers that enable the PodOverhead
+ feature.'
+ type: object
+ preemptionPolicy:
+ description: PreemptionPolicy is the Policy for preempting
+ pods with lower priority. One of Never, PreemptLowerPriority.
+ Defaults to PreemptLowerPriority if unset. This field
+ is beta-level, gated by the NonPreemptingPriority
+ feature-gate.
+ type: string
+ priority:
+ description: The priority value. Various system components
+ use this field to find the priority of the pod. When
+ Priority Admission Controller is enabled, it prevents
+ users from setting this field. The admission controller
+ populates this field from PriorityClassName. The higher
+ the value, the higher the priority.
+ format: int32
+ type: integer
+ priorityClassName:
+ description: If specified, indicates the pod's priority.
+ "system-node-critical" and "system-cluster-critical"
+ are two special keywords which indicate the highest
+ priorities with the former being the highest priority.
+ Any other name must be defined by creating a PriorityClass
+ object with that name. If not specified, the pod priority
+ will be default or zero if there is no default.
+ type: string
+ readinessGates:
+ description: 'If specified, all readiness gates will
+ be evaluated for pod readiness. A pod is ready when
+ all its containers are ready AND all conditions specified
+ in the readiness gates have status equal to "True"
+ More info: https://git.k8s.io/enhancements/keps/sig-network/580-pod-readiness-gates'
+ items:
+ description: PodReadinessGate contains the reference
+ to a pod condition
+ properties:
+ conditionType:
+ description: ConditionType refers to a condition
+ in the pod's condition list with matching type.
+ type: string
+ required:
+ - conditionType
+ type: object
+ type: array
+ restartPolicy:
+ description: 'Restart policy for all containers within
+ the pod. One of Always, OnFailure, Never. Default
+ to Always. More info: https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#restart-policy'
+ type: string
+ runtimeClassName:
+ description: 'RuntimeClassName refers to a RuntimeClass
+ object in the node.k8s.io group, which should be used
+ to run this pod. If no RuntimeClass resource matches
+ the named class, the pod will not be run. If unset
+ or empty, the "legacy" RuntimeClass will be used,
+ which is an implicit class with an empty definition
+ that uses the default runtime handler. More info:
+ https://git.k8s.io/enhancements/keps/sig-node/585-runtime-class
+ This is a beta feature as of Kubernetes v1.14.'
+ type: string
+ schedulerName:
+ description: If specified, the pod will be dispatched
+ by specified scheduler. If not specified, the pod
+ will be dispatched by default scheduler.
+ type: string
+ securityContext:
+ description: 'SecurityContext holds pod-level security
+ attributes and common container settings. Optional:
+ Defaults to empty. See type description for default
+ values of each field.'
+ properties:
+ fsGroup:
+ description: "A special supplemental group that
+ applies to all containers in a pod. Some volume
+ types allow the Kubelet to change the ownership
+ of that volume to be owned by the pod: \n 1. The
+ owning GID will be the FSGroup 2. The setgid bit
+ is set (new files created in the volume will be
+ owned by FSGroup) 3. The permission bits are OR'd
+ with rw-rw---- \n If unset, the Kubelet will not
+ modify the ownership and permissions of any volume."
+ format: int64
+ type: integer
+ fsGroupChangePolicy:
+ description: 'fsGroupChangePolicy defines behavior
+ of changing ownership and permission of the volume
+ before being exposed inside Pod. This field will
+ only apply to volume types which support fsGroup
+ based ownership(and permissions). It will have
+ no effect on ephemeral volume types such as: secret,
+ configmaps and emptydir. Valid values are "OnRootMismatch"
+ and "Always". If not specified, "Always" is used.'
+ type: string
+ runAsGroup:
+ description: The GID to run the entrypoint of the
+ container process. Uses runtime default if unset.
+ May also be set in SecurityContext. If set in
+ both SecurityContext and PodSecurityContext, the
+ value specified in SecurityContext takes precedence
+ for that container.
+ format: int64
+ type: integer
+ runAsNonRoot:
+ description: Indicates that the container must run
+ as a non-root user. If true, the Kubelet will
+ validate the image at runtime to ensure that it
+ does not run as UID 0 (root) and fail to start
+ the container if it does. If unset or false, no
+ such validation will be performed. May also be
+ set in SecurityContext. If set in both SecurityContext
+ and PodSecurityContext, the value specified in
+ SecurityContext takes precedence.
+ type: boolean
+ runAsUser:
+ description: The UID to run the entrypoint of the
+ container process. Defaults to user specified
+ in image metadata if unspecified. May also be
+ set in SecurityContext. If set in both SecurityContext
+ and PodSecurityContext, the value specified in
+ SecurityContext takes precedence for that container.
+ format: int64
+ type: integer
+ seLinuxOptions:
+ description: The SELinux context to be applied to
+ all containers. If unspecified, the container
+ runtime will allocate a random SELinux context
+ for each container. May also be set in SecurityContext. If
+ set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes precedence
+ for that container.
+ properties:
+ level:
+ description: Level is SELinux level label that
+ applies to the container.
+ type: string
+ role:
+ description: Role is a SELinux role label that
+ applies to the container.
+ type: string
+ type:
+ description: Type is a SELinux type label that
+ applies to the container.
+ type: string
+ user:
+ description: User is a SELinux user label that
+ applies to the container.
+ type: string
+ type: object
+ seccompProfile:
+ description: The seccomp options to use by the containers
+ in this pod.
+ properties:
+ localhostProfile:
+ description: localhostProfile indicates a profile
+ defined in a file on the node should be used.
+ The profile must be preconfigured on the node
+ to work. Must be a descending path, relative
+ to the kubelet's configured seccomp profile
+ location. Must only be set if type is "Localhost".
+ type: string
+ type:
+ description: "type indicates which kind of seccomp
+ profile will be applied. Valid options are:
+ \n Localhost - a profile defined in a file
+ on the node should be used. RuntimeDefault
+ - the container runtime default profile should
+ be used. Unconfined - no profile should be
+ applied."
+ type: string
+ required:
+ - type
+ type: object
+ supplementalGroups:
+ description: A list of groups applied to the first
+ process run in each container, in addition to
+ the container's primary GID. If unspecified,
+ no groups will be added to any container.
+ items:
+ format: int64
+ type: integer
+ type: array
+ sysctls:
+ description: Sysctls hold a list of namespaced sysctls
+ used for the pod. Pods with unsupported sysctls
+ (by the container runtime) might fail to launch.
+ items:
+ description: Sysctl defines a kernel parameter
+ to be set
+ properties:
+ name:
+ description: Name of a property to set
+ type: string
+ value:
+ description: Value of a property to set
+ type: string
+ required:
+ - name
+ - value
+ type: object
+ type: array
+ windowsOptions:
+ description: The Windows specific settings applied
+ to all containers. If unspecified, the options
+ within a container's SecurityContext will be used.
+ If set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes precedence.
+ properties:
+ gmsaCredentialSpec:
+ description: GMSACredentialSpec is where the
+ GMSA admission webhook (https://github.com/kubernetes-sigs/windows-gmsa)
+ inlines the contents of the GMSA credential
+ spec named by the GMSACredentialSpecName field.
+ type: string
+ gmsaCredentialSpecName:
+ description: GMSACredentialSpecName is the name
+ of the GMSA credential spec to use.
+ type: string
+ hostProcess:
+ description: HostProcess determines if a container
+ should be run as a 'Host Process' container.
+ This field is alpha-level and will only be
+ honored by components that enable the WindowsHostProcessContainers
+ feature flag. Setting this field without the
+ feature flag will result in errors when validating
+ the Pod. All of a Pod's containers must have
+ the same effective HostProcess value (it is
+ not allowed to have a mix of HostProcess containers
+ and non-HostProcess containers). In addition,
+ if HostProcess is true then HostNetwork must
+ also be set to true.
+ type: boolean
+ runAsUserName:
+ description: The UserName in Windows to run
+ the entrypoint of the container process. Defaults
+ to the user specified in image metadata if
+ unspecified. May also be set in PodSecurityContext.
+ If set in both SecurityContext and PodSecurityContext,
+ the value specified in SecurityContext takes
+ precedence.
+ type: string
+ type: object
+ type: object
+ serviceAccount:
+ description: 'DeprecatedServiceAccount is a depreciated
+ alias for ServiceAccountName. Deprecated: Use serviceAccountName
+ instead.'
+ type: string
+ serviceAccountName:
+ description: 'ServiceAccountName is the name of the
+ ServiceAccount to use to run this pod. More info:
+ https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/'
+ type: string
+ setHostnameAsFQDN:
+ description: If true the pod's hostname will be configured
+ as the pod's FQDN, rather than the leaf name (the
+ default). In Linux containers, this means setting
+ the FQDN in the hostname field of the kernel (the
+ nodename field of struct utsname). In Windows containers,
+ this means setting the registry value of hostname
+ for the registry key HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters
+ to FQDN. If a pod does not have FQDN, this has no
+ effect. Default to false.
+ type: boolean
+ shareProcessNamespace:
+ description: 'Share a single process namespace between
+ all of the containers in a pod. When this is set containers
+ will be able to view and signal processes from other
+ containers in the same pod, and the first process
+ in each container will not be assigned PID 1. HostPID
+ and ShareProcessNamespace cannot both be set. Optional:
+ Default to false.'
+ type: boolean
+ subdomain:
+ description: If specified, the fully qualified Pod hostname
+ will be "...svc.". If not specified, the pod will not have
+ a domainname at all.
+ type: string
+ terminationGracePeriodSeconds:
+ description: Optional duration in seconds the pod needs
+ to terminate gracefully. May be decreased in delete
+ request. Value must be non-negative integer. The value
+ zero indicates stop immediately via the kill signal
+ (no opportunity to shut down). If this value is nil,
+ the default grace period will be used instead. The
+ grace period is the duration in seconds after the
+ processes running in the pod are sent a termination
+ signal and the time when the processes are forcibly
+ halted with a kill signal. Set this value longer than
+ the expected cleanup time for your process. Defaults
+ to 30 seconds.
+ format: int64
+ type: integer
+ tolerations:
+ description: If specified, the pod's tolerations.
+ items:
+ description: The pod this Toleration is attached to
+ tolerates any taint that matches the triple
+ using the matching operator .
+ properties:
+ effect:
+ description: Effect indicates the taint effect
+ to match. Empty means match all taint effects.
+ When specified, allowed values are NoSchedule,
+ PreferNoSchedule and NoExecute.
+ type: string
+ key:
+ description: Key is the taint key that the toleration
+ applies to. Empty means match all taint keys.
+ If the key is empty, operator must be Exists;
+ this combination means to match all values and
+ all keys.
+ type: string
+ operator:
+ description: Operator represents a key's relationship
+ to the value. Valid operators are Exists and
+ Equal. Defaults to Equal. Exists is equivalent
+ to wildcard for value, so that a pod can tolerate
+ all taints of a particular category.
+ type: string
+ tolerationSeconds:
+ description: TolerationSeconds represents the
+ period of time the toleration (which must be
+ of effect NoExecute, otherwise this field is
+ ignored) tolerates the taint. By default, it
+ is not set, which means tolerate the taint forever
+ (do not evict). Zero and negative values will
+ be treated as 0 (evict immediately) by the system.
+ format: int64
+ type: integer
+ value:
+ description: Value is the taint value the toleration
+ matches to. If the operator is Exists, the value
+ should be empty, otherwise just a regular string.
+ type: string
+ type: object
+ type: array
+ topologySpreadConstraints:
+ description: TopologySpreadConstraints describes how
+ a group of pods ought to spread across topology domains.
+ Scheduler will schedule pods in a way which abides
+ by the constraints. All topologySpreadConstraints
+ are ANDed.
+ items:
+ description: TopologySpreadConstraint specifies how
+ to spread matching pods among the given topology.
+ properties:
+ labelSelector:
+ description: LabelSelector is used to find matching
+ pods. Pods that match this label selector are
+ counted to determine the number of pods in their
+ corresponding topology domain.
+ properties:
+ matchExpressions:
+ description: matchExpressions is a list of
+ label selector requirements. The requirements
+ are ANDed.
+ items:
+ description: A label selector requirement
+ is a selector that contains values, a
+ key, and an operator that relates the
+ key and values.
+ properties:
+ key:
+ description: key is the label key that
+ the selector applies to.
+ type: string
+ operator:
+ description: operator represents a key's
+ relationship to a set of values. Valid
+ operators are In, NotIn, Exists and
+ DoesNotExist.
+ type: string
+ values:
+ description: values is an array of string
+ values. If the operator is In or NotIn,
+ the values array must be non-empty.
+ If the operator is Exists or DoesNotExist,
+ the values array must be empty. This
+ array is replaced during a strategic
+ merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a map of {key,value}
+ pairs. A single {key,value} in the matchLabels
+ map is equivalent to an element of matchExpressions,
+ whose key field is "key", the operator is
+ "In", and the values array contains only
+ "value". The requirements are ANDed.
+ type: object
+ type: object
+ maxSkew:
+ description: 'MaxSkew describes the degree to
+ which pods may be unevenly distributed. When
+ `whenUnsatisfiable=DoNotSchedule`, it is the
+ maximum permitted difference between the number
+ of matching pods in the target topology and
+ the global minimum. For example, in a 3-zone
+ cluster, MaxSkew is set to 1, and pods with
+ the same labelSelector spread as 1/1/0: | zone1
+ | zone2 | zone3 | | P | P | |
+ - if MaxSkew is 1, incoming pod can only be
+ scheduled to zone3 to become 1/1/1; scheduling
+ it onto zone1(zone2) would make the ActualSkew(2-0)
+ on zone1(zone2) violate MaxSkew(1). - if MaxSkew
+ is 2, incoming pod can be scheduled onto any
+ zone. When `whenUnsatisfiable=ScheduleAnyway`,
+ it is used to give higher precedence to topologies
+ that satisfy it. It''s a required field. Default
+ value is 1 and 0 is not allowed.'
+ format: int32
+ type: integer
+ topologyKey:
+ description: TopologyKey is the key of node labels.
+ Nodes that have a label with this key and identical
+ values are considered to be in the same topology.
+ We consider each as a "bucket",
+ and try to put balanced number of pods into
+ each bucket. It's a required field.
+ type: string
+ whenUnsatisfiable:
+ description: 'WhenUnsatisfiable indicates how
+ to deal with a pod if it doesn''t satisfy the
+ spread constraint. - DoNotSchedule (default)
+ tells the scheduler not to schedule it. - ScheduleAnyway
+ tells the scheduler to schedule the pod in any
+ location, but giving higher precedence to
+ topologies that would help reduce the skew.
+ A constraint is considered "Unsatisfiable" for
+ an incoming pod if and only if every possible
+ node assigment for that pod would violate "MaxSkew"
+ on some topology. For example, in a 3-zone cluster,
+ MaxSkew is set to 1, and pods with the same
+ labelSelector spread as 3/1/1: | zone1 | zone2
+ | zone3 | | P P P | P | P | If WhenUnsatisfiable
+ is set to DoNotSchedule, incoming pod can only
+ be scheduled to zone2(zone3) to become 3/2/1(3/1/2)
+ as ActualSkew(2-1) on zone2(zone3) satisfies
+ MaxSkew(1). In other words, the cluster can
+ still be imbalanced, but scheduler won''t make
+ it *more* imbalanced. It''s a required field.'
+ type: string
+ required:
+ - maxSkew
+ - topologyKey
+ - whenUnsatisfiable
+ type: object
+ type: array
+ x-kubernetes-list-map-keys:
+ - topologyKey
+ - whenUnsatisfiable
+ x-kubernetes-list-type: map
+ volumes:
+ description: 'List of volumes that can be mounted by
+ containers belonging to the pod. More info: https://kubernetes.io/docs/concepts/storage/volumes'
+ items:
+ description: Volume represents a named volume in a
+ pod that may be accessed by any container in the
+ pod.
+ properties:
+ awsElasticBlockStore:
+ description: 'AWSElasticBlockStore represents
+ an AWS Disk resource that is attached to a kubelet''s
+ host machine and then exposed to the pod. More
+ info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore'
+ properties:
+ fsType:
+ description: 'Filesystem type of the volume
+ that you want to mount. Tip: Ensure that
+ the filesystem type is supported by the
+ host operating system. Examples: "ext4",
+ "xfs", "ntfs". Implicitly inferred to be
+ "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore
+ TODO: how do we prevent errors in the filesystem
+ from compromising the machine'
+ type: string
+ partition:
+ description: 'The partition in the volume
+ that you want to mount. If omitted, the
+ default is to mount by volume name. Examples:
+ For volume /dev/sda1, you specify the partition
+ as "1". Similarly, the volume partition
+ for /dev/sda is "0" (or you can leave the
+ property empty).'
+ format: int32
+ type: integer
+ readOnly:
+ description: 'Specify "true" to force and
+ set the ReadOnly property in VolumeMounts
+ to "true". If omitted, the default is "false".
+ More info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore'
+ type: boolean
+ volumeID:
+ description: 'Unique ID of the persistent
+ disk resource in AWS (Amazon EBS volume).
+ More info: https://kubernetes.io/docs/concepts/storage/volumes#awselasticblockstore'
+ type: string
+ required:
+ - volumeID
+ type: object
+ azureDisk:
+ description: AzureDisk represents an Azure Data
+ Disk mount on the host and bind mount to the
+ pod.
+ properties:
+ cachingMode:
+ description: 'Host Caching mode: None, Read
+ Only, Read Write.'
+ type: string
+ diskName:
+ description: The Name of the data disk in
+ the blob storage
+ type: string
+ diskURI:
+ description: The URI the data disk in the
+ blob storage
+ type: string
+ fsType:
+ description: Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Ex. "ext4", "xfs", "ntfs".
+ Implicitly inferred to be "ext4" if unspecified.
+ type: string
+ kind:
+ description: 'Expected values Shared: multiple
+ blob disks per storage account Dedicated:
+ single blob disk per storage account Managed:
+ azure managed data disk (only in managed
+ availability set). defaults to shared'
+ type: string
+ readOnly:
+ description: Defaults to false (read/write).
+ ReadOnly here will force the ReadOnly setting
+ in VolumeMounts.
+ type: boolean
+ required:
+ - diskName
+ - diskURI
+ type: object
+ azureFile:
+ description: AzureFile represents an Azure File
+ Service mount on the host and bind mount to
+ the pod.
+ properties:
+ readOnly:
+ description: Defaults to false (read/write).
+ ReadOnly here will force the ReadOnly setting
+ in VolumeMounts.
+ type: boolean
+ secretName:
+ description: the name of secret that contains
+ Azure Storage Account Name and Key
+ type: string
+ shareName:
+ description: Share Name
+ type: string
+ required:
+ - secretName
+ - shareName
+ type: object
+ cephfs:
+ description: CephFS represents a Ceph FS mount
+ on the host that shares a pod's lifetime
+ properties:
+ monitors:
+ description: 'Required: Monitors is a collection
+ of Ceph monitors More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it'
+ items:
+ type: string
+ type: array
+ path:
+ description: 'Optional: Used as the mounted
+ root, rather than the full Ceph tree, default
+ is /'
+ type: string
+ readOnly:
+ description: 'Optional: Defaults to false
+ (read/write). ReadOnly here will force the
+ ReadOnly setting in VolumeMounts. More info:
+ https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it'
+ type: boolean
+ secretFile:
+ description: 'Optional: SecretFile is the
+ path to key ring for User, default is /etc/ceph/user.secret
+ More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it'
+ type: string
+ secretRef:
+ description: 'Optional: SecretRef is reference
+ to the authentication secret for User, default
+ is empty. More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it'
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ user:
+ description: 'Optional: User is the rados
+ user name, default is admin More info: https://examples.k8s.io/volumes/cephfs/README.md#how-to-use-it'
+ type: string
+ required:
+ - monitors
+ type: object
+ cinder:
+ description: 'Cinder represents a cinder volume
+ attached and mounted on kubelets host machine.
+ More info: https://examples.k8s.io/mysql-cinder-pd/README.md'
+ properties:
+ fsType:
+ description: 'Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Examples: "ext4", "xfs",
+ "ntfs". Implicitly inferred to be "ext4"
+ if unspecified. More info: https://examples.k8s.io/mysql-cinder-pd/README.md'
+ type: string
+ readOnly:
+ description: 'Optional: Defaults to false
+ (read/write). ReadOnly here will force the
+ ReadOnly setting in VolumeMounts. More info:
+ https://examples.k8s.io/mysql-cinder-pd/README.md'
+ type: boolean
+ secretRef:
+ description: 'Optional: points to a secret
+ object containing parameters used to connect
+ to OpenStack.'
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ volumeID:
+ description: 'volume id used to identify the
+ volume in cinder. More info: https://examples.k8s.io/mysql-cinder-pd/README.md'
+ type: string
+ required:
+ - volumeID
+ type: object
+ configMap:
+ description: ConfigMap represents a configMap
+ that should populate this volume
+ properties:
+ defaultMode:
+ description: 'Optional: mode bits used to
+ set permissions on created files by default.
+ Must be an octal value between 0000 and
+ 0777 or a decimal value between 0 and 511.
+ YAML accepts both octal and decimal values,
+ JSON requires decimal values for mode bits.
+ Defaults to 0644. Directories within the
+ path are not affected by this setting. This
+ might be in conflict with other options
+ that affect the file mode, like fsGroup,
+ and the result can be other mode bits set.'
+ format: int32
+ type: integer
+ items:
+ description: If unspecified, each key-value
+ pair in the Data field of the referenced
+ ConfigMap will be projected into the volume
+ as a file whose name is the key and content
+ is the value. If specified, the listed keys
+ will be projected into the specified paths,
+ and unlisted keys will not be present. If
+ a key is specified which is not present
+ in the ConfigMap, the volume setup will
+ error unless it is marked optional. Paths
+ must be relative and may not contain the
+ '..' path or start with '..'.
+ items:
+ description: Maps a string key to a path
+ within a volume.
+ properties:
+ key:
+ description: The key to project.
+ type: string
+ mode:
+ description: 'Optional: mode bits used
+ to set permissions on this file. Must
+ be an octal value between 0000 and
+ 0777 or a decimal value between 0
+ and 511. YAML accepts both octal and
+ decimal values, JSON requires decimal
+ values for mode bits. If not specified,
+ the volume defaultMode will be used.
+ This might be in conflict with other
+ options that affect the file mode,
+ like fsGroup, and the result can be
+ other mode bits set.'
+ format: int32
+ type: integer
+ path:
+ description: The relative path of the
+ file to map the key to. May not be
+ an absolute path. May not contain
+ the path element '..'. May not start
+ with the string '..'.
+ type: string
+ required:
+ - key
+ - path
+ type: object
+ type: array
+ name:
+ description: 'Name of the referent. More info:
+ https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the ConfigMap
+ or its keys must be defined
+ type: boolean
+ type: object
+ csi:
+ description: CSI (Container Storage Interface)
+ represents ephemeral storage that is handled
+ by certain external CSI drivers (Beta feature).
+ properties:
+ driver:
+ description: Driver is the name of the CSI
+ driver that handles this volume. Consult
+ with your admin for the correct name as
+ registered in the cluster.
+ type: string
+ fsType:
+ description: Filesystem type to mount. Ex.
+ "ext4", "xfs", "ntfs". If not provided,
+ the empty value is passed to the associated
+ CSI driver which will determine the default
+ filesystem to apply.
+ type: string
+ nodePublishSecretRef:
+ description: NodePublishSecretRef is a reference
+ to the secret object containing sensitive
+ information to pass to the CSI driver to
+ complete the CSI NodePublishVolume and NodeUnpublishVolume
+ calls. This field is optional, and may
+ be empty if no secret is required. If the
+ secret object contains more than one secret,
+ all secret references are passed.
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ readOnly:
+ description: Specifies a read-only configuration
+ for the volume. Defaults to false (read/write).
+ type: boolean
+ volumeAttributes:
+ additionalProperties:
+ type: string
+ description: VolumeAttributes stores driver-specific
+ properties that are passed to the CSI driver.
+ Consult your driver's documentation for
+ supported values.
+ type: object
+ required:
+ - driver
+ type: object
+ downwardAPI:
+ description: DownwardAPI represents downward API
+ about the pod that should populate this volume
+ properties:
+ defaultMode:
+ description: 'Optional: mode bits to use on
+ created files by default. Must be a Optional:
+ mode bits used to set permissions on created
+ files by default. Must be an octal value
+ between 0000 and 0777 or a decimal value
+ between 0 and 511. YAML accepts both octal
+ and decimal values, JSON requires decimal
+ values for mode bits. Defaults to 0644.
+ Directories within the path are not affected
+ by this setting. This might be in conflict
+ with other options that affect the file
+ mode, like fsGroup, and the result can be
+ other mode bits set.'
+ format: int32
+ type: integer
+ items:
+ description: Items is a list of downward API
+ volume file
+ items:
+ description: DownwardAPIVolumeFile represents
+ information to create the file containing
+ the pod field
+ properties:
+ fieldRef:
+ description: 'Required: Selects a field
+ of the pod: only annotations, labels,
+ name and namespace are supported.'
+ properties:
+ apiVersion:
+ description: Version of the schema
+ the FieldPath is written in terms
+ of, defaults to "v1".
+ type: string
+ fieldPath:
+ description: Path of the field to
+ select in the specified API version.
+ type: string
+ required:
+ - fieldPath
+ type: object
+ mode:
+ description: 'Optional: mode bits used
+ to set permissions on this file, must
+ be an octal value between 0000 and
+ 0777 or a decimal value between 0
+ and 511. YAML accepts both octal and
+ decimal values, JSON requires decimal
+ values for mode bits. If not specified,
+ the volume defaultMode will be used.
+ This might be in conflict with other
+ options that affect the file mode,
+ like fsGroup, and the result can be
+ other mode bits set.'
+ format: int32
+ type: integer
+ path:
+ description: 'Required: Path is the
+ relative path name of the file to
+ be created. Must not be absolute or
+ contain the ''..'' path. Must be utf-8
+ encoded. The first item of the relative
+ path must not start with ''..'''
+ type: string
+ resourceFieldRef:
+ description: 'Selects a resource of
+ the container: only resources limits
+ and requests (limits.cpu, limits.memory,
+ requests.cpu and requests.memory)
+ are currently supported.'
+ properties:
+ containerName:
+ description: 'Container name: required
+ for volumes, optional for env
+ vars'
+ type: string
+ divisor:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Specifies the output
+ format of the exposed resources,
+ defaults to "1"
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ resource:
+ description: 'Required: resource
+ to select'
+ type: string
+ required:
+ - resource
+ type: object
+ required:
+ - path
+ type: object
+ type: array
+ type: object
+ emptyDir:
+ description: 'EmptyDir represents a temporary
+ directory that shares a pod''s lifetime. More
+ info: https://kubernetes.io/docs/concepts/storage/volumes#emptydir'
+ properties:
+ medium:
+ description: 'What type of storage medium
+ should back this directory. The default
+ is "" which means to use the node''s default
+ medium. Must be an empty string (default)
+ or Memory. More info: https://kubernetes.io/docs/concepts/storage/volumes#emptydir'
+ type: string
+ sizeLimit:
+ anyOf:
+ - type: integer
+ - type: string
+ description: 'Total amount of local storage
+ required for this EmptyDir volume. The size
+ limit is also applicable for memory medium.
+ The maximum usage on memory medium EmptyDir
+ would be the minimum value between the SizeLimit
+ specified here and the sum of memory limits
+ of all containers in a pod. The default
+ is nil which means that the limit is undefined.
+ More info: http://kubernetes.io/docs/user-guide/volumes#emptydir'
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ type: object
+ ephemeral:
+ description: "Ephemeral represents a volume that
+ is handled by a cluster storage driver. The
+ volume's lifecycle is tied to the pod that defines
+ it - it will be created before the pod starts,
+ and deleted when the pod is removed. \n Use
+ this if: a) the volume is only needed while
+ the pod runs, b) features of normal volumes
+ like restoring from snapshot or capacity tracking
+ are needed, c) the storage driver is specified
+ through a storage class, and d) the storage
+ driver supports dynamic volume provisioning
+ through a PersistentVolumeClaim (see EphemeralVolumeSource
+ for more information on the connection between
+ this volume type and PersistentVolumeClaim).
+ \n Use PersistentVolumeClaim or one of the vendor-specific
+ APIs for volumes that persist for longer than
+ the lifecycle of an individual pod. \n Use CSI
+ for light-weight local ephemeral volumes if
+ the CSI driver is meant to be used that way
+ - see the documentation of the driver for more
+ information. \n A pod can use both types of
+ ephemeral volumes and persistent volumes at
+ the same time. \n This is a beta feature and
+ only available when the GenericEphemeralVolume
+ feature gate is enabled."
+ properties:
+ volumeClaimTemplate:
+ description: "Will be used to create a stand-alone
+ PVC to provision the volume. The pod in
+ which this EphemeralVolumeSource is embedded
+ will be the owner of the PVC, i.e. the PVC
+ will be deleted together with the pod. The
+ name of the PVC will be `-` where `` is the name
+ from the `PodSpec.Volumes` array entry.
+ Pod validation will reject the pod if the
+ concatenated name is not valid for a PVC
+ (for example, too long). \n An existing
+ PVC with that name that is not owned by
+ the pod will *not* be used for the pod to
+ avoid using an unrelated volume by mistake.
+ Starting the pod is then blocked until the
+ unrelated PVC is removed. If such a pre-created
+ PVC is meant to be used by the pod, the
+ PVC has to updated with an owner reference
+ to the pod once the pod exists. Normally
+ this should not be necessary, but it may
+ be useful when manually reconstructing a
+ broken cluster. \n This field is read-only
+ and no changes will be made by Kubernetes
+ to the PVC after it has been created. \n
+ Required, must not be nil."
+ properties:
+ metadata:
+ description: May contain labels and annotations
+ that will be copied into the PVC when
+ creating it. No other fields are allowed
+ and will be rejected during validation.
+ type: object
+ spec:
+ description: The specification for the
+ PersistentVolumeClaim. The entire content
+ is copied unchanged into the PVC that
+ gets created from this template. The
+ same fields as in a PersistentVolumeClaim
+ are also valid here.
+ properties:
+ accessModes:
+ description: 'AccessModes contains
+ the desired access modes the volume
+ should have. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#access-modes-1'
+ items:
+ type: string
+ type: array
+ dataSource:
+ description: 'This field can be used
+ to specify either: * An existing
+ VolumeSnapshot object (snapshot.storage.k8s.io/VolumeSnapshot)
+ * An existing PVC (PersistentVolumeClaim)
+ If the provisioner or an external
+ controller can support the specified
+ data source, it will create a new
+ volume based on the contents of
+ the specified data source. If the
+ AnyVolumeDataSource feature gate
+ is enabled, this field will always
+ have the same contents as the DataSourceRef
+ field.'
+ properties:
+ apiGroup:
+ description: APIGroup is the group
+ for the resource being referenced.
+ If APIGroup is not specified,
+ the specified Kind must be in
+ the core API group. For any
+ other third-party types, APIGroup
+ is required.
+ type: string
+ kind:
+ description: Kind is the type
+ of resource being referenced
+ type: string
+ name:
+ description: Name is the name
+ of resource being referenced
+ type: string
+ required:
+ - kind
+ - name
+ type: object
+ dataSourceRef:
+ description: 'Specifies the object
+ from which to populate the volume
+ with data, if a non-empty volume
+ is desired. This may be any local
+ object from a non-empty API group
+ (non core object) or a PersistentVolumeClaim
+ object. When this field is specified,
+ volume binding will only succeed
+ if the type of the specified object
+ matches some installed volume populator
+ or dynamic provisioner. This field
+ will replace the functionality of
+ the DataSource field and as such
+ if both fields are non-empty, they
+ must have the same value. For backwards
+ compatibility, both fields (DataSource
+ and DataSourceRef) will be set to
+ the same value automatically if
+ one of them is empty and the other
+ is non-empty. There are two important
+ differences between DataSource and
+ DataSourceRef: * While DataSource
+ only allows two specific types of
+ objects, DataSourceRef allows
+ any non-core object, as well as
+ PersistentVolumeClaim objects. *
+ While DataSource ignores disallowed
+ values (dropping them), DataSourceRef preserves
+ all values, and generates an error
+ if a disallowed value is specified.
+ (Alpha) Using this field requires
+ the AnyVolumeDataSource feature
+ gate to be enabled.'
+ properties:
+ apiGroup:
+ description: APIGroup is the group
+ for the resource being referenced.
+ If APIGroup is not specified,
+ the specified Kind must be in
+ the core API group. For any
+ other third-party types, APIGroup
+ is required.
+ type: string
+ kind:
+ description: Kind is the type
+ of resource being referenced
+ type: string
+ name:
+ description: Name is the name
+ of resource being referenced
+ type: string
+ required:
+ - kind
+ - name
+ type: object
+ resources:
+ description: 'Resources represents
+ the minimum resources the volume
+ should have. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#resources'
+ properties:
+ limits:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Limits describes
+ the maximum amount of compute
+ resources allowed. More info:
+ https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ requests:
+ additionalProperties:
+ anyOf:
+ - type: integer
+ - type: string
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ description: 'Requests describes
+ the minimum amount of compute
+ resources required. If Requests
+ is omitted for a container,
+ it defaults to Limits if that
+ is explicitly specified, otherwise
+ to an implementation-defined
+ value. More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/'
+ type: object
+ type: object
+ selector:
+ description: A label query over volumes
+ to consider for binding.
+ properties:
+ matchExpressions:
+ description: matchExpressions
+ is a list of label selector
+ requirements. The requirements
+ are ANDed.
+ items:
+ description: A label selector
+ requirement is a selector
+ that contains values, a key,
+ and an operator that relates
+ the key and values.
+ properties:
+ key:
+ description: key is the
+ label key that the selector
+ applies to.
+ type: string
+ operator:
+ description: operator represents
+ a key's relationship to
+ a set of values. Valid
+ operators are In, NotIn,
+ Exists and DoesNotExist.
+ type: string
+ values:
+ description: values is an
+ array of string values.
+ If the operator is In
+ or NotIn, the values array
+ must be non-empty. If
+ the operator is Exists
+ or DoesNotExist, the values
+ array must be empty. This
+ array is replaced during
+ a strategic merge patch.
+ items:
+ type: string
+ type: array
+ required:
+ - key
+ - operator
+ type: object
+ type: array
+ matchLabels:
+ additionalProperties:
+ type: string
+ description: matchLabels is a
+ map of {key,value} pairs. A
+ single {key,value} in the matchLabels
+ map is equivalent to an element
+ of matchExpressions, whose key
+ field is "key", the operator
+ is "In", and the values array
+ contains only "value". The requirements
+ are ANDed.
+ type: object
+ type: object
+ storageClassName:
+ description: 'Name of the StorageClass
+ required by the claim. More info:
+ https://kubernetes.io/docs/concepts/storage/persistent-volumes#class-1'
+ type: string
+ volumeMode:
+ description: volumeMode defines what
+ type of volume is required by the
+ claim. Value of Filesystem is implied
+ when not included in claim spec.
+ type: string
+ volumeName:
+ description: VolumeName is the binding
+ reference to the PersistentVolume
+ backing this claim.
+ type: string
+ type: object
+ required:
+ - spec
+ type: object
+ type: object
+ fc:
+ description: FC represents a Fibre Channel resource
+ that is attached to a kubelet's host machine
+ and then exposed to the pod.
+ properties:
+ fsType:
+ description: 'Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Ex. "ext4", "xfs", "ntfs".
+ Implicitly inferred to be "ext4" if unspecified.
+ TODO: how do we prevent errors in the filesystem
+ from compromising the machine'
+ type: string
+ lun:
+ description: 'Optional: FC target lun number'
+ format: int32
+ type: integer
+ readOnly:
+ description: 'Optional: Defaults to false
+ (read/write). ReadOnly here will force the
+ ReadOnly setting in VolumeMounts.'
+ type: boolean
+ targetWWNs:
+ description: 'Optional: FC target worldwide
+ names (WWNs)'
+ items:
+ type: string
+ type: array
+ wwids:
+ description: 'Optional: FC volume world wide
+ identifiers (wwids) Either wwids or combination
+ of targetWWNs and lun must be set, but not
+ both simultaneously.'
+ items:
+ type: string
+ type: array
+ type: object
+ flexVolume:
+ description: FlexVolume represents a generic volume
+ resource that is provisioned/attached using
+ an exec based plugin.
+ properties:
+ driver:
+ description: Driver is the name of the driver
+ to use for this volume.
+ type: string
+ fsType:
+ description: Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Ex. "ext4", "xfs", "ntfs".
+ The default filesystem depends on FlexVolume
+ script.
+ type: string
+ options:
+ additionalProperties:
+ type: string
+ description: 'Optional: Extra command options
+ if any.'
+ type: object
+ readOnly:
+ description: 'Optional: Defaults to false
+ (read/write). ReadOnly here will force the
+ ReadOnly setting in VolumeMounts.'
+ type: boolean
+ secretRef:
+ description: 'Optional: SecretRef is reference
+ to the secret object containing sensitive
+ information to pass to the plugin scripts.
+ This may be empty if no secret object is
+ specified. If the secret object contains
+ more than one secret, all secrets are passed
+ to the plugin scripts.'
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ required:
+ - driver
+ type: object
+ flocker:
+ description: Flocker represents a Flocker volume
+ attached to a kubelet's host machine. This depends
+ on the Flocker control service being running
+ properties:
+ datasetName:
+ description: Name of the dataset stored as
+ metadata -> name on the dataset for Flocker
+ should be considered as deprecated
+ type: string
+ datasetUUID:
+ description: UUID of the dataset. This is
+ unique identifier of a Flocker dataset
+ type: string
+ type: object
+ gcePersistentDisk:
+ description: 'GCEPersistentDisk represents a GCE
+ Disk resource that is attached to a kubelet''s
+ host machine and then exposed to the pod. More
+ info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk'
+ properties:
+ fsType:
+ description: 'Filesystem type of the volume
+ that you want to mount. Tip: Ensure that
+ the filesystem type is supported by the
+ host operating system. Examples: "ext4",
+ "xfs", "ntfs". Implicitly inferred to be
+ "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk
+ TODO: how do we prevent errors in the filesystem
+ from compromising the machine'
+ type: string
+ partition:
+ description: 'The partition in the volume
+ that you want to mount. If omitted, the
+ default is to mount by volume name. Examples:
+ For volume /dev/sda1, you specify the partition
+ as "1". Similarly, the volume partition
+ for /dev/sda is "0" (or you can leave the
+ property empty). More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk'
+ format: int32
+ type: integer
+ pdName:
+ description: 'Unique name of the PD resource
+ in GCE. Used to identify the disk in GCE.
+ More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk'
+ type: string
+ readOnly:
+ description: 'ReadOnly here will force the
+ ReadOnly setting in VolumeMounts. Defaults
+ to false. More info: https://kubernetes.io/docs/concepts/storage/volumes#gcepersistentdisk'
+ type: boolean
+ required:
+ - pdName
+ type: object
+ gitRepo:
+ description: 'GitRepo represents a git repository
+ at a particular revision. DEPRECATED: GitRepo
+ is deprecated. To provision a container with
+ a git repo, mount an EmptyDir into an InitContainer
+ that clones the repo using git, then mount the
+ EmptyDir into the Pod''s container.'
+ properties:
+ directory:
+ description: Target directory name. Must not
+ contain or start with '..'. If '.' is supplied,
+ the volume directory will be the git repository. Otherwise,
+ if specified, the volume will contain the
+ git repository in the subdirectory with
+ the given name.
+ type: string
+ repository:
+ description: Repository URL
+ type: string
+ revision:
+ description: Commit hash for the specified
+ revision.
+ type: string
+ required:
+ - repository
+ type: object
+ glusterfs:
+ description: 'Glusterfs represents a Glusterfs
+ mount on the host that shares a pod''s lifetime.
+ More info: https://examples.k8s.io/volumes/glusterfs/README.md'
+ properties:
+ endpoints:
+ description: 'EndpointsName is the endpoint
+ name that details Glusterfs topology. More
+ info: https://examples.k8s.io/volumes/glusterfs/README.md#create-a-pod'
+ type: string
+ path:
+ description: 'Path is the Glusterfs volume
+ path. More info: https://examples.k8s.io/volumes/glusterfs/README.md#create-a-pod'
+ type: string
+ readOnly:
+ description: 'ReadOnly here will force the
+ Glusterfs volume to be mounted with read-only
+ permissions. Defaults to false. More info:
+ https://examples.k8s.io/volumes/glusterfs/README.md#create-a-pod'
+ type: boolean
+ required:
+ - endpoints
+ - path
+ type: object
+ hostPath:
+ description: 'HostPath represents a pre-existing
+ file or directory on the host machine that is
+ directly exposed to the container. This is generally
+ used for system agents or other privileged things
+ that are allowed to see the host machine. Most
+ containers will NOT need this. More info: https://kubernetes.io/docs/concepts/storage/volumes#hostpath
+ --- TODO(jonesdl) We need to restrict who can
+ use host directory mounts and who can/can not
+ mount host directories as read/write.'
+ properties:
+ path:
+ description: 'Path of the directory on the
+ host. If the path is a symlink, it will
+ follow the link to the real path. More info:
+ https://kubernetes.io/docs/concepts/storage/volumes#hostpath'
+ type: string
+ type:
+ description: 'Type for HostPath Volume Defaults
+ to "" More info: https://kubernetes.io/docs/concepts/storage/volumes#hostpath'
+ type: string
+ required:
+ - path
+ type: object
+ iscsi:
+ description: 'ISCSI represents an ISCSI Disk resource
+ that is attached to a kubelet''s host machine
+ and then exposed to the pod. More info: https://examples.k8s.io/volumes/iscsi/README.md'
+ properties:
+ chapAuthDiscovery:
+ description: whether support iSCSI Discovery
+ CHAP authentication
+ type: boolean
+ chapAuthSession:
+ description: whether support iSCSI Session
+ CHAP authentication
+ type: boolean
+ fsType:
+ description: 'Filesystem type of the volume
+ that you want to mount. Tip: Ensure that
+ the filesystem type is supported by the
+ host operating system. Examples: "ext4",
+ "xfs", "ntfs". Implicitly inferred to be
+ "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#iscsi
+ TODO: how do we prevent errors in the filesystem
+ from compromising the machine'
+ type: string
+ initiatorName:
+ description: Custom iSCSI Initiator Name.
+ If initiatorName is specified with iscsiInterface
+ simultaneously, new iSCSI interface : will be created for
+ the connection.
+ type: string
+ iqn:
+ description: Target iSCSI Qualified Name.
+ type: string
+ iscsiInterface:
+ description: iSCSI Interface Name that uses
+ an iSCSI transport. Defaults to 'default'
+ (tcp).
+ type: string
+ lun:
+ description: iSCSI Target Lun number.
+ format: int32
+ type: integer
+ portals:
+ description: iSCSI Target Portal List. The
+ portal is either an IP or ip_addr:port if
+ the port is other than default (typically
+ TCP ports 860 and 3260).
+ items:
+ type: string
+ type: array
+ readOnly:
+ description: ReadOnly here will force the
+ ReadOnly setting in VolumeMounts. Defaults
+ to false.
+ type: boolean
+ secretRef:
+ description: CHAP Secret for iSCSI target
+ and initiator authentication
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ targetPortal:
+ description: iSCSI Target Portal. The Portal
+ is either an IP or ip_addr:port if the port
+ is other than default (typically TCP ports
+ 860 and 3260).
+ type: string
+ required:
+ - iqn
+ - lun
+ - targetPortal
+ type: object
+ name:
+ description: 'Volume''s name. Must be a DNS_LABEL
+ and unique within the pod. More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names'
+ type: string
+ nfs:
+ description: 'NFS represents an NFS mount on the
+ host that shares a pod''s lifetime More info:
+ https://kubernetes.io/docs/concepts/storage/volumes#nfs'
+ properties:
+ path:
+ description: 'Path that is exported by the
+ NFS server. More info: https://kubernetes.io/docs/concepts/storage/volumes#nfs'
+ type: string
+ readOnly:
+ description: 'ReadOnly here will force the
+ NFS export to be mounted with read-only
+ permissions. Defaults to false. More info:
+ https://kubernetes.io/docs/concepts/storage/volumes#nfs'
+ type: boolean
+ server:
+ description: 'Server is the hostname or IP
+ address of the NFS server. More info: https://kubernetes.io/docs/concepts/storage/volumes#nfs'
+ type: string
+ required:
+ - path
+ - server
+ type: object
+ persistentVolumeClaim:
+ description: 'PersistentVolumeClaimVolumeSource
+ represents a reference to a PersistentVolumeClaim
+ in the same namespace. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#persistentvolumeclaims'
+ properties:
+ claimName:
+ description: 'ClaimName is the name of a PersistentVolumeClaim
+ in the same namespace as the pod using this
+ volume. More info: https://kubernetes.io/docs/concepts/storage/persistent-volumes#persistentvolumeclaims'
+ type: string
+ readOnly:
+ description: Will force the ReadOnly setting
+ in VolumeMounts. Default false.
+ type: boolean
+ required:
+ - claimName
+ type: object
+ photonPersistentDisk:
+ description: PhotonPersistentDisk represents a
+ PhotonController persistent disk attached and
+ mounted on kubelets host machine
+ properties:
+ fsType:
+ description: Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Ex. "ext4", "xfs", "ntfs".
+ Implicitly inferred to be "ext4" if unspecified.
+ type: string
+ pdID:
+ description: ID that identifies Photon Controller
+ persistent disk
+ type: string
+ required:
+ - pdID
+ type: object
+ portworxVolume:
+ description: PortworxVolume represents a portworx
+ volume attached and mounted on kubelets host
+ machine
+ properties:
+ fsType:
+ description: FSType represents the filesystem
+ type to mount Must be a filesystem type
+ supported by the host operating system.
+ Ex. "ext4", "xfs". Implicitly inferred to
+ be "ext4" if unspecified.
+ type: string
+ readOnly:
+ description: Defaults to false (read/write).
+ ReadOnly here will force the ReadOnly setting
+ in VolumeMounts.
+ type: boolean
+ volumeID:
+ description: VolumeID uniquely identifies
+ a Portworx volume
+ type: string
+ required:
+ - volumeID
+ type: object
+ projected:
+ description: Items for all in one resources secrets,
+ configmaps, and downward API
+ properties:
+ defaultMode:
+ description: Mode bits used to set permissions
+ on created files by default. Must be an
+ octal value between 0000 and 0777 or a decimal
+ value between 0 and 511. YAML accepts both
+ octal and decimal values, JSON requires
+ decimal values for mode bits. Directories
+ within the path are not affected by this
+ setting. This might be in conflict with
+ other options that affect the file mode,
+ like fsGroup, and the result can be other
+ mode bits set.
+ format: int32
+ type: integer
+ sources:
+ description: list of volume projections
+ items:
+ description: Projection that may be projected
+ along with other supported volume types
+ properties:
+ configMap:
+ description: information about the configMap
+ data to project
+ properties:
+ items:
+ description: If unspecified, each
+ key-value pair in the Data field
+ of the referenced ConfigMap will
+ be projected into the volume as
+ a file whose name is the key and
+ content is the value. If specified,
+ the listed keys will be projected
+ into the specified paths, and
+ unlisted keys will not be present.
+ If a key is specified which is
+ not present in the ConfigMap,
+ the volume setup will error unless
+ it is marked optional. Paths must
+ be relative and may not contain
+ the '..' path or start with '..'.
+ items:
+ description: Maps a string key
+ to a path within a volume.
+ properties:
+ key:
+ description: The key to project.
+ type: string
+ mode:
+ description: 'Optional: mode
+ bits used to set permissions
+ on this file. Must be an
+ octal value between 0000
+ and 0777 or a decimal value
+ between 0 and 511. YAML
+ accepts both octal and decimal
+ values, JSON requires decimal
+ values for mode bits. If
+ not specified, the volume
+ defaultMode will be used.
+ This might be in conflict
+ with other options that
+ affect the file mode, like
+ fsGroup, and the result
+ can be other mode bits set.'
+ format: int32
+ type: integer
+ path:
+ description: The relative
+ path of the file to map
+ the key to. May not be an
+ absolute path. May not contain
+ the path element '..'. May
+ not start with the string
+ '..'.
+ type: string
+ required:
+ - key
+ - path
+ type: object
+ type: array
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ ConfigMap or its keys must be
+ defined
+ type: boolean
+ type: object
+ downwardAPI:
+ description: information about the downwardAPI
+ data to project
+ properties:
+ items:
+ description: Items is a list of
+ DownwardAPIVolume file
+ items:
+ description: DownwardAPIVolumeFile
+ represents information to create
+ the file containing the pod
+ field
+ properties:
+ fieldRef:
+ description: 'Required: Selects
+ a field of the pod: only
+ annotations, labels, name
+ and namespace are supported.'
+ properties:
+ apiVersion:
+ description: Version of
+ the schema the FieldPath
+ is written in terms
+ of, defaults to "v1".
+ type: string
+ fieldPath:
+ description: Path of the
+ field to select in the
+ specified API version.
+ type: string
+ required:
+ - fieldPath
+ type: object
+ mode:
+ description: 'Optional: mode
+ bits used to set permissions
+ on this file, must be an
+ octal value between 0000
+ and 0777 or a decimal value
+ between 0 and 511. YAML
+ accepts both octal and decimal
+ values, JSON requires decimal
+ values for mode bits. If
+ not specified, the volume
+ defaultMode will be used.
+ This might be in conflict
+ with other options that
+ affect the file mode, like
+ fsGroup, and the result
+ can be other mode bits set.'
+ format: int32
+ type: integer
+ path:
+ description: 'Required: Path
+ is the relative path name
+ of the file to be created.
+ Must not be absolute or
+ contain the ''..'' path.
+ Must be utf-8 encoded. The
+ first item of the relative
+ path must not start with
+ ''..'''
+ type: string
+ resourceFieldRef:
+ description: 'Selects a resource
+ of the container: only resources
+ limits and requests (limits.cpu,
+ limits.memory, requests.cpu
+ and requests.memory) are
+ currently supported.'
+ properties:
+ containerName:
+ description: 'Container
+ name: required for volumes,
+ optional for env vars'
+ type: string
+ divisor:
+ anyOf:
+ - type: integer
+ - type: string
+ description: Specifies
+ the output format of
+ the exposed resources,
+ defaults to "1"
+ pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$
+ x-kubernetes-int-or-string: true
+ resource:
+ description: 'Required:
+ resource to select'
+ type: string
+ required:
+ - resource
+ type: object
+ required:
+ - path
+ type: object
+ type: array
+ type: object
+ secret:
+ description: information about the secret
+ data to project
+ properties:
+ items:
+ description: If unspecified, each
+ key-value pair in the Data field
+ of the referenced Secret will
+ be projected into the volume as
+ a file whose name is the key and
+ content is the value. If specified,
+ the listed keys will be projected
+ into the specified paths, and
+ unlisted keys will not be present.
+ If a key is specified which is
+ not present in the Secret, the
+ volume setup will error unless
+ it is marked optional. Paths must
+ be relative and may not contain
+ the '..' path or start with '..'.
+ items:
+ description: Maps a string key
+ to a path within a volume.
+ properties:
+ key:
+ description: The key to project.
+ type: string
+ mode:
+ description: 'Optional: mode
+ bits used to set permissions
+ on this file. Must be an
+ octal value between 0000
+ and 0777 or a decimal value
+ between 0 and 511. YAML
+ accepts both octal and decimal
+ values, JSON requires decimal
+ values for mode bits. If
+ not specified, the volume
+ defaultMode will be used.
+ This might be in conflict
+ with other options that
+ affect the file mode, like
+ fsGroup, and the result
+ can be other mode bits set.'
+ format: int32
+ type: integer
+ path:
+ description: The relative
+ path of the file to map
+ the key to. May not be an
+ absolute path. May not contain
+ the path element '..'. May
+ not start with the string
+ '..'.
+ type: string
+ required:
+ - key
+ - path
+ type: object
+ type: array
+ name:
+ description: 'Name of the referent.
+ More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields.
+ apiVersion, kind, uid?'
+ type: string
+ optional:
+ description: Specify whether the
+ Secret or its key must be defined
+ type: boolean
+ type: object
+ serviceAccountToken:
+ description: information about the serviceAccountToken
+ data to project
+ properties:
+ audience:
+ description: Audience is the intended
+ audience of the token. A recipient
+ of a token must identify itself
+ with an identifier specified in
+ the audience of the token, and
+ otherwise should reject the token.
+ The audience defaults to the identifier
+ of the apiserver.
+ type: string
+ expirationSeconds:
+ description: ExpirationSeconds is
+ the requested duration of validity
+ of the service account token.
+ As the token approaches expiration,
+ the kubelet volume plugin will
+ proactively rotate the service
+ account token. The kubelet will
+ start trying to rotate the token
+ if the token is older than 80
+ percent of its time to live or
+ if the token is older than 24
+ hours.Defaults to 1 hour and must
+ be at least 10 minutes.
+ format: int64
+ type: integer
+ path:
+ description: Path is the path relative
+ to the mount point of the file
+ to project the token into.
+ type: string
+ required:
+ - path
+ type: object
+ type: object
+ type: array
+ type: object
+ quobyte:
+ description: Quobyte represents a Quobyte mount
+ on the host that shares a pod's lifetime
+ properties:
+ group:
+ description: Group to map volume access to
+ Default is no group
+ type: string
+ readOnly:
+ description: ReadOnly here will force the
+ Quobyte volume to be mounted with read-only
+ permissions. Defaults to false.
+ type: boolean
+ registry:
+ description: Registry represents a single
+ or multiple Quobyte Registry services specified
+ as a string as host:port pair (multiple
+ entries are separated with commas) which
+ acts as the central registry for volumes
+ type: string
+ tenant:
+ description: Tenant owning the given Quobyte
+ volume in the Backend Used with dynamically
+ provisioned Quobyte volumes, value is set
+ by the plugin
+ type: string
+ user:
+ description: User to map volume access to
+ Defaults to serivceaccount user
+ type: string
+ volume:
+ description: Volume is a string that references
+ an already created Quobyte volume by name.
+ type: string
+ required:
+ - registry
+ - volume
+ type: object
+ rbd:
+ description: 'RBD represents a Rados Block Device
+ mount on the host that shares a pod''s lifetime.
+ More info: https://examples.k8s.io/volumes/rbd/README.md'
+ properties:
+ fsType:
+ description: 'Filesystem type of the volume
+ that you want to mount. Tip: Ensure that
+ the filesystem type is supported by the
+ host operating system. Examples: "ext4",
+ "xfs", "ntfs". Implicitly inferred to be
+ "ext4" if unspecified. More info: https://kubernetes.io/docs/concepts/storage/volumes#rbd
+ TODO: how do we prevent errors in the filesystem
+ from compromising the machine'
+ type: string
+ image:
+ description: 'The rados image name. More info:
+ https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it'
+ type: string
+ keyring:
+ description: 'Keyring is the path to key ring
+ for RBDUser. Default is /etc/ceph/keyring.
+ More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it'
+ type: string
+ monitors:
+ description: 'A collection of Ceph monitors.
+ More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it'
+ items:
+ type: string
+ type: array
+ pool:
+ description: 'The rados pool name. Default
+ is rbd. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it'
+ type: string
+ readOnly:
+ description: 'ReadOnly here will force the
+ ReadOnly setting in VolumeMounts. Defaults
+ to false. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it'
+ type: boolean
+ secretRef:
+ description: 'SecretRef is name of the authentication
+ secret for RBDUser. If provided overrides
+ keyring. Default is nil. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it'
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ user:
+ description: 'The rados user name. Default
+ is admin. More info: https://examples.k8s.io/volumes/rbd/README.md#how-to-use-it'
+ type: string
+ required:
+ - image
+ - monitors
+ type: object
+ scaleIO:
+ description: ScaleIO represents a ScaleIO persistent
+ volume attached and mounted on Kubernetes nodes.
+ properties:
+ fsType:
+ description: Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Ex. "ext4", "xfs", "ntfs".
+ Default is "xfs".
+ type: string
+ gateway:
+ description: The host address of the ScaleIO
+ API Gateway.
+ type: string
+ protectionDomain:
+ description: The name of the ScaleIO Protection
+ Domain for the configured storage.
+ type: string
+ readOnly:
+ description: Defaults to false (read/write).
+ ReadOnly here will force the ReadOnly setting
+ in VolumeMounts.
+ type: boolean
+ secretRef:
+ description: SecretRef references to the secret
+ for ScaleIO user and other sensitive information.
+ If this is not provided, Login operation
+ will fail.
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ sslEnabled:
+ description: Flag to enable/disable SSL communication
+ with Gateway, default false
+ type: boolean
+ storageMode:
+ description: Indicates whether the storage
+ for a volume should be ThickProvisioned
+ or ThinProvisioned. Default is ThinProvisioned.
+ type: string
+ storagePool:
+ description: The ScaleIO Storage Pool associated
+ with the protection domain.
+ type: string
+ system:
+ description: The name of the storage system
+ as configured in ScaleIO.
+ type: string
+ volumeName:
+ description: The name of a volume already
+ created in the ScaleIO system that is associated
+ with this volume source.
+ type: string
+ required:
+ - gateway
+ - secretRef
+ - system
+ type: object
+ secret:
+ description: 'Secret represents a secret that
+ should populate this volume. More info: https://kubernetes.io/docs/concepts/storage/volumes#secret'
+ properties:
+ defaultMode:
+ description: 'Optional: mode bits used to
+ set permissions on created files by default.
+ Must be an octal value between 0000 and
+ 0777 or a decimal value between 0 and 511.
+ YAML accepts both octal and decimal values,
+ JSON requires decimal values for mode bits.
+ Defaults to 0644. Directories within the
+ path are not affected by this setting. This
+ might be in conflict with other options
+ that affect the file mode, like fsGroup,
+ and the result can be other mode bits set.'
+ format: int32
+ type: integer
+ items:
+ description: If unspecified, each key-value
+ pair in the Data field of the referenced
+ Secret will be projected into the volume
+ as a file whose name is the key and content
+ is the value. If specified, the listed keys
+ will be projected into the specified paths,
+ and unlisted keys will not be present. If
+ a key is specified which is not present
+ in the Secret, the volume setup will error
+ unless it is marked optional. Paths must
+ be relative and may not contain the '..'
+ path or start with '..'.
+ items:
+ description: Maps a string key to a path
+ within a volume.
+ properties:
+ key:
+ description: The key to project.
+ type: string
+ mode:
+ description: 'Optional: mode bits used
+ to set permissions on this file. Must
+ be an octal value between 0000 and
+ 0777 or a decimal value between 0
+ and 511. YAML accepts both octal and
+ decimal values, JSON requires decimal
+ values for mode bits. If not specified,
+ the volume defaultMode will be used.
+ This might be in conflict with other
+ options that affect the file mode,
+ like fsGroup, and the result can be
+ other mode bits set.'
+ format: int32
+ type: integer
+ path:
+ description: The relative path of the
+ file to map the key to. May not be
+ an absolute path. May not contain
+ the path element '..'. May not start
+ with the string '..'.
+ type: string
+ required:
+ - key
+ - path
+ type: object
+ type: array
+ optional:
+ description: Specify whether the Secret or
+ its keys must be defined
+ type: boolean
+ secretName:
+ description: 'Name of the secret in the pod''s
+ namespace to use. More info: https://kubernetes.io/docs/concepts/storage/volumes#secret'
+ type: string
+ type: object
+ storageos:
+ description: StorageOS represents a StorageOS
+ volume attached and mounted on Kubernetes nodes.
+ properties:
+ fsType:
+ description: Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Ex. "ext4", "xfs", "ntfs".
+ Implicitly inferred to be "ext4" if unspecified.
+ type: string
+ readOnly:
+ description: Defaults to false (read/write).
+ ReadOnly here will force the ReadOnly setting
+ in VolumeMounts.
+ type: boolean
+ secretRef:
+ description: SecretRef specifies the secret
+ to use for obtaining the StorageOS API credentials. If
+ not specified, default values will be attempted.
+ properties:
+ name:
+ description: 'Name of the referent. More
+ info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names
+ TODO: Add other useful fields. apiVersion,
+ kind, uid?'
+ type: string
+ type: object
+ volumeName:
+ description: VolumeName is the human-readable
+ name of the StorageOS volume. Volume names
+ are only unique within a namespace.
+ type: string
+ volumeNamespace:
+ description: VolumeNamespace specifies the
+ scope of the volume within StorageOS. If
+ no namespace is specified then the Pod's
+ namespace will be used. This allows the
+ Kubernetes name scoping to be mirrored within
+ StorageOS for tighter integration. Set VolumeName
+ to any name to override the default behaviour.
+ Set to "default" if you are not using namespaces
+ within StorageOS. Namespaces that do not
+ pre-exist within StorageOS will be created.
+ type: string
+ type: object
+ vsphereVolume:
+ description: VsphereVolume represents a vSphere
+ volume attached and mounted on kubelets host
+ machine
+ properties:
+ fsType:
+ description: Filesystem type to mount. Must
+ be a filesystem type supported by the host
+ operating system. Ex. "ext4", "xfs", "ntfs".
+ Implicitly inferred to be "ext4" if unspecified.
+ type: string
+ storagePolicyID:
+ description: Storage Policy Based Management
+ (SPBM) profile ID associated with the StoragePolicyName.
+ type: string
+ storagePolicyName:
+ description: Storage Policy Based Management
+ (SPBM) profile name.
+ type: string
+ volumePath:
+ description: Path that identifies vSphere
+ volume vmdk
+ type: string
+ required:
+ - volumePath
+ type: object
+ required:
+ - name
+ type: object
+ type: array
+ required:
+ - containers
+ type: object
+ type: object
+ type: object
+ description: 'INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
+ Important: Run "make" to regenerate code after modifying this file
+ Defines replica spec for replica type'
+ type: object
+ ttlSecondsAfterFinished:
+ default: 86400
+ description: TTLSecondsAfterFinished is the TTL to clean up jobs.
+ It may take extra ReconcilePeriod seconds for the cleanup, since
+ reconcile gets called periodically. Default to 86400(one day).
+ format: int64
+ type: integer
+ required:
+ - fedReplicaSpecs
+ type: object
+ status:
+ description: FedAppStatus defines the observed state of FedApp
+ properties:
+ conditions:
+ items:
+ description: FedAppCondition describes current state of a job.
+ properties:
+ lastTransitionTime:
+ description: Last time the condition transit from one status
+ to another.
+ format: date-time
+ type: string
+ message:
+ description: Human readable message indicating details about
+ last transition.
+ type: string
+ reason:
+ description: (brief) reason for the condition's last transition.
+ type: string
+ status:
+ description: Status of the condition, one of True, False, Unknown.
+ type: string
+ type:
+ description: Type of job condition.
+ type: string
+ required:
+ - status
+ - type
+ type: object
+ type: array
+ startTime:
+ format: date-time
+ type: string
+ terminatedPodsMap:
+ additionalProperties:
+ description: TerminatedPods holds name of Pods that have terminated.
+ properties:
+ failed:
+ description: Failed holds name of failed Pods.
+ items:
+ additionalProperties:
+ type: object
+ type: object
+ type: array
+ succeeded:
+ description: Succeeded holds name of succeeded Pods.
+ items:
+ additionalProperties:
+ type: object
+ type: object
+ type: array
+ type: object
+ description: 'Record pods name which have terminated, hack for too
+ fast pod GC. TODO: when pods gc collection is too fast that fedapp
+ controller dont have enough time to record them in TerminatedPodsMap
+ field, use finalizer to avoid it.'
+ type: object
+ type: object
+ type: object
+ served: true
+ storage: true
+ subresources:
+ status: {}
+status:
+ acceptedNames:
+ kind: ""
+ plural: ""
+ conditions: []
+ storedVersions: []
diff --git a/operator/deploy_charts/hl-test-manager.yaml b/operator/deploy_charts/hl-test-manager.yaml
new file mode 100644
index 000000000..f93be49a4
--- /dev/null
+++ b/operator/deploy_charts/hl-test-manager.yaml
@@ -0,0 +1,64 @@
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: controller-manager
+ namespace: fedlearner
+ labels:
+ control-plane: controller-manager
+spec:
+ selector:
+ matchLabels:
+ control-plane: controller-manager
+ replicas: 1
+ template:
+ metadata:
+ annotations:
+ kubectl.kubernetes.io/default-container: manager
+ labels:
+ control-plane: controller-manager
+ spec:
+ securityContext:
+ runAsNonRoot: true
+ volumes:
+ - name: hl-kubeconfig
+ configMap:
+ name: hl-kubeconfig
+ defaultMode: 420
+ containers:
+ - command:
+ - /manager
+ env:
+ - name: KUBECONFIG
+ value: /.kube/config
+ args:
+ - --ingress-extra-host-suffix=.fl-bytedance.com
+ - --namespace=fedlearner
+ image: artifact.bytedance.com/fedlearner/pp_fedapp_operator:0.2.2
+ name: manager
+ securityContext:
+ allowPrivilegeEscalation: false
+ livenessProbe:
+ httpGet:
+ path: /healthz
+ port: 8081
+ initialDelaySeconds: 15
+ periodSeconds: 20
+ readinessProbe:
+ httpGet:
+ path: /readyz
+ port: 8081
+ initialDelaySeconds: 5
+ periodSeconds: 10
+ # TODO(user): Configure the resources accordingly based on the project requirements.
+ # More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/
+ resources:
+ limits:
+ cpu: 500m
+ memory: 128Mi
+ requests:
+ cpu: 10m
+ memory: 64Mi
+ volumeMounts:
+ - name: hl-kubeconfig
+ mountPath: /.kube/
+ terminationGracePeriodSeconds: 10
\ No newline at end of file
diff --git a/operator/deploy_charts/third_deployment.yaml b/operator/deploy_charts/third_deployment.yaml
new file mode 100644
index 000000000..b2bd6691a
--- /dev/null
+++ b/operator/deploy_charts/third_deployment.yaml
@@ -0,0 +1,56 @@
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: controller-manager
+ namespace: default
+ labels:
+ control-plane: controller-manager
+spec:
+ selector:
+ matchLabels:
+ control-plane: controller-manager
+ replicas: 1
+ template:
+ metadata:
+ annotations:
+ kubectl.kubernetes.io/default-container: manager
+ labels:
+ control-plane: controller-manager
+ spec:
+ imagePullSecrets:
+ - name: regcred
+ serviceAccountName: fedlearner-operator
+ securityContext:
+ runAsNonRoot: true
+ containers:
+ - name: manager
+ command:
+ - /manager
+ args:
+ - --ingress-extra-host-suffix=.fl-tryit1.com
+ - --namespace=default
+ image: artifact.bytedance.com/fedlearner/pp_fedapp_operator:0.2.2
+ securityContext:
+ allowPrivilegeEscalation: false
+ livenessProbe:
+ httpGet:
+ path: /healthz
+ port: 8081
+ initialDelaySeconds: 15
+ periodSeconds: 20
+ readinessProbe:
+ httpGet:
+ path: /readyz
+ port: 8081
+ initialDelaySeconds: 5
+ periodSeconds: 10
+ # TODO(user): Configure the resources accordingly based on the project requirements.
+ # More info: https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/
+ resources:
+ limits:
+ cpu: 500m
+ memory: 128Mi
+ requests:
+ cpu: 10m
+ memory: 64Mi
+ terminationGracePeriodSeconds: 10
\ No newline at end of file
diff --git a/operator/hack/boilerplate.go.txt b/operator/hack/boilerplate.go.txt
new file mode 100644
index 000000000..65b862271
--- /dev/null
+++ b/operator/hack/boilerplate.go.txt
@@ -0,0 +1,15 @@
+/*
+Copyright 2023.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
\ No newline at end of file
diff --git a/operator/main.go b/operator/main.go
new file mode 100644
index 000000000..26ed6b23e
--- /dev/null
+++ b/operator/main.go
@@ -0,0 +1,106 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package main
+
+import (
+ "flag"
+ "os"
+
+ // Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.)
+ // to ensure that exec-entrypoint and run can make use of them.
+ _ "k8s.io/client-go/plugin/pkg/client/auth"
+
+ "k8s.io/apimachinery/pkg/runtime"
+ utilruntime "k8s.io/apimachinery/pkg/util/runtime"
+ clientgoscheme "k8s.io/client-go/kubernetes/scheme"
+ ctrl "sigs.k8s.io/controller-runtime"
+ "sigs.k8s.io/controller-runtime/pkg/healthz"
+ "sigs.k8s.io/controller-runtime/pkg/log/zap"
+
+ fedlearnerv1alpha1 "fedlearner.net/operator/api/v1alpha1"
+ "fedlearner.net/operator/controllers"
+ //+kubebuilder:scaffold:imports
+)
+
+var (
+ scheme = runtime.NewScheme()
+ setupLog = ctrl.Log.WithName("setup")
+)
+
+func init() {
+ utilruntime.Must(clientgoscheme.AddToScheme(scheme))
+ utilruntime.Must(fedlearnerv1alpha1.AddToScheme(scheme))
+ //+kubebuilder:scaffold:scheme
+}
+
+func main() {
+ var metricsAddr string
+ var enableLeaderElection bool
+ var probeAddr string
+ var namespace string
+ flag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the metric endpoint binds to.")
+ flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
+ flag.StringVar(&namespace, "namespace", "default", "namespace")
+ flag.BoolVar(&enableLeaderElection, "leader-elect", false,
+ "Enable leader election for controller manager. "+
+ "Enabling this will ensure there is only one active controller manager.")
+ flag.StringVar(&controllers.IngressExtraHostSuffix, "ingress-extra-host-suffix", ".fl-aliyun-test.com", "The extra suffix of hosts when creating ingress.")
+ opts := zap.Options{
+ Development: true,
+ }
+ opts.BindFlags(flag.CommandLine)
+ flag.Parse()
+
+ ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts)))
+
+ mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{
+ Namespace: namespace,
+ Scheme: scheme,
+ MetricsBindAddress: metricsAddr,
+ Port: 9443,
+ HealthProbeBindAddress: probeAddr,
+ LeaderElection: enableLeaderElection,
+ LeaderElectionID: "bc5e4174.k8s.io",
+ })
+ if err != nil {
+ setupLog.Error(err, "unable to start manager")
+ os.Exit(1)
+ }
+
+ if err = (&controllers.FedAppReconciler{
+ Client: mgr.GetClient(),
+ Scheme: mgr.GetScheme(),
+ }).SetupWithManager(mgr); err != nil {
+ setupLog.Error(err, "unable to create controller", "controller", "FedApp")
+ os.Exit(1)
+ }
+ //+kubebuilder:scaffold:builder
+
+ if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil {
+ setupLog.Error(err, "unable to set up health check")
+ os.Exit(1)
+ }
+ if err := mgr.AddReadyzCheck("readyz", healthz.Ping); err != nil {
+ setupLog.Error(err, "unable to set up ready check")
+ os.Exit(1)
+ }
+
+ setupLog.Info("starting manager")
+ if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
+ setupLog.Error(err, "problem running manager")
+ os.Exit(1)
+ }
+}
diff --git a/pp_lite/BUILD.bazel b/pp_lite/BUILD.bazel
new file mode 100644
index 000000000..a3923df98
--- /dev/null
+++ b/pp_lite/BUILD.bazel
@@ -0,0 +1,30 @@
+# gazelle:exclude spark/
+load("@rules_python//python:defs.bzl", "py_binary", "py_library")
+
+package_group(
+ name = "pp_lite_package",
+ packages = ["//pp_lite/..."],
+)
+
+py_library(
+ name = "pp_lite",
+ srcs = [
+ "cli.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join/psi_ot",
+ "//pp_lite/data_join/psi_rsa",
+ "//pp_lite/proto:py_proto",
+ "//web_console_v2/inspection:error_code_lib",
+ "@common_click//:pkg",
+ ],
+)
+
+py_binary(
+ name = "cli_bin",
+ srcs = ["cli.py"],
+ main = "cli.py",
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":pp_lite"],
+)
diff --git a/pp_lite/README.md b/pp_lite/README.md
new file mode 100644
index 000000000..e05d43b69
--- /dev/null
+++ b/pp_lite/README.md
@@ -0,0 +1,45 @@
+# 隐私求交轻客户端
+
+## 文件结构
+轻客户端的文件由以下几个部分组成,
+- images/pp_lite:定义容器入口shell脚本,以及Dockerfile。
+- pp_lite:定义具体业务逻辑和客户端打包方法。
+- tools/tcp_grpc_proxy:提供了tcp转grpc的功能,主要用于OtPsi。
+```
+images
+.
+└── pp_lite
+ ├── nginx.tmpl # nginx模版或其他脚本
+ ├── entrypoint.sh # 入口文件
+ └── Dockerfile
+```
+
+```
+pp_lite
+.
+├── cli.py # 入口文件
+├── requirements.txt
+├── data_join # 求交实现
+│ ├── psi_rsa # rsa求交
+│ ├── psi_ot # ot求交
+│ └── utils
+├── rpc # rpc相关代码
+├── test # 集成测试
+└── deploy # 轻客户端打包脚本
+```
+- 其中cli.py通过click实现,封装了轻客户端提供的各种功能。images/psi/scripts/entrypoint.sh将外部参数透传给cli.py。
+- test中实现了一些集成测试,一些局部的ut和被测试文件放在一起。
+- proto文件不单独存放,放在具体的使用的位置。
+
+## 镜像管理方式
+整个轻客户端只打包一个镜像,其中
+- Dockerfile存储在images/pp_lite/Dockerfile
+- 入口脚本为images/pp_lite/entrypoint.sh
+
+通过传递不同的参数指定容器不同的行为。
+
+## 使用方式
+- 服务端通过平台中的数据模块进行求交。
+- 客户端的使用方式可参考pp_lite/deploy/README.md。
+
+具体demo可参考test中的测试用例。
\ No newline at end of file
diff --git a/pp_lite/cli.py b/pp_lite/cli.py
new file mode 100644
index 000000000..4101db2a7
--- /dev/null
+++ b/pp_lite/cli.py
@@ -0,0 +1,91 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import click
+import logging
+from pp_lite.data_join.psi_rsa import psi_client as psi_rsa_client
+from pp_lite.data_join.psi_rsa import psi_server as psi_rsa_server
+from pp_lite.data_join.psi_ot import client as psi_ot_and_hash_client
+from pp_lite.data_join.psi_ot import server as psi_ot_and_hash_server
+from pp_lite.data_join.psi_ot import arguments as psi_ot_and_hash_arguments
+from pp_lite.proto.common_pb2 import DataJoinType
+
+from web_console_v2.inspection.error_code import AreaCode, JobException, write_termination_message
+
+
+@click.group(name='pp_lite')
+def pp_lite():
+ pass
+
+
+# TODO(zhou.yi): add psi rsa options
+@pp_lite.command()
+@click.argument('role', type=click.Choice(['client', 'server', 'light_client']))
+def psi_rsa(role: str):
+ try:
+ if role in ['client', 'light_client']:
+ psi_rsa_client.run(psi_rsa_client.get_arguments())
+ else:
+ psi_rsa_server.run(psi_rsa_server.get_arguments())
+ except JobException as e:
+ logging.exception(e.message)
+ write_termination_message(AreaCode.PSI_RSA, e.error_type, e.message)
+ raise JobException(AreaCode.PSI_RSA, e.error_type, e.message) from e
+
+
+# TODO(zhou.yi): add psi ot options
+@pp_lite.command()
+@click.argument('role', type=click.Choice(['client', 'server', 'light_client']))
+def psi_ot(role: str):
+ try:
+ args = psi_ot_and_hash_arguments.get_arguments()
+ args.data_join_type = DataJoinType.OT_PSI
+ if role == 'client':
+ args.partitioned = True
+ psi_ot_and_hash_client.run(args)
+ elif role == 'light_client':
+ args.partitioned = False
+ psi_ot_and_hash_client.run(args)
+ else:
+ psi_ot_and_hash_server.run(args)
+ except JobException as e:
+ logging.exception(e.message)
+ write_termination_message(AreaCode.PSI_OT, e.error_type, e.message)
+ raise JobException(AreaCode.PSI_OT, e.error_type, e.message) from e
+
+
+# TODO(zhou.yi): add psi hash options
+@pp_lite.command()
+@click.argument('role', type=click.Choice(['client', 'server', 'light_client']))
+def psi_hash(role: str):
+ try:
+ args = psi_ot_and_hash_arguments.get_arguments()
+ args.data_join_type = DataJoinType.HASHED_DATA_JOIN
+ if role == 'client':
+ args.partitioned = True
+ psi_ot_and_hash_client.run(args)
+ elif role == 'light_client':
+ args.partitioned = False
+ psi_ot_and_hash_client.run(args)
+ else:
+ psi_ot_and_hash_server.run(args)
+ except JobException as e:
+ logging.exception(e.message)
+ write_termination_message(AreaCode.PSI_HASH, e.error_type, e.message)
+ raise JobException(AreaCode.PSI_HASH, e.error_type, e.message) from e
+
+
+if __name__ == '__main__':
+ pp_lite()
diff --git a/pp_lite/data_join/BUILD.bazel b/pp_lite/data_join/BUILD.bazel
new file mode 100644
index 000000000..0e039a07a
--- /dev/null
+++ b/pp_lite/data_join/BUILD.bazel
@@ -0,0 +1,7 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+py_library(
+ name = "envs",
+ srcs = ["envs.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+)
diff --git a/pp_lite/data_join/envs.py b/pp_lite/data_join/envs.py
new file mode 100644
index 000000000..3fd86cc7a
--- /dev/null
+++ b/pp_lite/data_join/envs.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+GRPC_CLIENT_TIMEOUT = int(os.environ.get('GRPC_CLIENT_TIMEOUT', 30))
+STORAGE_ROOT = os.environ.get('STORAGE_ROOT', '/app')
+CLIENT_CONNECT_RETRY_INTERVAL = 10
diff --git a/pp_lite/data_join/psi_ot/BUILD.bazel b/pp_lite/data_join/psi_ot/BUILD.bazel
new file mode 100644
index 000000000..e1a204975
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/BUILD.bazel
@@ -0,0 +1,47 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+py_library(
+ name = "psi_ot",
+ srcs = [
+ "arguments.py",
+ "client.py",
+ "data_join_control_servicer.py",
+ "data_join_manager.py",
+ "data_join_server.py",
+ "server.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join:envs",
+ "//pp_lite/data_join/psi_ot/joiner",
+ "//pp_lite/data_join/utils",
+ "//pp_lite/proto:py_grpc",
+ "//pp_lite/proto:py_proto",
+ "//pp_lite/rpc",
+ "//pp_lite/utils",
+ "//py_libs:metrics_lib",
+ "//web_console_v2/inspection:error_code_lib",
+ ],
+)
+
+py_test(
+ name = "data_join_server_test",
+ size = "small",
+ srcs = ["data_join_server_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":psi_ot",
+ "//pp_lite/testing",
+ ],
+)
+
+py_test(
+ name = "arguments_test",
+ size = "small",
+ srcs = ["arguments_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":psi_ot",
+ "//pp_lite/testing",
+ ],
+)
diff --git a/pp_lite/data_join/psi_ot/arguments.py b/pp_lite/data_join/psi_ot/arguments.py
new file mode 100644
index 000000000..5b7fe62b8
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/arguments.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from os import getenv
+import json
+from pp_lite.proto.arguments_pb2 import Arguments
+from web_console_v2.inspection.error_code import AreaCode, ErrorType, JobException
+
+
+def get_arguments() -> Arguments:
+ for i in ['INPUT_PATH', 'OUTPUT_PATH', 'KEY_COLUMN', 'SERVER_PORT', 'JOINER_PORT']:
+ if getenv(i) is None:
+ raise JobException(AreaCode.UNKNOWN, ErrorType.INPUT_PARAMS_ERROR, f'Environment variable {i} is missing.')
+ args = Arguments()
+ args.input_path = getenv('INPUT_PATH')
+ args.output_path = getenv('OUTPUT_PATH')
+ args.key_column = getenv('KEY_COLUMN')
+ args.server_port = int(getenv('SERVER_PORT'))
+ args.joiner_port = int(getenv('JOINER_PORT'))
+ args.worker_rank = int(getenv('INDEX', '0'))
+ if getenv('NUM_WORKERS'):
+ args.num_workers = int(getenv('NUM_WORKERS'))
+ role = getenv('ROLE', '')
+ if getenv('CLUSTER_SPEC') and role != 'light_client':
+ cluster_spec = json.loads(getenv('CLUSTER_SPEC'))
+
+ # Only accept CLUSTER_SPEC in cluster environment,
+ # so that CLUSTER_SPEC from .env in light client environment can be omitted.
+ if 'clusterSpec' in cluster_spec:
+ args.cluster_spec.workers.extend(cluster_spec['clusterSpec']['Worker'])
+ args.num_workers = len(args.cluster_spec.workers)
+ assert args.num_workers > 0
+ return args
diff --git a/pp_lite/data_join/psi_ot/arguments_test.py b/pp_lite/data_join/psi_ot/arguments_test.py
new file mode 100644
index 000000000..0ad23650d
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/arguments_test.py
@@ -0,0 +1,93 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import unittest
+
+from pp_lite.data_join.psi_ot.arguments import get_arguments
+from pp_lite.proto.arguments_pb2 import Arguments, ClusterSpec
+
+
+class ArgumentsTest(unittest.TestCase):
+
+ def test_get_client_arguments(self):
+ os.environ['INPUT_PATH'] = 'input'
+ os.environ['OUTPUT_PATH'] = 'output'
+ os.environ['KEY_COLUMN'] = 'raw_id'
+ os.environ['INDEX'] = '0'
+ os.environ['NUM_WORKERS'] = '1'
+ os.environ['JOINER_PORT'] = '12345'
+ os.environ['SERVER_PORT'] = '54321'
+ os.environ['ROLE'] = 'client'
+ args = get_arguments()
+ self.assertEqual(
+ args,
+ Arguments(input_path='input',
+ output_path='output',
+ key_column='raw_id',
+ server_port=54321,
+ joiner_port=12345,
+ worker_rank=0,
+ num_workers=1))
+ os.environ['CLUSTER_SPEC'] = json.dumps({'clusterSpec': {'Worker': ['worker-0', 'worker-1']}})
+ args = get_arguments()
+ cluster_spec = ClusterSpec()
+ cluster_spec.workers.extend(['worker-0', 'worker-1'])
+ self.assertEqual(
+ args,
+ Arguments(input_path='input',
+ output_path='output',
+ key_column='raw_id',
+ server_port=54321,
+ joiner_port=12345,
+ worker_rank=0,
+ num_workers=2,
+ cluster_spec=cluster_spec))
+
+ def test_get_light_client_arguments(self):
+ os.environ['INPUT_PATH'] = 'input'
+ os.environ['OUTPUT_PATH'] = 'output'
+ os.environ['KEY_COLUMN'] = 'raw_id'
+ os.environ['INDEX'] = '0'
+ os.environ['NUM_WORKERS'] = '5'
+ os.environ['JOINER_PORT'] = '12345'
+ os.environ['SERVER_PORT'] = '54321'
+ os.environ['ROLE'] = 'light_client'
+ args = get_arguments()
+ self.assertEqual(
+ args,
+ Arguments(input_path='input',
+ output_path='output',
+ key_column='raw_id',
+ server_port=54321,
+ joiner_port=12345,
+ worker_rank=0,
+ num_workers=5))
+ os.environ['CLUSTER_SPEC'] = json.dumps({'clusterSpec': {'Worker': ['worker-0', 'worker-1']}}) # omitted
+ args = get_arguments()
+ self.assertEqual(
+ args,
+ Arguments(input_path='input',
+ output_path='output',
+ key_column='raw_id',
+ server_port=54321,
+ joiner_port=12345,
+ worker_rank=0,
+ num_workers=5))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_ot/client.py b/pp_lite/data_join/psi_ot/client.py
new file mode 100644
index 000000000..6610d02cd
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/client.py
@@ -0,0 +1,135 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import time
+import logging
+import logging.config
+import tempfile
+import copy
+from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor, wait
+
+from pp_lite.data_join import envs
+
+from pp_lite.proto.arguments_pb2 import Arguments
+from pp_lite.proto.common_pb2 import DataJoinType, FileType
+from pp_lite.rpc.data_join_control_client import DataJoinControlClient
+from pp_lite.data_join.psi_ot.data_join_manager import DataJoinManager
+from pp_lite.data_join.psi_ot.joiner.ot_psi_joiner import OtPsiJoiner
+from pp_lite.data_join.psi_ot.joiner.hashed_data_joiner import HashedDataJoiner
+from pp_lite.data_join.utils.example_id_writer import ExampleIdWriter
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader, PartitionInfo
+from pp_lite.data_join.utils.partitioner import Partitioner
+from pp_lite.utils.logging_config import logging_config, log_path
+from pp_lite.utils.tools import get_partition_ids
+
+
+def wait_for_ready(client: DataJoinControlClient):
+ while True:
+ try:
+ client.health_check()
+ break
+ except grpc.RpcError:
+ logging.info('server is not ready')
+ time.sleep(10)
+
+
+def wait_for_ready_and_verify(client: DataJoinControlClient, num_partitions: int, num_workers: int):
+ wait_for_ready(client=client)
+ logging.info('[DataJoinControlClient] start verify parameter')
+ resp = client.verify_parameter(num_partitions=num_partitions, num_workers=num_workers)
+ logging.info(
+ f'[DataJoinControlClient] Server num_partitions: {resp.num_partitions}, num_workers: {resp.num_workers}')
+ # assert resp.succeeded, 'joiner must have the same parameters'
+ if not resp.succeeded:
+ if resp.num_partitions == 0:
+ logging.info(f'[DataJoinControlClient] Server num_partitions: {resp.num_partitions}, sever quit.')
+ logging.info('[DataJoinControlClient]joiner must have the same parameters')
+ return False
+ return True
+
+
+def run_joiner(args: Arguments):
+ logging.config.dictConfig(logging_config(file_path=log_path(log_dir=envs.STORAGE_ROOT)))
+ client = DataJoinControlClient(args.server_port)
+ reader = ExampleIdReader(input_path=args.input_path, file_type=FileType.CSV, key_column=args.key_column)
+ writer = ExampleIdWriter(output_path=args.output_path, key_column=args.key_column)
+ if args.data_join_type == DataJoinType.HASHED_DATA_JOIN:
+ joiner = HashedDataJoiner(joiner_port=args.joiner_port)
+ else:
+ joiner = OtPsiJoiner(joiner_port=args.joiner_port)
+ partition_info = PartitionInfo(args.input_path)
+ partition_ids = get_partition_ids(args.worker_rank, args.num_workers, partition_info.num_partitions)
+ logging.info(f'allocated partitions {partition_ids} to worker {args.worker_rank}')
+ manager = DataJoinManager(joiner, client, reader, writer)
+ num_partitions = reader.num_partitions
+ if num_partitions == 0:
+ logging.info('[run_joiner]num_partitions of client is zero, client close grpc and quit.')
+ client.finish()
+ return
+ ret = wait_for_ready_and_verify(client, num_partitions, args.num_workers)
+ if ret:
+ manager.run(partition_ids=partition_ids)
+ client.close()
+ else:
+ client.finish() # Close _stub
+ client.close()
+
+
+def run(args: Arguments):
+ logging.config.dictConfig(logging_config(file_path=log_path(log_dir=envs.STORAGE_ROOT)))
+ if args.partitioned:
+ run_joiner(args)
+ else:
+ partitioned_path = tempfile.mkdtemp()
+ logging.info(f'[DataJoinControlClient] input not partitioned, start partitioning to {partitioned_path}...')
+ client = DataJoinControlClient(args.server_port)
+ wait_for_ready(client=client)
+ parameter_response = client.get_parameter()
+ # DataJoinControlClient includes a gRPC channel.
+ # Creating a gRPC channel on the same port again may cause error.
+ # So the client needs to be closed after use.
+ client.close()
+ num_partitions = parameter_response.num_partitions
+ logging.info(f'[DataJoinControlClient] data will be partitioned to {num_partitions} partition(s).')
+ partitioner = Partitioner(input_path=args.input_path,
+ output_path=partitioned_path,
+ num_partitions=num_partitions,
+ block_size=1000000,
+ key_column=args.key_column,
+ queue_size=40,
+ reader_thread_num=20,
+ writer_thread_num=20)
+ partitioner.partition_data()
+ logging.info('[DataJoinControlClient] partition finished.')
+ futures = []
+ pool = ProcessPoolExecutor()
+ for worker in range(args.num_workers):
+ worker_args = copy.deepcopy(args)
+ worker_args.worker_rank = worker
+ worker_args.input_path = partitioned_path
+ worker_args.server_port = args.server_port + worker * 2
+ worker_args.joiner_port = args.joiner_port + worker
+ futures.append(pool.submit(run_joiner, worker_args))
+ res = wait(futures, return_when=FIRST_EXCEPTION)
+ for future in res.done:
+ if future.exception():
+ # early stop all subprocesses when catch exception
+ for pid, process in pool._processes.items(): # pylint: disable=protected-access
+ process.terminate()
+ raise Exception('Joiner subprocess failed') from future.exception()
+
+ combine_writer = ExampleIdWriter(output_path=args.output_path, key_column=args.key_column)
+ combine_writer.combine(num_partitions)
diff --git a/pp_lite/data_join/psi_ot/data_join_control_servicer.py b/pp_lite/data_join/psi_ot/data_join_control_servicer.py
new file mode 100644
index 000000000..821f8fcac
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/data_join_control_servicer.py
@@ -0,0 +1,87 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import logging
+from typing import Callable
+from google.protobuf import empty_pb2
+
+from pp_lite.rpc.server import IServicer
+from pp_lite.data_join.psi_ot.data_join_server import DataJoinServer
+from pp_lite.proto.common_pb2 import Pong, Ping
+from pp_lite.proto import data_join_control_service_pb2 as service_pb2
+from pp_lite.proto import data_join_control_service_pb2_grpc as service_pb2_grpc
+from pp_lite.proto.arguments_pb2 import ClusterSpec
+
+
+class DataJoinControlServicer(service_pb2_grpc.DataJoinControlServiceServicer, IServicer):
+
+ def __init__(self, data_join_server: DataJoinServer, cluster_spec: ClusterSpec):
+ self._stop_hook = None
+ self._data_join_server = data_join_server
+ self._cluster_spec = cluster_spec
+
+ def HealthCheck(self, request: Ping, context):
+ logging.info('[DataJoinControlServicer] Receive HealthCheck Request')
+ return Pong(message=request.message)
+
+ def VerifyParameter(self, request: service_pb2.VerifyParameterRequest, context):
+ logging.info('[DataJoinControlServicer] Receive VerifyParameter Request')
+ succeeded = True
+ num_partitions = self._data_join_server.num_partitions
+ num_workers = len(self._cluster_spec.workers)
+ if request.num_partitions != num_partitions:
+ logging.warning(
+ f'Server and client do not have the same partition num, {num_partitions} vs {request.num_partitions}')
+ succeeded = False
+ if request.num_workers != num_workers:
+ logging.warning(
+ f'Server and client do not have the same worker num, {num_workers} vs {request.num_workers}')
+ succeeded = False
+ return service_pb2.VerifyParameterResponse(succeeded=succeeded,
+ num_partitions=num_partitions,
+ num_workers=num_workers)
+
+ def GetParameter(self, request: service_pb2.GetParameterRequest, context):
+ logging.info('[DataJoinControlServicer] Receive GetParameter Request')
+ num_partitions = self._data_join_server.num_partitions
+ num_workers = len(self._cluster_spec.workers)
+ return service_pb2.GetParameterResponse(num_partitions=num_partitions, num_workers=num_workers)
+
+ def CreateDataJoin(self, request: service_pb2.CreateDataJoinRequest, context):
+ logging.info(f'[DataJoinControlServicer] Receive CreateDataJoin Request for partition {request.partition_id}')
+ assert request.type == self._data_join_server.data_join_type, 'joiner must have the same type'
+ self._data_join_server.stop()
+ if self._data_join_server.empty(partition_id=request.partition_id):
+ logging.info(f'[DataJoinControlServicer] skip joiner for partition {request.partition_id} with input 0 ids')
+ return service_pb2.CreateDataJoinResponse(succeeded=True, empty=True)
+ succeeded = self._data_join_server.start(partition_id=request.partition_id)
+ return service_pb2.CreateDataJoinResponse(succeeded=succeeded, empty=False)
+
+ def GetDataJoinResult(self, request: service_pb2.GetDataJoinResultRequest, context):
+ logging.info(
+ f'[DataJoinControlServicer] Receive GetDataJoinResult Request for partition {request.partition_id}')
+ finished = self._data_join_server.is_finished(request.partition_id)
+ logging.info(f'[DataJoinControlServicer] respond result {finished} to GetDataJoinResult request')
+ return service_pb2.DataJoinResult(finished=finished)
+
+ def Finish(self, request, context):
+ logging.info('[DataJoinControlServicer] Receive Finish Request')
+ self._stop_hook()
+ return empty_pb2.Empty()
+
+ def register(self, server: grpc.Server, stop_hook: Callable[[], None]):
+ self._stop_hook = stop_hook
+ service_pb2_grpc.add_DataJoinControlServiceServicer_to_server(self, server)
diff --git a/pp_lite/data_join/psi_ot/data_join_manager.py b/pp_lite/data_join/psi_ot/data_join_manager.py
new file mode 100644
index 000000000..357dad391
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/data_join_manager.py
@@ -0,0 +1,83 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import logging
+from typing import List
+
+from pp_lite.data_join.psi_ot.joiner.joiner_interface import Joiner
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader
+from pp_lite.data_join.utils.example_id_writer import ExampleIdWriter
+from pp_lite.rpc.data_join_control_client import DataJoinControlClient
+from pp_lite.utils.decorators import retry_fn, timeout_fn
+from pp_lite.utils import metric_collector
+
+MAX_NUMBER = 16000000
+
+
+class DataJoinManager:
+
+ def __init__(self, joiner: Joiner, client: DataJoinControlClient, reader: ExampleIdReader, writer: ExampleIdWriter):
+ self._reader = reader
+ self._writer = writer
+ self._client = client
+ self._joiner = joiner
+
+ def _wait_for_server_finished(self, partition_id: int):
+ for i in range(10):
+ resp = self._client.get_data_join_result(partition_id=partition_id)
+ if resp.finished:
+ logging.info(f'[DataJoinManager] server is finished for partition {partition_id}')
+ return
+ logging.warning(f'[DataJoinManager] server is still not finished for partition {partition_id}')
+ time.sleep(10)
+ raise Exception('server is still not finished!')
+
+ @retry_fn(3)
+ @timeout_fn(1200)
+ def _run_task(self, joiner: Joiner, partition_id: int):
+ logging.info(f'[DataJoinManager] start partition {partition_id}')
+ # ensure input id is unique
+ with metric_collector.emit_timing('dataset.data_join.ot_or_hash_psi.read_data_timing', {'role': 'client'}):
+ ids = list(set(self._reader.read(partition_id)))
+ if len(ids) == 0:
+ logging.info(f'[DataJoinManager] skip join for partition {partition_id} with client input 0 ids')
+ return
+ response = self._client.create_data_join(partition_id=partition_id, data_join_type=joiner.type)
+ if response.empty:
+ logging.info(f'[DataJoinManager] skip join for partition {partition_id} with server input 0 ids')
+ return
+ logging.info(f'[DataJoinManager] start join for partition {partition_id} with input {len(ids)} ids')
+ assert len(ids) < MAX_NUMBER, f'the number of id should be less than {MAX_NUMBER}'
+ metric_collector.emit_counter('dataset.data_join.ot_or_hash_psi.partition_start_join', 1, {'role': 'client'})
+ metric_collector.emit_counter('dataset.data_join.ot_or_hash_psi.row_num', len(ids), {'role': 'client'})
+ inter_ids = joiner.client_run(ids=ids)
+ metric_collector.emit_counter('dataset.data_join.ot_or_hash_psi.intersection', len(inter_ids),
+ {'role': 'client'})
+ logging.info(f'[DataJoinManager] finish join for partition {partition_id} with output {len(inter_ids)} ids')
+ self._writer.write(partition_id=partition_id, ids=inter_ids)
+ self._wait_for_server_finished(partition_id=partition_id)
+ self._writer.write_success_tag(partition_id=partition_id)
+ logging.info(f'[DataJoinManager] finish writing result for partition {partition_id}')
+
+ def run(self, partition_ids: List[int]):
+ logging.info('[DataJoinManager] data join start!')
+ for partition_id in partition_ids:
+ if self._writer.success_tag_exists(partition_id=partition_id):
+ logging.warning(f'[DataJoinManager] skip partition {partition_id} since success tag exists')
+ continue
+ self._run_task(joiner=self._joiner, partition_id=partition_id)
+ self._client.finish()
+ logging.info('[DataJoinManager] data join is finished!')
diff --git a/pp_lite/data_join/psi_ot/data_join_server.py b/pp_lite/data_join/psi_ot/data_join_server.py
new file mode 100644
index 000000000..c78743bf0
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/data_join_server.py
@@ -0,0 +1,117 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from threading import Lock
+import time
+import logging
+from typing import List
+from multiprocessing import get_context
+
+from pp_lite.proto.common_pb2 import DataJoinType
+from pp_lite.data_join.psi_ot.joiner.joiner_interface import Joiner
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader
+from pp_lite.data_join.utils.example_id_writer import ExampleIdWriter
+from pp_lite.data_join.psi_ot.data_join_manager import MAX_NUMBER
+from pp_lite.utils import metric_collector
+
+
+def _run(joiner: Joiner, ids: List[str], writer: ExampleIdWriter, partition_id: int):
+ # since spawn method is used, logging config is not forked from parent process,
+ # so the log level should be set to INFO.
+ # TODO(hangweiqiang): find a better way to initialize the process
+ logging.getLogger().setLevel(logging.INFO)
+ metric_collector.emit_counter('dataset.data_join.ot_or_hash_psi.partition_start_join', 1, {'role': 'server'})
+ metric_collector.emit_counter('dataset.data_join.ot_or_hash_psi.row_num', len(ids), {'role': 'server'})
+ inter_ids = joiner.server_run(ids)
+ metric_collector.emit_counter('dataset.data_join.ot_or_hash_psi.intersection', len(inter_ids), {'role': 'server'})
+ logging.info(f'[DataJoinServer] finish data join for partition {partition_id} with {len(inter_ids)} ids')
+ writer.write(partition_id=partition_id, ids=inter_ids)
+ writer.write_success_tag(partition_id=partition_id)
+ logging.info(f'[DataJoinServer] finish write result to partition {partition_id}')
+
+
+class DataJoinServer:
+
+ def __init__(self, joiner: Joiner, reader: ExampleIdReader, writer: ExampleIdWriter):
+ self._reader = reader
+ self._writer = writer
+ self._process = None
+ self._joiner = joiner
+ self._prepared_partition_id = None
+ self._ids = None
+ self._mutex = Lock()
+ # Since DataJoinServer use multiprocessing.Process to initialize a new process to
+ # run joiner, it may be blocked during fork due to https://github.com/grpc/grpc/issues/21471.
+ # Setting start method to spawn will resolve this problem, the difference between fork
+ # and spawn can be found in https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
+ self._mp_ctx = get_context('spawn')
+
+ @property
+ def data_join_type(self) -> DataJoinType:
+ return self._joiner.type
+
+ @property
+ def num_partitions(self) -> int:
+ return self._reader.num_partitions
+
+ def is_finished(self, partition_id: int) -> bool:
+ return self._writer.success_tag_exists(partition_id=partition_id)
+
+ def _get_ids(self, partition_id: int) -> List[str]:
+ with self._mutex:
+ if self._prepared_partition_id == partition_id:
+ return self._ids
+ # ensure input id is unique
+ self._ids = list(set(self._reader.read(partition_id)))
+ self._prepared_partition_id = partition_id
+ return self._ids
+
+ def empty(self, partition_id: int) -> bool:
+ ids = self._get_ids(partition_id)
+ return len(ids) == 0
+
+ def start(self, partition_id: int) -> bool:
+ """Start non-blocking joiner"""
+ assert self._process is None
+
+ with metric_collector.emit_timing('dataset.data_join.ot_or_hash_psi.read_data_timing', {'role': 'server'}):
+ ids = self._get_ids(partition_id)
+ logging.info(f'[DataJoinServer] read {len(ids)} ids from partition {partition_id}')
+
+ if len(ids) < 1 or len(ids) > MAX_NUMBER:
+ logging.warning(f'[DataJoinServer] len(ids) should be positive and less than {MAX_NUMBER}')
+ return False
+
+ self._process = self._mp_ctx.Process(target=_run,
+ kwargs={
+ 'joiner': self._joiner,
+ 'ids': ids,
+ 'writer': self._writer,
+ 'partition_id': partition_id,
+ })
+ logging.info(f'[DataJoinServer] start joiner for partition {partition_id}')
+ self._process.start()
+ # waiting for data join server being ready
+ time.sleep(10)
+ return True
+
+ def stop(self):
+ """kill the joiner process and release the resources"""
+ if self._process is None:
+ return
+ self._process.terminate()
+ self._process.join()
+ self._process.close()
+ self._process = None
diff --git a/pp_lite/data_join/psi_ot/data_join_server_test.py b/pp_lite/data_join/psi_ot/data_join_server_test.py
new file mode 100644
index 000000000..e132bd588
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/data_join_server_test.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import time
+import shutil
+import tempfile
+import unittest
+from typing import List, Optional
+
+from pp_lite.proto.common_pb2 import FileType, DataJoinType
+from pp_lite.data_join.psi_ot.data_join_server import DataJoinServer
+from pp_lite.data_join.psi_ot.joiner.joiner_interface import Joiner
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader
+from pp_lite.data_join.utils.example_id_writer import ExampleIdWriter
+from pp_lite.testing.make_data import _make_fake_data
+
+
+class TestJoiner(Joiner):
+
+ def __init__(self, wait_time: Optional[float] = None):
+ super().__init__(12345)
+ self._wait_time = wait_time
+
+ @property
+ def type(self) -> DataJoinType:
+ return DataJoinType.HASHED_DATA_JOIN
+
+ def client_run(self, ids: List[str]) -> List[str]:
+ return []
+
+ def server_run(self, ids: List[str]) -> List[str]:
+ if self._wait_time:
+ time.sleep(self._wait_time)
+ return ids
+
+
+class DataJoinServerTest(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._base_path = tempfile.mkdtemp()
+ self._input_path = os.path.join(self._base_path, 'input')
+ self._output_path = os.path.join(self._base_path, 'output')
+ os.makedirs(self._input_path)
+ os.makedirs(self._output_path)
+ _make_fake_data(self._input_path, 10, 10)
+ self._reader = ExampleIdReader(self._input_path, FileType.CSV, 'x_1')
+ self._writer = ExampleIdWriter(self._output_path, 'x_1')
+
+ def tearDown(self):
+ shutil.rmtree(self._base_path)
+
+ def test_start(self):
+ joiner = TestJoiner()
+ server = DataJoinServer(joiner, self._reader, self._writer)
+ server.start(partition_id=2)
+ time.sleep(0.1)
+ self.assertTrue(os.path.exists(os.path.join(self._output_path, 'partition_2')))
+
+ def test_stop(self):
+ joiner = TestJoiner(wait_time=11)
+ server = DataJoinServer(joiner, self._reader, self._writer)
+ server.start(partition_id=2)
+ server.stop()
+ time.sleep(2)
+ self.assertFalse(os.path.exists(os.path.join(self._output_path, 'partition_2')))
+ self.assertFalse(os.path.exists(os.path.join(self._output_path, '0002._SUCCESS')))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_ot/joiner/BUILD.bazel b/pp_lite/data_join/psi_ot/joiner/BUILD.bazel
new file mode 100644
index 000000000..4f54a769c
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/BUILD.bazel
@@ -0,0 +1,43 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+py_library(
+ name = "joiner",
+ srcs = [
+ "hashed_data_join_servicer.py",
+ "hashed_data_joiner.py",
+ "joiner_interface.py",
+ "ot_psi_joiner.py",
+ "utils.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join:envs",
+ "//pp_lite/proto:py_grpc",
+ "//pp_lite/proto:py_proto",
+ "//pp_lite/rpc",
+ "@common_cityhash//:pkg",
+ "@common_fsspec//:pkg",
+ "@common_pyarrow//:pkg", # keep
+ ],
+)
+
+py_test(
+ name = "ot_psi_joiner_test",
+ size = "small",
+ srcs = ["ot_psi_joiner_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":joiner",
+ "//pp_lite/data_join:envs",
+ ],
+)
+
+py_test(
+ name = "utils_test",
+ size = "small",
+ srcs = ["utils_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":joiner",
+ ],
+)
diff --git a/pp_lite/data_join/psi_ot/joiner/hashed_data_join_servicer.py b/pp_lite/data_join/psi_ot/joiner/hashed_data_join_servicer.py
new file mode 100644
index 000000000..39b2424a3
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/hashed_data_join_servicer.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from typing import Iterable, List, Callable
+
+from pp_lite.rpc.server import IServicer
+from pp_lite.proto import hashed_data_join_pb2_grpc, hashed_data_join_pb2
+
+
+class HashedDataJoinServicer(hashed_data_join_pb2_grpc.HashedDataJoinServiceServicer, IServicer):
+
+ def __init__(self, ids: List[str]):
+ self._ids = set(ids)
+ self._inter_ids = []
+ self._finished = False
+
+ def is_finished(self):
+ return self._finished
+
+ def get_data_join_result(self) -> List[str]:
+ assert self.is_finished(), 'Getting result before finished'
+ return self._inter_ids
+
+ def DataJoin(self, request_iterator: Iterable[hashed_data_join_pb2.DataJoinRequest], context):
+ for part in request_iterator:
+ current_inter_ids = [id for id in part.ids if id in self._ids]
+ self._inter_ids.extend(current_inter_ids)
+ yield hashed_data_join_pb2.DataJoinResponse(ids=current_inter_ids)
+ self._finished = True
+
+ def register(self, server: grpc.Server, stop_hook: Callable[[], None]):
+ hashed_data_join_pb2_grpc.add_HashedDataJoinServiceServicer_to_server(self, server)
diff --git a/pp_lite/data_join/psi_ot/joiner/hashed_data_joiner.py b/pp_lite/data_join/psi_ot/joiner/hashed_data_joiner.py
new file mode 100644
index 000000000..ca3e2224d
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/hashed_data_joiner.py
@@ -0,0 +1,54 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from typing import List
+from pp_lite.data_join.psi_ot.joiner.joiner_interface import Joiner
+from pp_lite.data_join.psi_ot.joiner.utils import HashValueSet
+from pp_lite.proto.common_pb2 import DataJoinType
+from pp_lite.rpc.server import RpcServer
+from pp_lite.rpc.hashed_data_join_client import HashedDataJoinClient
+from pp_lite.data_join.psi_ot.joiner.hashed_data_join_servicer import HashedDataJoinServicer
+
+
+class HashedDataJoiner(Joiner):
+
+ @property
+ def type(self) -> DataJoinType:
+ return DataJoinType.HASHED_DATA_JOIN
+
+ def client_run(self, ids: List[str]) -> List[str]:
+ client = HashedDataJoinClient(server_port=self.joiner_port)
+ hash_value_set = HashValueSet()
+ hash_value_set.add_raw_values(ids)
+ response_iterator = client.data_join(hash_value_set.get_hash_value_list())
+ resp_ids = []
+ for part in response_iterator:
+ for response_hashed_id in part.ids:
+ resp_ids.append(hash_value_set.get_raw_value(response_hashed_id))
+ return resp_ids
+
+ def server_run(self, ids: List[str]) -> List[str]:
+ hash_value_set = HashValueSet()
+ hash_value_set.add_raw_values(ids)
+ servicer = HashedDataJoinServicer(ids=hash_value_set.get_hash_value_list())
+ server = RpcServer(servicer=servicer, listen_port=self.joiner_port)
+ server.start()
+ for _ in range(1000):
+ if servicer.is_finished():
+ raw_ids = [hash_value_set.get_raw_value(hash_id) for hash_id in servicer.get_data_join_result()]
+ return raw_ids
+ time.sleep(1)
+ raise Exception('data join is not finished')
diff --git a/pp_lite/data_join/psi_ot/joiner/joiner_interface.py b/pp_lite/data_join/psi_ot/joiner/joiner_interface.py
new file mode 100644
index 000000000..8cced0ffb
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/joiner_interface.py
@@ -0,0 +1,39 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List
+from abc import ABCMeta, abstractmethod
+from pp_lite.proto.common_pb2 import DataJoinType
+
+
+class Joiner(metaclass=ABCMeta):
+
+ def __init__(self, joiner_port: int):
+ self.joiner_port = joiner_port
+
+ @property
+ @abstractmethod
+ def type(self) -> DataJoinType:
+ raise NotImplementedError
+
+ @abstractmethod
+ def client_run(self, ids: List[str]) -> List[str]:
+ """Run data join at client side. The id's in intersection set is returned"""
+ raise NotImplementedError
+
+ @abstractmethod
+ def server_run(self, ids: List[str]) -> List[str]:
+ """Run data join at server side. The id's in intersection set is returned"""
+ raise NotImplementedError
diff --git a/pp_lite/data_join/psi_ot/joiner/ot_psi_joiner.py b/pp_lite/data_join/psi_ot/joiner/ot_psi_joiner.py
new file mode 100644
index 000000000..1093a472a
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/ot_psi_joiner.py
@@ -0,0 +1,83 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import fsspec
+import logging
+import datetime
+from enum import Enum
+from typing import List
+
+from pp_lite.data_join import envs
+from pp_lite.proto.common_pb2 import DataJoinType
+from pp_lite.data_join.psi_ot.joiner.joiner_interface import Joiner
+
+
+def _write_ids(filename: str, ids: List[str]):
+ with fsspec.open(filename, 'wt') as f:
+ f.write('\n'.join(ids))
+
+
+def _read_ids(filename: str) -> List[str]:
+ with fsspec.open(filename, 'rt') as f:
+ return f.read().splitlines()
+
+
+class _Role(Enum):
+ """the value is consistent with the argument of ot command"""
+ client = 0 # psi sender; tcp client
+ server = 1 # psi receiver; tcp server
+
+
+def _timestamp() -> str:
+ """Return string format of time to make test easier to mock"""
+ return datetime.datetime.now().strftime('%Y%m%d-%H%M%S-%f')
+
+
+class OtPsiJoiner(Joiner):
+
+ @property
+ def type(self) -> DataJoinType:
+ return DataJoinType.OT_PSI
+
+ def _run(self, ids: List[str], role: _Role):
+ timestamp = _timestamp()
+ input_path = f'{envs.STORAGE_ROOT}/data/{role.name}-input-{timestamp}'
+ output_path = f'{envs.STORAGE_ROOT}/data/{role.name}-output-{timestamp}'
+ _write_ids(input_path, ids)
+ # cmd = f'{CMD} -r {role.value} -file {input_path} -ofile {output_path} && \
+ # -ip localhost:{self.joiner_port}'.split()
+ # logging.info(f'[OtPsiJoiner] run cmd: {cmd}')
+ try:
+ import psi_oprf # pylint: disable=import-outside-toplevel
+ psi_oprf.PsiRun(role.value, input_path, output_path, f'localhost:{self.joiner_port}')
+ logging.info('[ot_psi_joiner] PsiRun finished.')
+ # subprocess.run(cmd, check=True)
+ joined_ids = _read_ids(output_path)
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception('[OtPsiJoiner] error happened during ot psi!')
+ raise Exception from e
+ finally:
+ # delete the input and output file of ot program to release the storage volume
+ fs = fsspec.get_mapper(input_path).fs
+ fs.delete(input_path)
+ if fs.exists(output_path):
+ fs.delete(output_path)
+ return joined_ids
+
+ def client_run(self, ids: List[str]) -> List[str]:
+ return self._run(ids, _Role.client)
+
+ def server_run(self, ids: List[str]) -> List[str]:
+ return self._run(ids, _Role.server)
diff --git a/pp_lite/data_join/psi_ot/joiner/ot_psi_joiner_test.py b/pp_lite/data_join/psi_ot/joiner/ot_psi_joiner_test.py
new file mode 100644
index 000000000..9617436a1
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/ot_psi_joiner_test.py
@@ -0,0 +1,101 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import List
+from unittest.mock import patch
+from tempfile import TemporaryDirectory
+from concurrent.futures import ThreadPoolExecutor
+import importlib.util as imutil
+
+from pp_lite.data_join import envs
+from pp_lite.data_join.psi_ot.joiner.ot_psi_joiner import OtPsiJoiner
+
+
+def _write_fake_output(filename: str, ids: List[str]):
+ with open(filename, 'wt', encoding='utf-8') as f:
+ f.write('\n'.join(ids))
+
+
+def check_psi_oprf():
+ spec = imutil.find_spec('psi_oprf')
+ if spec is None:
+ psi_oprf_existed = False
+ else:
+ psi_oprf_existed = True
+ return psi_oprf_existed
+
+
+@unittest.skipUnless(check_psi_oprf(), 'require ot psi file')
+class OtPsiJoinerTest(unittest.TestCase):
+
+ @patch('pp_lite.data_join.psi_ot.joiner.ot_psi_joiner._timestamp')
+ def test_client_run(self, mock_run, mock_timestamp):
+ joiner = OtPsiJoiner(joiner_port=12345)
+ timestamp = '20220310-185545'
+ mock_timestamp.return_value = timestamp
+ with TemporaryDirectory() as temp_dir:
+ envs.STORAGE_ROOT = temp_dir
+ input_path = f'{envs.STORAGE_ROOT}/data/client-input-{timestamp}'
+ output_path = f'{envs.STORAGE_ROOT}/data/client-output-{timestamp}'
+ inter_ids = ['4', '5', '6']
+
+ def _side_effect(*args, **kwargs):
+ _write_fake_output(output_path, inter_ids)
+
+ mock_run.side_effect = _side_effect
+ ids = joiner.client_run(['1', '2', '3'])
+ mock_run.assert_called_with(0, input_path, output_path, f'localhost:{self.joiner_port}')
+ self.assertEqual(ids, inter_ids)
+
+ @patch('pp_lite.data_join.psi_ot.joiner.ot_psi_joiner._timestamp')
+ def test_server_run(self, mock_run, mock_timestamp):
+ joiner = OtPsiJoiner(joiner_port=12345)
+ timestamp = '20220310-185545'
+ mock_timestamp.return_value = timestamp
+ with TemporaryDirectory() as temp_dir:
+ envs.STORAGE_ROOT = temp_dir
+ input_path = f'{envs.STORAGE_ROOT}/data/server-input-{timestamp}'
+ output_path = f'{envs.STORAGE_ROOT}/data/server-output-{timestamp}'
+ inter_ids = ['4', '5', '6']
+
+ def _side_effect(*args, **kwargs):
+ _write_fake_output(output_path, inter_ids)
+
+ mock_run.side_effect = _side_effect
+ ids = joiner.server_run(['1', '2', '3'])
+ mock_run.assert_called_with(1, input_path, output_path, f'localhost:{self.joiner_port}')
+ self.assertEqual(ids, inter_ids)
+
+
+@unittest.skipUnless(check_psi_oprf(), 'require ot psi file')
+class OtPsiJoinerInContainerTest(unittest.TestCase):
+
+ def test_joiner(self):
+ client_ids = [str(i) for i in range(10000)]
+ server_ids = [str(i) for i in range(5000, 15000)]
+ joined_ids = [str(i) for i in range(5000, 10000)]
+ joiner = OtPsiJoiner(joiner_port=1212)
+ with ThreadPoolExecutor(max_workers=2) as pool:
+ client_fut = pool.submit(joiner.client_run, client_ids)
+ server_fut = pool.submit(joiner.server_run, server_ids)
+ client_result = client_fut.result()
+ server_result = server_fut.result()
+ self.assertEqual(client_result, server_result)
+ self.assertEqual(sorted(client_result), joined_ids)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_ot/joiner/utils.py b/pp_lite/data_join/psi_ot/joiner/utils.py
new file mode 100644
index 000000000..79e0a27ad
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/utils.py
@@ -0,0 +1,38 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List, Optional
+from cityhash import CityHash64 # pylint: disable=no-name-in-module
+
+
+class HashValueSet(object):
+
+ def __init__(self) -> None:
+ self._hash_map = {}
+
+ def add_raw_values(self, values: List[str]):
+ self._hash_map.update({str(CityHash64(value)): value for value in values})
+
+ def get_raw_value(self, hashed_value: str) -> Optional[str]:
+ try:
+ return self._hash_map[hashed_value]
+ except KeyError as e:
+ raise Exception('Hashed value not found in hash map.') from e
+
+ def get_hash_value_list(self) -> List[str]:
+ return list(self._hash_map.keys())
+
+ def exists(self, hashed_value: str) -> bool:
+ return hashed_value in self._hash_map
diff --git a/pp_lite/data_join/psi_ot/joiner/utils_test.py b/pp_lite/data_join/psi_ot/joiner/utils_test.py
new file mode 100644
index 000000000..1c43ddc79
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/joiner/utils_test.py
@@ -0,0 +1,48 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pp_lite.data_join.psi_ot.joiner.utils import HashValueSet
+
+
+class HashValueSetTest(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self._hash_value_set = HashValueSet()
+ self._hash_value_set.add_raw_values(['1', '2'])
+
+ def test_add(self):
+ self.assertDictEqual(
+ self._hash_value_set._hash_map, # pylint: disable=protected-access
+ {
+ '9304157803607034849': '1',
+ '6920640749119438759': '2'
+ })
+
+ def test_get(self):
+ self.assertEqual(self._hash_value_set.get_raw_value('9304157803607034849'), '1')
+ self.assertRaises(Exception, self._hash_value_set.get_raw_value, args=('123'))
+
+ def test_list(self):
+ self.assertListEqual(self._hash_value_set.get_hash_value_list(), ['9304157803607034849', '6920640749119438759'])
+
+ def test_exists(self):
+ self.assertTrue(self._hash_value_set.exists('9304157803607034849'))
+ self.assertFalse(self._hash_value_set.exists('123'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_ot/server.py b/pp_lite/data_join/psi_ot/server.py
new file mode 100644
index 000000000..338d3c9f7
--- /dev/null
+++ b/pp_lite/data_join/psi_ot/server.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import logging.config
+
+from pp_lite.data_join import envs
+
+from pp_lite.rpc.server import RpcServer
+from pp_lite.proto.arguments_pb2 import Arguments
+from pp_lite.proto.common_pb2 import FileType, DataJoinType
+from pp_lite.data_join.psi_ot.data_join_control_servicer import DataJoinControlServicer
+from pp_lite.data_join.psi_ot.data_join_server import DataJoinServer
+from pp_lite.data_join.psi_ot.joiner.ot_psi_joiner import OtPsiJoiner
+from pp_lite.data_join.psi_ot.joiner.hashed_data_joiner import HashedDataJoiner
+from pp_lite.data_join.utils.example_id_writer import ExampleIdWriter
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader
+from pp_lite.utils.logging_config import logging_config, log_path
+
+
+def run(args: Arguments):
+ logging.config.dictConfig(logging_config(log_path(log_dir=envs.STORAGE_ROOT)))
+ reader = ExampleIdReader(input_path=args.input_path, file_type=FileType.CSV, key_column=args.key_column)
+ writer = ExampleIdWriter(output_path=args.output_path, key_column=args.key_column)
+ if args.data_join_type == DataJoinType.HASHED_DATA_JOIN:
+ joiner = HashedDataJoiner(joiner_port=args.joiner_port)
+ else:
+ joiner = OtPsiJoiner(joiner_port=args.joiner_port)
+ data_join_server = DataJoinServer(joiner, reader=reader, writer=writer)
+ servicer = DataJoinControlServicer(data_join_server=data_join_server, cluster_spec=args.cluster_spec)
+ server = RpcServer(servicer, listen_port=args.server_port)
+ server.start()
+ server.wait()
+ logging.info('server is finished!')
diff --git a/pp_lite/data_join/psi_rsa/BUILD.bazel b/pp_lite/data_join/psi_rsa/BUILD.bazel
new file mode 100644
index 000000000..4b5288aec
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/BUILD.bazel
@@ -0,0 +1,35 @@
+load("@rules_python//python:defs.bzl", "py_binary", "py_library")
+
+py_library(
+ name = "psi_rsa",
+ srcs = [
+ "psi_client.py",
+ "psi_server.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join:envs",
+ "//pp_lite/data_join/psi_rsa/client",
+ "//pp_lite/data_join/psi_rsa/server",
+ "//pp_lite/data_join/utils",
+ "//pp_lite/rpc",
+ "//pp_lite/utils",
+ "//py_libs:metrics_lib",
+ ],
+)
+
+py_binary(
+ name = "server_bin",
+ srcs = ["psi_server.py"],
+ main = "psi_server.py",
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":psi_rsa"],
+)
+
+py_binary(
+ name = "client_bin",
+ srcs = ["psi_client.py"],
+ main = "psi_client.py",
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":psi_rsa"],
+)
diff --git a/pp_lite/data_join/psi_rsa/client/BUILD.bazel b/pp_lite/data_join/psi_rsa/client/BUILD.bazel
new file mode 100644
index 000000000..b5837de16
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/client/BUILD.bazel
@@ -0,0 +1,29 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+py_library(
+ name = "client",
+ srcs = [
+ "data_joiner.py",
+ "signer.py",
+ "syncronizer.py",
+ "task_producer.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join/utils",
+ "//pp_lite/rpc",
+ "//pp_lite/utils",
+ "@common_cityhash//:pkg",
+ "@common_gmpy2//:pkg",
+ "@common_pandas//:pkg",
+ "@common_rsa//:pkg",
+ ],
+)
+
+py_test(
+ name = "signer_test",
+ size = "small",
+ srcs = ["signer_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = ["client"],
+)
diff --git a/pp_lite/data_join/psi_rsa/client/data_joiner.py b/pp_lite/data_join/psi_rsa/client/data_joiner.py
new file mode 100644
index 000000000..dd9364c3c
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/client/data_joiner.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import List, Optional, Iterable
+
+from pp_lite.rpc.client import DataJoinClient
+from pp_lite.utils.decorators import time_log
+
+
+class DataJoiner:
+
+ def __init__(self, client: DataJoinClient, partition_batch_size: int = 10):
+ self._client = client
+ self._partition_batch_size = partition_batch_size
+
+ def _get_partition(self, partition_id: int) -> Iterable[Optional[List[str]]]:
+ partition_list = [partition_id]
+ if partition_id == -1:
+ part_num = self._client.get_partition_number().partition_num
+ partition_list = list(range(part_num))
+
+ for i in range(0, len(partition_list), self._partition_batch_size):
+ batch = partition_list[i:i + self._partition_batch_size]
+ for resp in self._client.get_signed_ids(partition_ids=batch):
+ yield resp.ids
+
+ @time_log('Joiner')
+ def join(self, signed_ids: List[str], partition_id: int = -1):
+ hash_table = set(signed_ids)
+ intersection_ids = []
+ # TODO: implement multiprocess consumer
+ for ids in self._get_partition(partition_id):
+ if ids is None:
+ continue
+ logging.info(f'[Joiner] {len(ids)} ids received from server')
+ inter = [i for i in ids if i in hash_table]
+ intersection_ids.extend(inter)
+ logging.info(f'[Joiner] {len(intersection_ids)} ids joined')
+ # remove duplicate elements
+ return list(set(intersection_ids))
diff --git a/pp_lite/data_join/psi_rsa/client/signer.py b/pp_lite/data_join/psi_rsa/client/signer.py
new file mode 100644
index 000000000..a91c7ac71
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/client/signer.py
@@ -0,0 +1,110 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import rsa
+import random
+import logging
+from typing import List, Tuple, Iterable, Optional
+from concurrent.futures import ThreadPoolExecutor
+
+from cityhash import CityHash64 # pylint: disable=no-name-in-module
+from gmpy2 import powmod, divm, mpz # pylint: disable=no-name-in-module
+
+from pp_lite.rpc.client import DataJoinClient
+from pp_lite.utils.decorators import time_log
+from pp_lite.data_join.utils.generators import make_ids_iterator_from_list
+
+
+class Signer:
+
+ def __init__(self, client: DataJoinClient, num_workers: Optional[int] = None):
+ self._client = client
+ self._public_key = self._get_public_key()
+ self._pool = None
+ if num_workers is not None:
+ self._pool = ThreadPoolExecutor(max_workers=num_workers)
+
+ def _get_public_key(self) -> rsa.PublicKey:
+ resp = self._client.get_public_key()
+ return rsa.PublicKey(int(resp.n), int(resp.e))
+
+ @staticmethod
+ def _blind(ids: List[str], public_key: rsa.PublicKey) -> Tuple[List[int], List[int]]:
+ """Blind raw id by random number
+ blind id by id * r^e % n, where r is the blind number, randomly sampled from (0, 2^256),
+ (e, n) is the rsa public key.
+ Args:
+ ids: list of raw id
+ public_key: rsa public key
+ Returns:
+ blinded id
+ """
+ blind_numbers = [random.SystemRandom().getrandbits(256) for i in ids]
+ hashed_ids = [CityHash64(i) for i in ids]
+ e = public_key.e
+ n = public_key.n
+ blinded_ids = [(powmod(r, e, n) * x) % n for r, x in zip(blind_numbers, hashed_ids)]
+ return blinded_ids, blind_numbers
+
+ @staticmethod
+ def _deblind(blind_signed_ids: List[int], blind_numbers: List[int], public_key: rsa.PublicKey) -> List[mpz]:
+ n = public_key.n
+ signed_ids = [divm(x, r, n) for x, r in zip(blind_signed_ids, blind_numbers)]
+ return signed_ids
+
+ @staticmethod
+ def _one_way_hash(ids: List[int]):
+ hashed_ids = [hex(CityHash64(str(i)))[2:] for i in ids]
+ return hashed_ids
+
+ def _remote_sign(self, blinded_ids: List[int]):
+ blinded_ids = [str(i) for i in blinded_ids]
+ resp = self._client.sign(blinded_ids)
+ return [int(i) for i in resp.signed_ids]
+
+ def sign_batch(self, ids: List[str]) -> List[str]:
+ """Sign raw id by calling service from remote server
+ sign id by
+ 1. generate blind number r;
+ 2. blind raw id by r: id * r^e % n;
+ 3. calling blind sign service: id^d * r^d^e % n = id^d * r % n
+ 4. deblind blinded signed id by r: id^d % n
+ 5. hash signed id: hash(id^d%n)
+ Args:
+ ids: raw id
+ Returns:
+ signed ids
+ """
+ blinded_ids, blind_numbers = self._blind(ids, self._public_key)
+ blinded_signed_ids = self._remote_sign(blinded_ids)
+ signed_ids = self._deblind(blinded_signed_ids, blind_numbers, self._public_key)
+ hashed_ids = self._one_way_hash(signed_ids)
+ return hashed_ids
+
+ def sign_iterator(self, ids_iterator: Iterable[List[str]]):
+ if self._pool:
+ yield from self._pool.map(self.sign_batch, ids_iterator)
+ else:
+ for ids in ids_iterator:
+ yield self.sign_batch(ids)
+
+ @time_log('Signer')
+ def sign_list(self, ids: List[str], batch_size=4096):
+ ids_iterator = make_ids_iterator_from_list(ids, batch_size)
+ signed_ids = []
+ for sids in self.sign_iterator(ids_iterator=ids_iterator):
+ signed_ids.extend(sids)
+ logging.info(f'[Signer] {len(signed_ids)} ids signed')
+ return signed_ids
diff --git a/pp_lite/data_join/psi_rsa/client/signer_test.py b/pp_lite/data_join/psi_rsa/client/signer_test.py
new file mode 100644
index 000000000..06783a51d
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/client/signer_test.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch
+from pp_lite.proto import data_join_service_pb2
+from pp_lite.rpc.client import DataJoinClient
+from pp_lite.data_join.psi_rsa.client.signer import Signer
+import rsa
+from gmpy2 import powmod # pylint: disable=no-name-in-module
+
+
+class SignerTest(unittest.TestCase):
+
+ @patch('pp_lite.rpc.client.DataJoinClient.get_public_key')
+ def setUp(self, get_public_key) -> None:
+ get_public_key.return_value = data_join_service_pb2.PublicKeyResponse(n=str(9376987687101647609), e=str(65537))
+ self.private_key = rsa.PrivateKey(9376987687101647609, 65537, 332945516441048573, 15236990059, 615409451)
+ self.client = DataJoinClient()
+ self.signer = Signer(client=self.client)
+
+ def test_sign(self):
+ self.signer._client.sign = MagicMock( # pylint: disable=protected-access
+ side_effect=lambda x: data_join_service_pb2.SignResponse(
+ signed_ids=[str(powmod(int(i), self.private_key.d, self.private_key.n)) for i in x]))
+
+ signed_ids = self.signer.sign_batch(['1', '2', '3'])
+ correct_signed_ids = ['288f534080870918', '19ade65d522c7915', 'ab2fa2127da06b98']
+ self.assertEqual(signed_ids, correct_signed_ids)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_rsa/client/syncronizer.py b/pp_lite/data_join/psi_rsa/client/syncronizer.py
new file mode 100644
index 000000000..7a42be9bc
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/client/syncronizer.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List, Iterable
+
+from pp_lite.rpc.client import DataJoinClient
+from pp_lite.utils.decorators import retry_fn, time_log
+from pp_lite.data_join.utils.generators import make_ids_iterator_from_list
+
+
+class ResultSynchronizer:
+
+ def __init__(self, client: DataJoinClient, batch_size: int = 4096):
+ self._client = client
+ self._batch_size = batch_size
+
+ def sync_from_iterator(self, ids_iterator: Iterable[List[str]], partition_id: int = -1):
+ self._client.sync_data_join_result(ids_iterator, partition_id)
+
+ @time_log('Synchronizer')
+ @retry_fn(retry_times=3)
+ def sync(self, ids: List[str], partition_id: int = -1):
+ ids_iterator = make_ids_iterator_from_list(ids, self._batch_size)
+ self.sync_from_iterator(ids_iterator, partition_id)
diff --git a/pp_lite/data_join/psi_rsa/client/task_producer.py b/pp_lite/data_join/psi_rsa/client/task_producer.py
new file mode 100644
index 000000000..6dd204b77
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/client/task_producer.py
@@ -0,0 +1,96 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import pandas
+import logging
+from typing import List
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader
+
+from pp_lite.rpc.client import DataJoinClient
+from pp_lite.data_join.psi_rsa.client.data_joiner import DataJoiner
+from pp_lite.data_join.psi_rsa.client.signer import Signer
+from pp_lite.data_join.psi_rsa.client.syncronizer import ResultSynchronizer
+from pp_lite.utils.metrics import emit_counter
+
+
+class TaskProducer:
+
+ def __init__(self,
+ client: DataJoinClient,
+ reader: ExampleIdReader,
+ output_dir: str,
+ key_column: str,
+ batch_size: int,
+ num_sign_parallel: int = 5):
+ self._client = client
+ self._reader = reader
+ self._output_dir = output_dir
+ self._key_column = key_column
+ self._batch_size = batch_size
+ self._signer = Signer(client=self._client, num_workers=num_sign_parallel)
+ self._joiner = DataJoiner(client=self._client)
+ self._synchronizer = ResultSynchronizer(client=self._client)
+ self._create_dirs()
+
+ def run(self, partition_id: int):
+ logging.info(f'[TaskProducer] dealing with partition{partition_id} ......')
+ ids = self._read_ids(partition_id)
+ if len(ids) == 0:
+ logging.error('[DataReader] Input data is empty, so exit now.')
+ raise ValueError('[DataReader] Input data is empty, please confirm input path')
+ logging.info(f'[DataReader] the input data count is {len(ids)}')
+ # sign
+ signed_ids = self._signer.sign_list(ids, self._batch_size)
+ signed_df = pandas.DataFrame({self._key_column: ids, 'sign': signed_ids})
+ signed_df.to_csv(self._get_signed_path(partition_id), index=False)
+ # join
+ joined_signed_ids = self._joiner.join(signed_ids=signed_ids, partition_id=partition_id)
+ singed2id = dict(zip(signed_ids, ids))
+ joined_ids = [singed2id[i] for i in joined_signed_ids]
+ # TODO (zhou.yi) use ExampleIdWriter to write file
+ joined_df = pandas.DataFrame({self._key_column: joined_ids})
+ joined_df.to_csv(self._get_joined_path(partition_id), index=False)
+ # synchronize
+ self._synchronizer.sync(ids=joined_ids, partition_id=partition_id)
+ # update audit info
+ emit_counter('Input data count', len(ids))
+ emit_counter('Joined data count', len(joined_ids))
+
+ def _read_ids(self, partition_id: int) -> List[str]:
+
+ if partition_id < 0:
+ # partition_id < 0 means the client and the server did not use the same logic to partition,
+ # all client data intersect with all server data.
+ return self._reader.read_all()
+ return self._reader.read(partition_id)
+
+ def _create_dirs(self):
+ os.makedirs(os.path.join(self._output_dir, 'signed'), exist_ok=True)
+ os.makedirs(os.path.join(self._output_dir, 'joined'), exist_ok=True)
+
+ def _get_signed_path(self, partition_id: int) -> str:
+ if partition_id < 0:
+ file_path = 'signed.csv'
+ else:
+ file_path = f'part-{str(partition_id).zfill(5)}-signed.csv'
+ return os.path.join(self._output_dir, 'signed', file_path)
+
+ def _get_joined_path(self, partition_id: int) -> str:
+ if partition_id < 0:
+ file_path = 'joined.csv'
+ else:
+ file_path = f'part-{str(partition_id).zfill(5)}-joined.csv'
+ return os.path.join(self._output_dir, 'joined', file_path)
diff --git a/pp_lite/data_join/psi_rsa/psi_client.py b/pp_lite/data_join/psi_rsa/psi_client.py
new file mode 100644
index 000000000..3b1a70e87
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/psi_client.py
@@ -0,0 +1,104 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import json
+import logging.config
+import os
+from os import getenv
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader
+from pp_lite.proto.common_pb2 import FileType
+
+from pp_lite.rpc.client import DataJoinClient
+from pp_lite.data_join.psi_rsa.client.task_producer import TaskProducer
+from pp_lite.utils.logging_config import logging_config, log_path
+from pp_lite.utils.metrics import get_audit_value, show_audit_info
+from pp_lite.utils.tools import get_partition_ids, print_named_dict
+
+
+def str_as_bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ if v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+# TODO(zhou.yi): change to CLI arguments
+def get_arguments():
+ arguments = {
+ 'input_dir': getenv('INPUT_DIR', '/app/workdir/input'),
+ 'output_dir': getenv('OUTPUT_DIR', '/app/workdir/output'),
+ 'key_column': getenv('KEY_COLUMN', 'raw_id'),
+ 'server_port': int(getenv('SERVER_PORT', '50051')),
+ 'batch_size': int(getenv('BATCH_SIZE', '4096')),
+ 'num_sign_parallel': int(getenv('NUM_SIGN_PARALLEL', '20')),
+ 'partitioned': str_as_bool(getenv('PARTITIONED', 'false')),
+ 'log_dir': getenv('LOG_DIR', '/app/workdir/log/'),
+ }
+ arguments['worker_rank'] = int(getenv('INDEX', '0'))
+ if getenv('NUM_WORKERS'):
+ arguments['num_workers'] = int(getenv('NUM_WORKERS'))
+ role = getenv('ROLE', '')
+ if getenv('CLUSTER_SPEC') and role != 'light_client':
+ cluster_spec = json.loads(getenv('CLUSTER_SPEC'))
+
+ # Only accept CLUSTER_SPEC in cluster environment,
+ # so that CLUSTER_SPEC from .env in light client environment can be omitted.
+ if 'clusterSpec' in cluster_spec:
+ arguments['num_workers'] = len(cluster_spec['clusterSpec']['Worker'])
+ return arguments
+
+
+def _show_client_audit():
+ show_audit_info()
+ intersection_rate = format(get_audit_value('Joined data count') / get_audit_value('Input data count') * 100, '.2f')
+ logging.info('====================Result====================')
+ logging.info(f'Intersection rate {intersection_rate} %')
+ logging.info('Running log locate at workdir/log')
+ logging.info('Data join result locate at workdir/output/joined')
+ logging.info('==============================================')
+
+
+def run(args: dict):
+ if args.get('log_dir') is not None:
+ if not os.path.exists(args['log_dir']):
+ os.makedirs(args['log_dir'], exist_ok=True)
+ logging.config.dictConfig(logging_config(file_path=log_path(args['log_dir'])))
+ print_named_dict(name='Client Arguments', target_dict=args)
+ client = DataJoinClient(args['server_port'])
+ client.check_server_ready(timeout_seconds=60)
+ reader = ExampleIdReader(input_path=args['input_dir'], file_type=FileType.CSV, key_column=args['key_column'])
+ partition_list = [-1]
+ # If the client and the server use the same logic to partition, the partition can be intersected one by one.
+ # Otherwise, all client data intersect with all server data.
+ if args['partitioned']:
+ num_partitions = reader.num_partitions
+ partition_list = get_partition_ids(worker_rank=args['worker_rank'],
+ num_workers=args['num_workers'],
+ num_partitions=num_partitions)
+
+ task_producer = TaskProducer(client=client,
+ reader=reader,
+ output_dir=args['output_dir'],
+ key_column=args['key_column'],
+ batch_size=args['batch_size'],
+ num_sign_parallel=args['num_sign_parallel'])
+ for partition_id in partition_list:
+ task_producer.run(partition_id=partition_id)
+ _show_client_audit()
+ client.finish()
diff --git a/pp_lite/data_join/psi_rsa/psi_server.py b/pp_lite/data_join/psi_rsa/psi_server.py
new file mode 100644
index 000000000..66afc7363
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/psi_server.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging.config
+from os import getenv
+
+from pp_lite.data_join import envs
+
+from pp_lite.rpc.server import RpcServer
+from pp_lite.data_join.psi_rsa.server.data_join_servicer import DataJoinServiceServicer
+from pp_lite.data_join.psi_rsa.server.utils import load_private_rsa_key
+from pp_lite.utils.logging_config import logging_config, log_path
+from pp_lite.utils.tools import print_named_dict
+
+
+def get_arguments():
+ arguments = {
+ 'rsa_private_key_path': getenv('PRIVATE_KEY_PATH'),
+ 'input_dir': getenv('INPUT_DIR'),
+ 'output_dir': getenv('OUTPUT_DIR'),
+ 'signed_column': getenv('SIGNED_COLUMN', 'signed'),
+ 'key_column': getenv('KEY_COLUMN', 'raw_id'),
+ 'server_port': int(getenv('SERVER_PORT', '50051')),
+ 'batch_size': int(getenv('BATCH_SIZE', '4096')),
+ 'num_sign_parallel': int(getenv('NUM_SIGN_PARALLEL', '30'))
+ }
+ return arguments
+
+
+def run(args: dict):
+ logging.config.dictConfig(logging_config(file_path=log_path(log_dir=envs.STORAGE_ROOT)))
+ print_named_dict(name='Server Arguments', target_dict=args)
+ private_key = load_private_rsa_key(args['rsa_private_key_path'])
+ servicer = DataJoinServiceServicer(private_key=private_key,
+ input_dir=args['input_dir'],
+ output_dir=args['output_dir'],
+ signed_column=args['signed_column'],
+ key_column=args['key_column'],
+ batch_size=args['batch_size'],
+ num_sign_parallel=args['num_sign_parallel'])
+ server = RpcServer(servicer=servicer, listen_port=args['server_port'])
+ server.start()
+ server.wait()
diff --git a/pp_lite/data_join/psi_rsa/server/BUILD.bazel b/pp_lite/data_join/psi_rsa/server/BUILD.bazel
new file mode 100644
index 000000000..cf64cd8fc
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/BUILD.bazel
@@ -0,0 +1,49 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+py_library(
+ name = "server",
+ srcs = [
+ "data_join_servicer.py",
+ "partition_reader.py",
+ "partition_writer.py",
+ "signer.py",
+ "utils.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join/utils",
+ "//pp_lite/proto:py_grpc",
+ "//pp_lite/proto:py_proto",
+ "//pp_lite/rpc",
+ "//pp_lite/utils",
+ "@common_fsspec//:pkg",
+ "@common_gmpy2//:pkg",
+ "@common_pandas//:pkg",
+ "@common_pyarrow//:pkg", # keep
+ "@common_rsa//:pkg",
+ ],
+)
+
+py_test(
+ name = "partition_writer_test",
+ size = "small",
+ srcs = ["partition_writer_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":server"],
+)
+
+py_test(
+ name = "partition_reader_test",
+ size = "small",
+ srcs = ["partition_reader_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":server"],
+)
+
+py_test(
+ name = "signer_test",
+ size = "small",
+ srcs = ["signer_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":server"],
+)
diff --git a/pp_lite/data_join/psi_rsa/server/data_join_servicer.py b/pp_lite/data_join/psi_rsa/server/data_join_servicer.py
new file mode 100644
index 000000000..57fecdea3
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/data_join_servicer.py
@@ -0,0 +1,115 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import logging
+from typing import Callable
+from google.protobuf import empty_pb2
+import rsa
+
+from pp_lite.rpc.server import IServicer
+from pp_lite.proto.common_pb2 import Pong, Ping
+from pp_lite.proto import data_join_service_pb2, data_join_service_pb2_grpc
+from pp_lite.data_join.psi_rsa.server.partition_writer import RsaServerPartitionWriter
+from pp_lite.data_join.psi_rsa.server.signer import RsaDataJoinSigner
+from pp_lite.data_join.psi_rsa.server.partition_reader import RsaServerPartitionReader
+from pp_lite.utils.metrics import show_audit_info, emit_counter
+from pp_lite.utils import metric_collector
+
+
+class DataJoinServiceServicer(data_join_service_pb2_grpc.DataJoinServiceServicer, IServicer):
+
+ def __init__(self,
+ private_key: rsa.PrivateKey,
+ input_dir: str,
+ output_dir: str,
+ signed_column: str,
+ key_column: str,
+ batch_size: int = 4096,
+ num_sign_parallel: int = 1):
+ self._writer = RsaServerPartitionWriter(output_dir=output_dir, key_column=key_column)
+ self._stop_hook = None
+ self._signer = RsaDataJoinSigner(private_key=private_key, num_workers=num_sign_parallel)
+ self._partition_reader = RsaServerPartitionReader(input_dir=input_dir,
+ signed_column=signed_column,
+ batch_size=batch_size)
+
+ def GetPartitionNumber(self, request, context):
+ emit_counter('get_partition_num', 1)
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.get_partition_num', 1, {'role': 'server'})
+ partition_num = self._partition_reader.get_partition_num()
+ logging.info(f'Receive request \'GetPartitionNum\' from client, partition num is {partition_num}')
+ return data_join_service_pb2.GetPartitionNumberResponse(partition_num=partition_num)
+
+ def GetSignedIds(self, request: data_join_service_pb2.GetSignedIdsRequest, context):
+ emit_counter('get_partition', 1)
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.get_partition', 1, {'role': 'server'})
+ partition_ids = request.partition_ids
+ tip = 'without partition' if not partition_ids else f'partition {partition_ids[0]} ~ {partition_ids[-1]}'
+ logging.info(f'Receive request \'GetPartition\' from client, {tip}')
+ total_num = 0
+ ids_generator = self._partition_reader.get_ids_generator(partition_ids)
+ for ids in ids_generator:
+ emit_counter('send_ids', len(ids))
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.send_ids', len(ids), {'role': 'server'})
+ total_num = total_num + len(ids)
+ logging.info(f'Sending data {tip}, sent {total_num} ids now')
+ yield data_join_service_pb2.GetSignedIdsResponse(ids=ids)
+
+ def GetPublicKey(self, request, context):
+ emit_counter('get_public_key', 1)
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.get_public_key', 1, {'role': 'server'})
+ logging.info('Receive request \'GetPublicKey\' from client')
+ public_key = self._signer.public_key
+ return data_join_service_pb2.PublicKeyResponse(e=str(public_key.e), n=str(public_key.n))
+
+ def Sign(self, request: data_join_service_pb2.SignRequest, context):
+ ids = [int(i) for i in request.ids]
+ emit_counter('sign_time', 1)
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.sign_time', 1, {'role': 'server'})
+ emit_counter('sign_ids', len(ids))
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.sign_ids', len(ids), {'role': 'server'})
+ logging.info(f'Receive request \'Sign\' from client, the number of signed ids is {len(ids)}.')
+ signed_ids = self._signer.sign_ids(ids)
+ signed_ids = [str(i) for i in signed_ids]
+ return data_join_service_pb2.SignResponse(signed_ids=signed_ids)
+
+ def SyncDataJoinResult(self, request_iterator, context):
+ emit_counter('sync_time', 1)
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.sync_time', 1, {'role': 'server'})
+ logging.info('Receive request \'Synchronize\' from client')
+
+ def data_generator():
+ for request in request_iterator:
+ yield request.partition_id, request.ids
+
+ total_num = self._writer.write_data_join_result(data_generator())
+ emit_counter('sync_ids', total_num)
+ metric_collector.emit_counter('dataset.data_join.rsa_psi.sync_ids', total_num, {'role': 'server'})
+ return data_join_service_pb2.SyncDataJoinResultResponse(succeeded=True)
+
+ def Finish(self, request, context):
+ self._signer.stop()
+ show_audit_info()
+ self._stop_hook()
+ return empty_pb2.Empty()
+
+ def HealthCheck(self, request: Ping, context):
+ logging.info('Receive request \'HealthCheck\' from client')
+ return Pong(message=request.message)
+
+ def register(self, server: grpc.Server, stop_hook: Callable[[], None]):
+ self._stop_hook = stop_hook
+ data_join_service_pb2_grpc.add_DataJoinServiceServicer_to_server(self, server)
diff --git a/pp_lite/data_join/psi_rsa/server/partition_reader.py b/pp_lite/data_join/psi_rsa/server/partition_reader.py
new file mode 100644
index 000000000..3fd4b76f6
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/partition_reader.py
@@ -0,0 +1,157 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import os
+import threading
+import queue
+from time import sleep
+from typing import Iterator, Optional, Generator, List
+from pathlib import Path
+
+import pandas
+import fsspec
+
+from pp_lite.utils.decorators import retry_fn
+from pp_lite.data_join.utils.generators import make_ids_iterator_from_list
+
+
+def _get_part_id(filename: str) -> Optional[int]:
+ """extract partition id from filename"""
+ comp = filename.split('-')
+ for c in comp:
+ if c.isdecimal():
+ return int(c)
+ return None
+
+
+def _filter_files(files: Iterator[str], partition_ids: List[int]) -> Iterator:
+ """
+ Args:
+ files(Iterator): files iterator under input dir
+ partition_ids(List[int]): target file number
+ Returns:
+ the files have the same number with partition_ids
+ """
+ file_list = []
+ for file in files:
+ if file.startswith('part-'):
+ if _get_part_id(file) in partition_ids:
+ file_list.append(file)
+ return file_list
+
+
+# TODO (zhou.yi): use ExampleIdReader to read
+class RsaServerPartitionReader():
+
+ def __init__(self, input_dir: str, signed_column: str, batch_size: int):
+ self._input_dir = input_dir
+ self._signed_column = signed_column
+ self._batch_size = batch_size
+ self._file_system: fsspec.AbstractFileSystem = fsspec.get_mapper(self._input_dir).fs
+ self._partition_num = self._set_partition_num()
+
+ def _get_files_under(self, path: str) -> List[str]:
+ return [
+ Path(file.get('name')).name
+ for file in self._file_system.ls(path, detail=True)
+ if file.get('type') == 'file'
+ ]
+
+ @staticmethod
+ def _is_valid_file(filename: str) -> bool:
+ return filename.startswith('part-')
+
+ def _set_partition_num(self) -> int:
+ files = self._get_files_under(self._input_dir)
+ return len(list(filter(self._is_valid_file, files)))
+
+ def get_partition_num(self) -> int:
+ return self._partition_num
+
+ def _get_files_by_partition_id(self, partition_ids: List[int]) -> List[str]:
+ files = self._get_files_under(self._input_dir)
+ if not partition_ids:
+ files = filter(self._is_valid_file, files)
+ else:
+ files = _filter_files(files=files, partition_ids=partition_ids)
+ files = [os.path.join(self._input_dir, file) for file in files]
+ return files
+
+ def get_ids_generator(self, partition_ids: List[int]) -> Generator:
+
+ files = self._get_files_by_partition_id(partition_ids)
+ with _BufferedReader(files, self._signed_column) as reader:
+ for keys in reader.read_keys():
+ for ids in make_ids_iterator_from_list(keys, self._batch_size):
+ yield ids
+
+
+class _BufferedReader(object):
+ """A reader to get keys from files.
+
+ It uses producer-consumer pattern to speed up.
+ """
+ _FILE_CAPACITY = 5
+ _WAIT_TIME_SECONDS = 1
+
+ def __init__(self, file_list: List[str], key_column: str):
+ self._read_thread = threading.Thread(target=self._get_keys_from_files, name='Buffered Reader', daemon=True)
+ self._data_queue = queue.Queue(maxsize=self._FILE_CAPACITY)
+ self._file_list = file_list
+ self._key_column = key_column
+ self._exception = None
+ self._finish = False
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not exc_type:
+ self._read_thread.join()
+
+ def read_keys(self):
+ self._read_thread.start()
+ while True:
+ if self._exception:
+ logging.exception(f'read keys with exception: {str(self._exception)}')
+ raise self._exception
+ # NOTE: this is not thread-safe
+ if self._data_queue.empty():
+ if self._finish:
+ break
+ sleep(self._WAIT_TIME_SECONDS)
+ else:
+ yield self._data_queue.get()
+
+ def _get_keys_from_files(self):
+ try:
+ # Reads keys per file
+ for file in self._file_list:
+ df = read_csv(file)
+ keys = df[self._key_column].astype('str')
+ self._data_queue.put(keys)
+ self._finish = True
+ # pylint: disable=broad-except
+ except Exception as e:
+ self._exception = e
+
+
+# reading from hdfs may fail and exit, so add retry
+@retry_fn(retry_times=3)
+def read_csv(file_path: str) -> pandas.DataFrame:
+ with fsspec.open(file_path, mode='r') as f:
+ logging.debug(f'Read file {file_path}...')
+ return pandas.read_csv(f)
diff --git a/pp_lite/data_join/psi_rsa/server/partition_reader_test.py b/pp_lite/data_join/psi_rsa/server/partition_reader_test.py
new file mode 100644
index 000000000..312276824
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/partition_reader_test.py
@@ -0,0 +1,57 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=protected-access
+import unittest
+import tempfile
+from shutil import rmtree
+
+from pp_lite.data_join.psi_rsa.server.partition_reader import RsaServerPartitionReader
+
+
+class RsaServerPartitionReaderTest(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls.input_dir: str = tempfile.mkdtemp()
+ cls.parts = []
+ for _ in range(5):
+ _, path = tempfile.mkstemp(prefix='part-', dir=cls.input_dir)
+ with open(path, mode='w', encoding='utf-8') as f:
+ f.write('signed_id\n1')
+ cls.helper = RsaServerPartitionReader(input_dir=cls.input_dir, signed_column='signed_id', batch_size=32)
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ rmtree(cls.input_dir)
+
+ def test_get_files_under(self):
+ files = self.helper._get_files_under(self.input_dir)
+ self.assertEqual(5, len(files))
+
+ def test_get_partition_num(self):
+ num = self.helper.get_partition_num()
+ self.assertEqual(5, num)
+
+ def test_get_files_by_partition_id(self):
+ ids = [id.partition('part-')[-1] for id in self.parts]
+ self.assertEqual(5, len(self.helper._get_files_by_partition_id(ids)))
+ self.assertEqual(5, len(self.helper._get_files_by_partition_id(None)))
+ self.assertEqual(3, len(self.helper._get_files_by_partition_id(ids)[:3]))
+ self.assertEqual(5, len(self.helper._get_files_by_partition_id([])))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_rsa/server/partition_writer.py b/pp_lite/data_join/psi_rsa/server/partition_writer.py
new file mode 100644
index 000000000..5df90da81
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/partition_writer.py
@@ -0,0 +1,58 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import os
+from typing import Iterator, Tuple, List
+
+import fsspec
+
+
+class RsaServerPartitionWriter:
+
+ def __init__(self, output_dir: str, key_column: str):
+ self._output_dir = output_dir
+ self._key_column = key_column
+ self._file_system: fsspec.AbstractFileSystem = fsspec.get_mapper(self._output_dir).fs
+
+ def _get_output_filename(self, partition_id: int = -1) -> str:
+ if partition_id is None or partition_id < 0:
+ return os.path.join(self._output_dir, 'joined', 'output.csv')
+ return os.path.join(self._output_dir, 'joined', f'output_{partition_id}.csv')
+
+ # TODO(zhou.yi): refactor this function by ExampleIdWriter
+ def write_data_join_result(self, data_iterator: Iterator[Tuple[int, List[str]]]):
+ total_num = 0
+
+ partition_id, ids = next(data_iterator, (None, None))
+ if ids is None:
+ logging.warning('no joined ids received from client!')
+ ids = []
+ filename = self._get_output_filename(partition_id)
+ if self._file_system.exists(filename):
+ self._file_system.rm(filename)
+ if not self._file_system.exists(os.path.dirname(filename)):
+ self._file_system.makedirs(os.path.dirname(filename))
+ with fsspec.open(filename, mode='w', encoding='utf-8') as f:
+ f.write(self._key_column + '\n')
+ f.write('\n'.join(ids) + '\n')
+ tip = 'without partition' if partition_id == -1 else f'partition {partition_id}'
+ total_num = total_num + len(ids)
+ logging.info(f'Receive data {tip}, Synchronize {total_num} ids now')
+ for partition_id, ids in data_iterator:
+ f.write('\n'.join(ids) + '\n')
+ total_num = total_num + len(ids)
+ logging.info(f'Receive data {tip}, Synchronize {total_num} ids now')
+ return total_num
diff --git a/pp_lite/data_join/psi_rsa/server/partition_writer_test.py b/pp_lite/data_join/psi_rsa/server/partition_writer_test.py
new file mode 100644
index 000000000..7585de9b1
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/partition_writer_test.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import tempfile
+from shutil import rmtree
+
+from pp_lite.data_join.psi_rsa.server.partition_writer import RsaServerPartitionWriter
+
+
+class RsaServerPartitionWriterTest(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.input_dir: str = tempfile.mkdtemp()
+ self.output_dir: str = tempfile.mkdtemp()
+ self.parts = []
+ for _ in range(5):
+ _, path = tempfile.mkstemp(prefix='part-', dir=self.input_dir)
+ with open(path, mode='w', encoding='utf-8') as f:
+ f.write('raw_id\n1')
+ self.writer = RsaServerPartitionWriter(output_dir=self.output_dir, key_column='raw_id')
+
+ def tearDown(self) -> None:
+ rmtree(self.input_dir)
+ rmtree(self.output_dir)
+
+ def test_write_data_join_result(self):
+ self.writer.write_data_join_result(iter([(-1, ['1', '2', '3']), (-1, ['1', '2', '3'])]))
+ with open(f'{self.output_dir}/joined/output.csv', mode='r', encoding='utf-8') as f:
+ self.assertEqual('raw_id\n1\n2\n3\n1\n2\n3\n', f.read())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_rsa/server/signer.py b/pp_lite/data_join/psi_rsa/server/signer.py
new file mode 100644
index 000000000..9877d5ffa
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/signer.py
@@ -0,0 +1,58 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import multiprocessing
+from concurrent.futures import ProcessPoolExecutor
+from typing import List
+
+import rsa
+from gmpy2 import powmod # pylint: disable=no-name-in-module
+
+
+class RsaDataJoinSigner():
+
+ def __init__(self, private_key: rsa.PrivateKey, num_workers: int = 1):
+ self._private_key = private_key
+ self._public_key = rsa.PublicKey(self._private_key.n, self._private_key.e)
+ mp_context = multiprocessing.get_context('spawn')
+ self._pool = ProcessPoolExecutor(max_workers=num_workers, mp_context=mp_context)
+
+ @property
+ def private_key(self) -> rsa.PrivateKey:
+ return self._private_key
+
+ @property
+ def public_key(self) -> rsa.PublicKey:
+ return self._public_key
+
+ @staticmethod
+ def _sign_ids(ids: List[int], private_key: rsa.PrivateKey) -> List[int]:
+ return [powmod(i, private_key.d, private_key.n) for i in ids]
+
+ def sign_ids(self, ids: List[int]) -> List[int]:
+
+ future = self._pool.submit(self._sign_ids, ids, self._private_key)
+ return future.result()
+
+ def stop(self):
+ # Processes in the process pool that have not yet exited will block the server process from exiting,
+ # so killing each subprocess is needed.
+ for pid, process in self._pool._processes.items(): # pylint:disable=protected-access
+ process.terminate()
+ logging.info(f'send SIGTERM to process {pid}!')
+ self._pool.shutdown(wait=True)
+ self._pool = None
+ logging.info('data join signer stopped')
diff --git a/pp_lite/data_join/psi_rsa/server/signer_test.py b/pp_lite/data_join/psi_rsa/server/signer_test.py
new file mode 100644
index 000000000..4caad436b
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/signer_test.py
@@ -0,0 +1,34 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import rsa
+
+from pp_lite.data_join.psi_rsa.server.signer import RsaDataJoinSigner
+
+
+class RsaDataJoinSignerTest(unittest.TestCase):
+
+ def setUp(self):
+ self._signer = RsaDataJoinSigner(
+ rsa.PrivateKey(9376987687101647609, 65537, 332945516441048573, 15236990059, 615409451))
+
+ def test_sign(self):
+ self.assertListEqual(self._signer.sign_ids([2, 3, 4]),
+ [5558008899394433345, 4817922342110581069, 672854883936409540])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/psi_rsa/server/utils.py b/pp_lite/data_join/psi_rsa/server/utils.py
new file mode 100644
index 000000000..73c86750d
--- /dev/null
+++ b/pp_lite/data_join/psi_rsa/server/utils.py
@@ -0,0 +1,23 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import rsa
+import fsspec
+
+
+def load_private_rsa_key(private_key_path: str) -> rsa.PrivateKey:
+ with fsspec.open(private_key_path, mode='rb') as f:
+ private_key = rsa.PrivateKey.load_pkcs1(f.read())
+ return private_key
diff --git a/pp_lite/data_join/utils/BUILD.bazel b/pp_lite/data_join/utils/BUILD.bazel
new file mode 100644
index 000000000..dcc461742
--- /dev/null
+++ b/pp_lite/data_join/utils/BUILD.bazel
@@ -0,0 +1,52 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+py_library(
+ name = "utils",
+ srcs = [
+ "example_id_reader.py",
+ "example_id_writer.py",
+ "generators.py",
+ "partitioner.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/proto:py_proto",
+ "@common_cityhash//:pkg",
+ "@common_fsspec//:pkg",
+ "@common_pandas//:pkg",
+ "@common_pyarrow//:pkg", # keep
+ ],
+)
+
+py_test(
+ name = "example_id_reader_test",
+ size = "small",
+ srcs = ["example_id_reader_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":utils",
+ "//pp_lite/testing",
+ ],
+)
+
+py_test(
+ name = "example_id_writer_test",
+ size = "small",
+ srcs = ["example_id_writer_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":utils",
+ "//pp_lite/testing",
+ ],
+)
+
+py_test(
+ name = "partitioner_test",
+ size = "medium",
+ srcs = ["partitioner_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":utils",
+ "//pp_lite/data_join/psi_ot",
+ ],
+)
diff --git a/pp_lite/data_join/utils/example_id_reader.py b/pp_lite/data_join/utils/example_id_reader.py
new file mode 100644
index 000000000..ca2bee7ee
--- /dev/null
+++ b/pp_lite/data_join/utils/example_id_reader.py
@@ -0,0 +1,112 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import os
+import fsspec
+import pandas as pd
+from typing import List, Iterator
+from pp_lite.proto.common_pb2 import FileType
+
+CHUNK_SIZE = 16 * 1024 * 1024 * 1024
+
+
+class PartitionInfo:
+
+ def __init__(self, input_path: str):
+ self._input_path = input_path
+ self._num_partitions = None
+ self._files = None
+ self._fs = fsspec.get_mapper(input_path).fs
+
+ @staticmethod
+ def _is_valid_file(filename: str) -> bool:
+ return os.path.split(filename)[1].startswith('part-')
+
+ def _list_files(self):
+ if self._files is None:
+ files = [file['name'] for file in self._fs.listdir(self._input_path)]
+ self._files = list(filter(self._is_valid_file, files))
+ return self._files
+
+ @property
+ def num_partitions(self) -> int:
+ if self._num_partitions is None:
+ self._num_partitions = len(self._list_files())
+ return self._num_partitions
+
+ def get_all_files(self) -> List[str]:
+ if self._fs.isfile(self._input_path):
+ return [self._input_path]
+ return [file['name'] for file in self._fs.listdir(self._input_path)]
+
+ def get_files(self, partition_id: int) -> List[str]:
+ """return file given partition id"""
+ files = []
+ try:
+ for file in self._list_files():
+ comp_list = os.path.split(file)[1].split('-')
+ assert len(comp_list) > 1, f'split file err, file is {file}'
+ comp = comp_list[1] # as: part-04988-24de412a-0741-4157-bf0f-2e5dc4ebe2d5-c000.csv
+ if comp.isdigit():
+ if int(comp) == partition_id:
+ files.append(file)
+ except RuntimeError:
+ logging.warning(f'[example_id_reader] get_files from {partition_id} err.')
+ return files
+
+
+class ExampleIdReader:
+
+ def __init__(self, input_path: str, file_type: FileType, key_column: str):
+ self._file_type = file_type
+ self._input_path = input_path
+ self._key_column = key_column
+ self._partition_info = PartitionInfo(input_path)
+ self._fs = fsspec.get_mapper(input_path).fs
+
+ @property
+ def num_partitions(self) -> int:
+ return self._partition_info.num_partitions
+
+ def _iter_data(self, filename) -> Iterator[pd.DataFrame]:
+ if self._file_type == FileType.CSV:
+ with self._fs.open(filename, 'r') as fin:
+ df = pd.read_csv(fin, chunksize=CHUNK_SIZE)
+ for chunk in df:
+ yield chunk
+ else:
+ raise NotImplementedError('tfrecord is not supported')
+
+ def _data_iterator(self, partition_id: int) -> Iterator[pd.DataFrame]:
+ assert partition_id < self.num_partitions
+ files = self._partition_info.get_files(partition_id)
+ for file in files:
+ iterator = self._iter_data(filename=file)
+ for data in iterator:
+ yield data
+
+ def read(self, partition_id: int) -> List:
+ values = []
+ for data in self._data_iterator(partition_id):
+ values.extend(data[self._key_column].astype('str').to_list())
+ return values
+
+ def read_all(self) -> List[str]:
+ ids = []
+ for filename in self._partition_info.get_all_files():
+ for part in self._iter_data(filename):
+ ids.extend(part[self._key_column].astype('str').to_list())
+ return ids
diff --git a/pp_lite/data_join/utils/example_id_reader_test.py b/pp_lite/data_join/utils/example_id_reader_test.py
new file mode 100644
index 000000000..9d2f820c9
--- /dev/null
+++ b/pp_lite/data_join/utils/example_id_reader_test.py
@@ -0,0 +1,73 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+import tempfile
+from pp_lite.proto.common_pb2 import FileType
+from pp_lite.data_join.utils.example_id_reader import ExampleIdReader, PartitionInfo
+from pp_lite.testing.make_data import _make_fake_data
+
+
+class PartitionInfoTest(unittest.TestCase):
+
+ def test_num_partitions(self):
+ with tempfile.TemporaryDirectory() as input_dir:
+ _make_fake_data(input_dir, num_partitions=10, line_num=10)
+ partition_reader = PartitionInfo(input_dir)
+ self.assertEqual(partition_reader.num_partitions, 10)
+
+ def test_get_file(self):
+ with tempfile.TemporaryDirectory() as input_dir:
+ _make_fake_data(input_dir, num_partitions=10, line_num=10, partitioned=True, spark=False)
+ partition_reader = PartitionInfo(input_dir)
+ files = partition_reader.get_files(partition_id=1)
+ self.assertEqual(files, [os.path.join(input_dir, 'part-1')])
+
+ def test_get_file_spark(self):
+ with tempfile.TemporaryDirectory() as input_dir:
+ _make_fake_data(input_dir, num_partitions=10, line_num=10, partitioned=True, spark=True)
+ partition_reader = PartitionInfo(input_dir)
+ files = partition_reader.get_files(partition_id=1)
+ files = [files[0].split('-')[0] + '-' + files[0].split('-')[1]]
+ self.assertEqual(files, [os.path.join(input_dir, 'part-1')])
+
+ def test_get_all_files(self):
+ with tempfile.TemporaryDirectory() as input_dir:
+ _make_fake_data(input_dir, num_partitions=10, line_num=10, partitioned=False)
+ partition_reader = PartitionInfo(input_dir)
+ files = partition_reader.get_all_files()
+ self.assertEqual(sorted(files), [f'{input_dir}/abcd-{str(i)}' for i in range(10)])
+
+
+class ExampleIdReaderTest(unittest.TestCase):
+
+ def test_read(self):
+ with tempfile.TemporaryDirectory() as input_dir:
+ _make_fake_data(input_dir, num_partitions=10, line_num=10)
+ reader = ExampleIdReader(input_dir, FileType.CSV, key_column='part_id')
+ values = reader.read(partition_id=1)
+ self.assertEqual(values, ['1'] * 10)
+
+ def test_read_all(self):
+ with tempfile.TemporaryDirectory() as input_dir:
+ _make_fake_data(input_dir, num_partitions=10, line_num=10, partitioned=False)
+ reader = ExampleIdReader(input_dir, FileType.CSV, key_column='part_id')
+ values = reader.read_all()
+ self.assertEqual(len(values), 100)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/utils/example_id_writer.py b/pp_lite/data_join/utils/example_id_writer.py
new file mode 100644
index 000000000..6b0507242
--- /dev/null
+++ b/pp_lite/data_join/utils/example_id_writer.py
@@ -0,0 +1,59 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import fsspec
+import logging
+import pandas as pd
+from typing import List
+
+
+class ExampleIdWriter:
+
+ def __init__(self, output_path: str, key_column: str):
+ self._output_path = output_path
+ self._key_column = key_column
+ self._fs = fsspec.get_mapper(output_path).fs
+
+ def write(self, partition_id: int, ids: List[str]):
+ if not self._fs.exists(self._output_path):
+ self._fs.makedirs(self._output_path, exist_ok=True)
+ filename = os.path.join(self._output_path, f'partition_{partition_id}')
+ logging.debug(f'[ExampleIdWriter] start writing {len(ids)} ids for partition {partition_id} to {filename}')
+ with self._fs.open(filename, 'w') as f:
+ df = pd.DataFrame(data={self._key_column: ids})
+ df.to_csv(f, index=False)
+ logging.debug(f'[ExampleIdWriter] finish writing for partition {partition_id}')
+
+ def combine(self, num_partitions: int):
+ if not os.path.isfile(os.path.join(self._output_path, 'partition_0')):
+ logging.warning('[ExampleIdWriter] combine fail, as no partition file')
+ return
+ self._fs.copy(os.path.join(self._output_path, 'partition_0'), os.path.join(self._output_path, 'output.csv'))
+ for partition in range(1, num_partitions):
+ with self._fs.open(os.path.join(self._output_path, 'output.csv'), 'ab') as o:
+ with self._fs.open(os.path.join(self._output_path, f'partition_{partition}')) as partition:
+ partition.readline()
+ o.write(partition.read())
+
+ def _success_tag(self, partition_id: int) -> str:
+ return os.path.join(self._output_path, f'{partition_id:04}._SUCCESS')
+
+ def write_success_tag(self, partition_id: int):
+ self._fs.touch(self._success_tag(partition_id))
+ logging.debug(f'[ExampleIdWriter] write success tag for partition {partition_id}')
+
+ def success_tag_exists(self, partition_id: int) -> bool:
+ return self._fs.exists(self._success_tag(partition_id))
diff --git a/pp_lite/data_join/utils/example_id_writer_test.py b/pp_lite/data_join/utils/example_id_writer_test.py
new file mode 100644
index 000000000..a0df306e4
--- /dev/null
+++ b/pp_lite/data_join/utils/example_id_writer_test.py
@@ -0,0 +1,58 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+import tempfile
+from pathlib import Path
+from pp_lite.data_join.utils.example_id_writer import ExampleIdWriter
+
+
+class ExampleIdWriterTest(unittest.TestCase):
+
+ def test_write(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ writer = ExampleIdWriter(temp_dir, key_column='raw_id')
+ writer.write(partition_id=0, ids=['a', 'b'])
+ with open(os.path.join(temp_dir, 'partition_0'), encoding='utf-8') as f:
+ content = f.read()
+ self.assertEqual(content, 'raw_id\na\nb\n')
+
+ def test_write_success_tag(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ writer = ExampleIdWriter(temp_dir, key_column='raw_id')
+ writer.write_success_tag(partition_id=1)
+ self.assertTrue(os.path.exists(os.path.join(temp_dir, '0001._SUCCESS')))
+
+ def test_success_tag_exists(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ writer = ExampleIdWriter(temp_dir, key_column='raw_id')
+ self.assertFalse(writer.success_tag_exists(1))
+ Path(os.path.join(temp_dir, '0001._SUCCESS')).touch()
+ self.assertTrue(writer.success_tag_exists(1))
+
+ def test_combine(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ writer = ExampleIdWriter(temp_dir, key_column='raw_id')
+ writer.write(partition_id=0, ids=[1, 2])
+ writer.write(partition_id=1, ids=[3, 4])
+ writer.write(partition_id=2, ids=[5, 6])
+ writer.combine(3)
+ with open(os.path.join(temp_dir, 'output.csv'), 'r', encoding='utf-8') as f:
+ self.assertEqual(f.read(), 'raw_id\n1\n2\n3\n4\n5\n6\n')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/data_join/utils/generators.py b/pp_lite/data_join/utils/generators.py
new file mode 100644
index 000000000..572a9ea72
--- /dev/null
+++ b/pp_lite/data_join/utils/generators.py
@@ -0,0 +1,23 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List
+
+
+def make_ids_iterator_from_list(ids: List[str], batch_size=4096):
+ num_parts = (len(ids) + batch_size - 1) // batch_size
+ for part_id in range(num_parts):
+ id_part = ids[part_id * batch_size:(part_id + 1) * batch_size]
+ yield id_part
diff --git a/pp_lite/data_join/utils/partitioner.py b/pp_lite/data_join/utils/partitioner.py
new file mode 100644
index 000000000..d838949ed
--- /dev/null
+++ b/pp_lite/data_join/utils/partitioner.py
@@ -0,0 +1,137 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import csv
+import fcntl
+import shutil
+import threading
+import logging
+import pyarrow as pa
+import pyarrow.csv as _csv
+from cityhash import CityHash64 # pylint: disable=no-name-in-module
+from queue import Queue, Empty
+from concurrent.futures import ThreadPoolExecutor
+
+
+def get_partition_path(output_path: str, partition_id: int):
+ return os.path.join(output_path, f'part-{partition_id}')
+
+
+def read_ids(input_path: str, key_column: str, block_size: int, num_partitions: int, queue: Queue):
+ t = threading.current_thread()
+ read_options = _csv.ReadOptions(block_size=block_size)
+ with _csv.open_csv(input_path, read_options=read_options) as reader:
+ for chunk in reader:
+ if chunk is None:
+ break
+ raw_df = chunk.to_pandas()
+ raw_df[key_column] = raw_df[key_column].astype('str')
+ raw_df['partition_id'] = [CityHash64(i) % num_partitions for i in raw_df[key_column]]
+ groups = raw_df.groupby('partition_id')
+ for group in groups:
+ partition_id, data = group
+ data.drop(columns=['partition_id'], inplace=True)
+ table = pa.Table.from_pandas(data, preserve_index=False)
+ group = (partition_id, table)
+ queue.put(group)
+ logging.info(f'[Reader]: Put {table.num_rows} ids with partition id {partition_id} into queue of size '
+ f'{queue.qsize()} ------ Thread_id: {t.ident}')
+
+
+def write_partitioned_ids(output_path: str, queue: Queue):
+ try:
+ while True:
+ t = threading.current_thread()
+ partition_id, table = queue.get(timeout=30)
+ logging.info(f'[Writer]: Get {table.num_rows} ids with partition id {partition_id} from queue of size '
+ f'{queue.qsize()} ------ Thread_id: {t.ident}')
+ path = get_partition_path(output_path, partition_id)
+ with open(path, 'ab') as f:
+ fcntl.flock(f.fileno(), fcntl.LOCK_EX)
+ option = _csv.WriteOptions(include_header=False)
+ _csv.write_csv(table, f, option)
+ except Empty as e:
+ logging.info('writer exits due to getting no data from queue')
+
+
+class PartReader:
+
+ def __init__(self, input_path: str, num_partitions: int, block_size: int, key_column: str, reader_thread_num: int):
+ self._input_path = input_path
+ self._num_partitions = num_partitions
+ self._block_size = block_size
+ self._key_column = key_column
+ self._pool = ThreadPoolExecutor(max_workers=reader_thread_num)
+
+ def __del__(self):
+ self._pool.shutdown(wait=True)
+ logging.info('[Reader] ThreadPoolExecutor has shutdown.')
+
+ def read(self, queue: Queue):
+ for filename in os.listdir(self._input_path):
+ self._pool.submit(read_ids, os.path.join(self._input_path, filename), self._key_column, self._block_size,
+ self._num_partitions, queue)
+
+
+class PartWriter:
+
+ def __init__(self, output_path: str, num_partitions: int, writer_thread_num: int):
+ self._output_path = output_path
+ self._num_partitions = num_partitions
+ self._pool = ThreadPoolExecutor(max_workers=writer_thread_num)
+
+ def __del__(self):
+ self._pool.shutdown(wait=True)
+ logging.info('[Writer] ThreadPoolExecutor has shutdown.')
+
+ def write(self, queue: Queue):
+ for _ in range(20):
+ self._pool.submit(write_partitioned_ids, self._output_path, queue)
+
+
+class Partitioner:
+
+ def __init__(self, input_path: str, output_path: str, num_partitions: int, block_size: int, key_column: str,
+ queue_size: int, reader_thread_num: int, writer_thread_num: int):
+ self._input_path = input_path
+ self._output_path = output_path
+ self._num_partitions = num_partitions
+ self._block_size = block_size
+ self._key_column = key_column
+ self._queue = Queue(queue_size)
+ self._reader_thread_num = reader_thread_num
+ self._writer_thread_num = writer_thread_num
+ shutil.rmtree(self._output_path, ignore_errors=True)
+ os.makedirs(self._output_path, exist_ok=True)
+
+ def partition_data(self) -> None:
+ header = [self._key_column]
+ for filename in os.listdir(self._input_path):
+ input_path = os.path.join(self._input_path, filename)
+ with open(input_path, 'r', encoding='utf-8') as f:
+ reader = csv.DictReader(f)
+ header = reader.fieldnames
+ break
+
+ for i in range(self._num_partitions):
+ with open(get_partition_path(self._output_path, i), 'w', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, header)
+ writer.writeheader()
+ reader = PartReader(self._input_path, self._num_partitions, self._block_size, self._key_column,
+ self._reader_thread_num)
+ writer = PartWriter(self._output_path, self._num_partitions, self._writer_thread_num)
+ reader.read(self._queue)
+ writer.write(self._queue)
diff --git a/pp_lite/data_join/utils/partitioner_test.py b/pp_lite/data_join/utils/partitioner_test.py
new file mode 100644
index 000000000..c96f6d4fa
--- /dev/null
+++ b/pp_lite/data_join/utils/partitioner_test.py
@@ -0,0 +1,137 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import csv
+import os
+import logging
+import pandas
+import shutil
+import tempfile
+import time
+import unittest
+from queue import Queue
+from pp_lite.data_join.utils.partitioner import read_ids, write_partitioned_ids, Partitioner
+
+
+def make_data(num: int, path: str, num_line: int) -> None:
+ shutil.rmtree(path, ignore_errors=True)
+ os.makedirs(path, exist_ok=True)
+
+ for i in range(num):
+ data = range(i * num_line, (i + 1) * num_line)
+ ids = [{'oaid': oaid, 'x1': oaid} for oaid in data]
+ with open(os.path.join(path, f'part-{i}'), 'wt', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, ['oaid', 'x1'])
+ writer.writeheader()
+ writer.writerows(ids)
+
+
+class PartitionTest(unittest.TestCase):
+
+ def test_read_ids(self):
+ with tempfile.TemporaryDirectory() as input_path:
+ make_data(1, input_path, 10)
+ filename = f'{input_path}/part-0'
+ queue = Queue(20)
+ read_ids(filename, 'oaid', 1000, 2, queue)
+ id1, table1 = queue.get()
+ self.assertEqual(id1, 0)
+ self.assertEqual(table1.to_pandas().values.tolist(), [['7', 7]])
+ id2, table2 = queue.get()
+ self.assertEqual(id2, 1)
+ self.assertEqual(table2.to_pandas().values.tolist(),
+ [['0', 0], ['1', 1], ['2', 2], ['3', 3], ['4', 4], ['5', 5], ['6', 6], ['8', 8], ['9', 9]])
+
+ # 1 partition
+ filename = f'{input_path}/part-0'
+ queue = Queue(20)
+ read_ids(filename, 'oaid', 1000, 1, queue)
+ id1, table1 = queue.get()
+ self.assertEqual(id1, 0)
+ self.assertEqual(
+ table1.to_pandas().values.tolist(),
+ [['0', 0], ['1', 1], ['2', 2], ['3', 3], ['4', 4], ['5', 5], ['6', 6], ['7', 7], ['8', 8], ['9', 9]])
+
+ def test_write_ids(self):
+ with tempfile.TemporaryDirectory() as input_path:
+ make_data(1, input_path, 10)
+ filename = f'{input_path}/part-0'
+ queue = Queue(20)
+ num_partitions = 2
+ block_size = 1000
+ read_ids(filename, 'oaid', block_size, num_partitions, queue)
+ with tempfile.TemporaryDirectory() as output_path:
+ write_partitioned_ids(output_path, queue)
+ self.assertEqual(len(os.listdir(output_path)), 2)
+ self.assertTrue(os.path.exists(os.path.join(output_path, 'part-0')))
+ self.assertTrue(os.path.exists(os.path.join(output_path, 'part-1')))
+ with open(os.path.join(output_path, 'part-0'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), '"7",7\n')
+ with open(os.path.join(output_path, 'part-1'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), '"0",0\n"1",1\n"2",2\n"3",3\n"4",4\n"5",5\n"6",6\n"8",8\n"9",9\n')
+
+ # 1 partition
+ read_ids(filename, 'oaid', block_size, 1, queue)
+ with tempfile.TemporaryDirectory() as output_path:
+ write_partitioned_ids(output_path, queue)
+ self.assertEqual(len(os.listdir(output_path)), 1)
+ self.assertTrue(os.path.exists(os.path.join(output_path, 'part-0')))
+ with open(os.path.join(output_path, 'part-0'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(),
+ '"0",0\n"1",1\n"2",2\n"3",3\n"4",4\n"5",5\n"6",6\n"7",7\n"8",8\n"9",9\n')
+
+ def test_partitioner(self):
+ with tempfile.TemporaryDirectory() as input_path:
+ make_data(20, input_path, 10)
+ with tempfile.TemporaryDirectory() as output_path:
+ timeout = 30
+ partitioner = Partitioner(input_path=input_path,
+ output_path=output_path,
+ num_partitions=20,
+ block_size=10000000,
+ key_column='oaid',
+ queue_size=40,
+ reader_thread_num=20,
+ writer_thread_num=20)
+ start = time.time()
+ partitioner.partition_data()
+ logging.info(f'Partitioner use time {time.time() - start - timeout}s')
+ self.assertEqual(len(os.listdir(output_path)), 20)
+ df = pandas.read_csv(os.path.join(output_path, 'part-0'))
+ self.assertEqual(sorted(df.values.tolist()),
+ [[22, 22], [71, 71], [94, 94], [120, 120], [127, 127], [136, 136], [173, 173]])
+
+ # 1 partition
+ with tempfile.TemporaryDirectory() as output_path:
+ timeout = 30
+ partitioner = Partitioner(input_path=input_path,
+ output_path=output_path,
+ num_partitions=1,
+ block_size=10000000,
+ key_column='oaid',
+ queue_size=40,
+ reader_thread_num=20,
+ writer_thread_num=20)
+ start = time.time()
+ partitioner.partition_data()
+ logging.info(f'Partitioner use time {time.time() - start - timeout}s')
+ self.assertEqual(len(os.listdir(output_path)), 1)
+ df = pandas.read_csv(os.path.join(output_path, 'part-0'))
+ self.assertEqual(len(df.values.tolist()), 200)
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
+ unittest.main()
diff --git a/pp_lite/deploy/BUILD.bazel b/pp_lite/deploy/BUILD.bazel
new file mode 100644
index 000000000..0ffbafac5
--- /dev/null
+++ b/pp_lite/deploy/BUILD.bazel
@@ -0,0 +1,28 @@
+load("@rules_python//python:defs.bzl", "py_binary", "py_library")
+
+py_binary(
+ name = "archiver",
+ srcs = ["archiver.py"],
+ data = ["//pp_lite/deploy/static"],
+ main = "archiver.py",
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":deploy"],
+)
+
+py_library(
+ name = "deploy",
+ srcs = [
+ "archiver.py",
+ "certificate.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//deploy/auto_cert",
+ "//deploy/auto_cert:authenticator_lib",
+ "//deploy/auto_cert:auto_cert_py_proto",
+ "//deploy/container:containers",
+ "//pp_lite/deploy/configs:deploy_config",
+ "//pp_lite/deploy/configs:logging_config",
+ "@common_click//:pkg",
+ ],
+)
diff --git a/pp_lite/deploy/README.md b/pp_lite/deploy/README.md
new file mode 100644
index 000000000..9ed40a42b
--- /dev/null
+++ b/pp_lite/deploy/README.md
@@ -0,0 +1,37 @@
+# PP Lite Client Archiver
+
+This is a packaging tool for PP Lite - Client. Whenever you need to send your client a copy of our inspiring PP Lite, use me.
+
+## How to Make an Archive?
+
+### 1. Configuration
+
+In order to make a zip file for your client, you have to prepare a configuration file in YAML.
+
+Save it somewhere that you know.
+
+### 2. Make Zip File with Bazel
+
+Using Bazel, you can make your zip file blazing fast:
+
+```bash
+# --run_under option makes relative path usable
+bazelisk run --run_under="cd $PWD && " //pp_lite/deploy:archiver -- -c -o [-f ]
+```
+
+## How to Use the Archive?
+
+Using the archive is as simple as eating an apple:
+
+```bash
+# Choose according to your format choice
+unzip pp_lite_client.zip
+# OR
+tar xf pp_lite_client.tar
+
+cd pp_lite
+# You may find UUID in [LIGHT_CLIENT_PSI]-[more information]-[Click and check the workflow]
+bash start.sh
+```
+
+You can modify this bootstrap script for sure, and so is the `.env` file, but **make sure you know what you are doing**.
diff --git a/pp_lite/deploy/archiver.py b/pp_lite/deploy/archiver.py
new file mode 100644
index 000000000..4320eba17
--- /dev/null
+++ b/pp_lite/deploy/archiver.py
@@ -0,0 +1,91 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import logging.config
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from shutil import make_archive, copy
+
+from click import command, option, Path as PathType, Choice
+
+from deploy.auto_cert.authenticator import ApiKeyAuthenticator
+from deploy.container.containers import pull_image_as_docker_archive
+from pp_lite.deploy.configs.logging_config import LOGGING_CONFIG
+from pp_lite.deploy.certificate import get_certificate, write_certificate
+from pp_lite.deploy.configs.controllers import get_deploy_config_from_yaml
+
+
+def _pull_image(image_uri: str, pp_lite_path: Path):
+ logging.info('include_image is set to true; pulling pp_lite_client...')
+ pull_image_as_docker_archive(image_uri, pp_lite_path / 'client_image.tar', 'privacy_computing_platform')
+ logging.info('include_image is set to true; pulling pp_lite_client... [DONE]')
+
+
+def _make_dirs(pp_lite_path: Path):
+ pp_lite_path.mkdir()
+ (pp_lite_path / 'input').mkdir()
+ (pp_lite_path / 'output').mkdir()
+ (pp_lite_path / 'log').mkdir()
+ (pp_lite_path / 'cert').mkdir()
+
+
+def _copy_static_files(pp_lite_path: Path):
+ path_of_this_file = Path(__file__).parent
+ copy(path_of_this_file / 'static' / '.env', Path(pp_lite_path / '.env'))
+ copy(path_of_this_file / 'static' / 'start.sh', Path(pp_lite_path / 'start.sh'))
+
+
+@command(name='PP Lite Client Archiver', help='I make archives for PP Lite Client for your clients.')
+@option('--yaml_config_path',
+ '-c',
+ help='How should I perform?',
+ type=PathType(exists=True, file_okay=True, dir_okay=False),
+ required=True)
+@option('--output_path',
+ '-o',
+ help='Where should I put the output archive?',
+ type=PathType(exists=True, file_okay=False, dir_okay=True),
+ required=True)
+@option('--output_format',
+ '-f',
+ help='In what format do you want your archive to be?',
+ type=Choice(['tar', 'zip']),
+ default='zip')
+def archive(yaml_config_path: str, output_path: str, output_format: str):
+ with TemporaryDirectory() as temp_dir:
+ config = get_deploy_config_from_yaml(Path(yaml_config_path).absolute())
+ cert = get_certificate(config.pure_domain_name, ApiKeyAuthenticator(config.auto_cert_api_key))
+
+ pp_lite_path = Path(temp_dir) / 'pp_lite'
+ _make_dirs(pp_lite_path)
+ _copy_static_files(pp_lite_path)
+ write_certificate(pp_lite_path / 'cert', cert)
+
+ if config.include_image:
+ _pull_image(config.image_uri, pp_lite_path)
+
+ logging.info('Making zip archive...')
+ make_archive(Path(output_path) / 'pp_lite_client', output_format, Path(temp_dir).absolute())
+ logging.info('Making zip archive... [DONE]')
+
+
+if __name__ == '__main__':
+ logging.config.dictConfig(LOGGING_CONFIG)
+ try:
+ archive() # pylint: disable=no-value-for-parameter
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(e)
+ raise e
diff --git a/pp_lite/deploy/certificate.py b/pp_lite/deploy/certificate.py
new file mode 100644
index 000000000..a9b223523
--- /dev/null
+++ b/pp_lite/deploy/certificate.py
@@ -0,0 +1,43 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from pathlib import Path
+
+from deploy.auto_cert.certificate_model_pb2 import CertificateFile
+from deploy.auto_cert.certificate_service import CertificateService
+from deploy.auto_cert.authenticator import ApiKeyAuthenticator
+from deploy.auto_cert.consts import BOE_NEXUS_CONFIG
+
+
+def get_certificate(company_name: str, authenticator: ApiKeyAuthenticator) -> CertificateFile:
+ service = CertificateService(authenticator, BOE_NEXUS_CONFIG)
+ common_name = f'{company_name}.fedlearner.net'
+ certs = service.get_certificates_by_name(common_name)
+ if len(certs) == 0:
+ logging.info(f'Certificate with company_name={company_name} not found; issuing...')
+ cert = service.issue_certificate(common_name, 365)
+ else:
+ cert = list(certs.values())[0]
+ return cert
+
+
+def write_certificate(cert_path: Path, cert: CertificateFile):
+ with open(cert_path / 'public.pem', mode='w', encoding='utf-8') as f:
+ f.write(cert.certificate)
+ with open(cert_path / 'intermediate.pem', mode='w', encoding='utf-8') as f:
+ f.write('\n'.join(cert.certificate_chain))
+ with open(cert_path / 'private.key', mode='w', encoding='utf-8') as f:
+ f.write(cert.private_key)
diff --git a/pp_lite/deploy/configs/BUILD.bazel b/pp_lite/deploy/configs/BUILD.bazel
new file mode 100644
index 000000000..55496e223
--- /dev/null
+++ b/pp_lite/deploy/configs/BUILD.bazel
@@ -0,0 +1,44 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+load("@rules_proto//proto:defs.bzl", "proto_library")
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library")
+
+py_library(
+ name = "deploy_config",
+ srcs = [
+ "controllers.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":models_py_proto",
+ "@common_pyyaml//:pkg",
+ ],
+)
+
+py_library(
+ name = "logging_config",
+ srcs = [
+ "logging_config.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+)
+
+proto_library(
+ name = "models_proto",
+ srcs = ["models.proto"],
+ visibility = ["//pp_lite:pp_lite_package"],
+)
+
+py_proto_library(
+ name = "models_py_proto",
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":models_proto"],
+)
+
+py_test(
+ name = "controllers_test",
+ size = "small",
+ srcs = ["controllers_test.py"],
+ data = ["//pp_lite/deploy/test_data"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = ["//pp_lite/deploy/configs:deploy_config"],
+)
diff --git a/pp_lite/deploy/configs/controllers.py b/pp_lite/deploy/configs/controllers.py
new file mode 100644
index 000000000..21bb7d69f
--- /dev/null
+++ b/pp_lite/deploy/configs/controllers.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from pathlib import Path
+
+from yaml import load, Loader
+
+from pp_lite.deploy.configs.models_pb2 import Config
+
+
+def get_deploy_config_from_yaml(path: Path) -> Config:
+ logging.info(f'Getting config from YAML file with path={path}...')
+ with open(path, mode='r', encoding='utf-8') as f:
+ content = load(f, Loader=Loader)
+ logging.info(f'Getting config from YAML file with path={path}... [DONE]')
+ return Config(pure_domain_name=content['pure_domain_name'],
+ image_uri=content['image_uri'],
+ include_image=content['include_image'],
+ auto_cert_api_key=content['auto_cert_api_key'])
diff --git a/pp_lite/deploy/configs/controllers_test.py b/pp_lite/deploy/configs/controllers_test.py
new file mode 100644
index 000000000..b6172c9cd
--- /dev/null
+++ b/pp_lite/deploy/configs/controllers_test.py
@@ -0,0 +1,33 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from unittest import main, TestCase
+from pathlib import Path
+
+from pp_lite.deploy.configs.controllers import get_deploy_config_from_yaml
+
+
+class ConfigTest(TestCase):
+
+ def test_get_config_from_yaml(self):
+ config = get_deploy_config_from_yaml(Path(__file__).parent.parent / 'test_data' / 'deploy_config.yaml')
+ self.assertEqual('some_company', config.pure_domain_name)
+ self.assertEqual('artifact.bytedance.com/fedlearner/pp_lite:2.3.25.4', config.image_uri)
+ self.assertEqual(False, config.include_image)
+ self.assertEqual('some_key', config.auto_cert_api_key)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pp_lite/deploy/configs/logging_config.py b/pp_lite/deploy/configs/logging_config.py
new file mode 100644
index 000000000..154036407
--- /dev/null
+++ b/pp_lite/deploy/configs/logging_config.py
@@ -0,0 +1,37 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+LOGGING_CONFIG = {
+ 'version': 1,
+ 'disable_existing_loggers': False,
+ 'root': {
+ 'handlers': ['console'],
+ 'level': 'DEBUG'
+ },
+ 'handlers': {
+ 'console': {
+ 'class': 'logging.StreamHandler',
+ 'formatter': 'generic',
+ 'level': 'INFO'
+ },
+ },
+ 'formatters': {
+ 'generic': {
+ 'format': '%(asctime)s [%(process)d] [%(levelname)s] [PP Lite Client Archiver] %(message)s',
+ 'datefmt': '%Y-%m-%d %H:%M:%S',
+ 'class': 'logging.Formatter'
+ }
+ }
+}
diff --git a/pp_lite/deploy/configs/models.proto b/pp_lite/deploy/configs/models.proto
new file mode 100644
index 000000000..88f106a9f
--- /dev/null
+++ b/pp_lite/deploy/configs/models.proto
@@ -0,0 +1,27 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+package pp_lite.deploy.configs;
+
+message Config {
+ // Prefix of the certificate: .fedlearner.net
+ string pure_domain_name = 1;
+ string image_uri = 2;
+ // Whether to include pp_lite_client image in the produced zip file
+ bool include_image = 3;
+ string auto_cert_api_key = 4;
+}
diff --git a/pp_lite/deploy/static/BUILD.bazel b/pp_lite/deploy/static/BUILD.bazel
new file mode 100644
index 000000000..6143e68fa
--- /dev/null
+++ b/pp_lite/deploy/static/BUILD.bazel
@@ -0,0 +1,8 @@
+filegroup(
+ name = "static",
+ srcs = [
+ ".env",
+ "start.sh",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+)
diff --git a/pp_lite/deploy/static/start.sh b/pp_lite/deploy/static/start.sh
new file mode 100755
index 000000000..6696ebb10
--- /dev/null
+++ b/pp_lite/deploy/static/start.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+if grep -q "psi-ot" .env
+then
+ if [[ $1 == '' || $2 == '' || $3 == '' ]]
+ then
+ echo "[Usage]: bash $0 [param1: UUID] [param2: INPUT_DIR] [param3: NUM_WORKERS]"
+ exit 1
+ fi
+ NUM_WORKERS=$3
+ sed -i".bak" -e "s/NUM_WORKERS.*/NUM_WORKERS=${NUM_WORKERS}/" .env
+else
+ if [[ $1 == '' || $2 == '' ]]
+ then
+ echo "[Usage]: bash $0 [param1: UUID] [param2: INPUT_DIR]"
+ exit 1
+ fi
+fi
+UUID=$1
+INPUT_DIR=$2
+
+set -e
+sed -i".bak" -e "s/SERVICE_ID.*/SERVICE_ID=${UUID}-lc-start-server-worker-0/" .env
+
+if test -z "$(docker images | grep pp_lite)"
+then
+ echo "Loading image..."
+ docker load -i client_image.tar
+else
+ echo "Image already satisfied."
+fi
+
+IMAGE_URI="$(docker images --format '{{ .Repository }}:{{ .Tag }}' | grep pp_lite | head -n 1)"
+echo "Using $IMAGE_URI to proceed."
+
+echo "Start Client..."
+docker run -it --rm --env-file .env \
+ -v "$PWD":/app/workdir \
+ -v "${INPUT_DIR}":/app/workdir/input \
+ "${IMAGE_URI}"
+echo "Start Client... [DONE]"
diff --git a/pp_lite/deploy/test_data/BUILD.bazel b/pp_lite/deploy/test_data/BUILD.bazel
new file mode 100644
index 000000000..57b02c214
--- /dev/null
+++ b/pp_lite/deploy/test_data/BUILD.bazel
@@ -0,0 +1,6 @@
+filegroup(
+ name = "test_data",
+ testonly = True,
+ srcs = ["deploy_config.yaml"],
+ visibility = ["//pp_lite:pp_lite_package"],
+)
diff --git a/pp_lite/deploy/test_data/deploy_config.yaml b/pp_lite/deploy/test_data/deploy_config.yaml
new file mode 100644
index 000000000..f3c38b5dd
--- /dev/null
+++ b/pp_lite/deploy/test_data/deploy_config.yaml
@@ -0,0 +1,4 @@
+pure_domain_name: some_company
+image_uri: artifact.bytedance.com/fedlearner/pp_lite:2.3.25.4
+include_image: false
+auto_cert_api_key: some_key
diff --git a/pp_lite/proto/BUILD.bazel b/pp_lite/proto/BUILD.bazel
new file mode 100644
index 000000000..8d9d5c188
--- /dev/null
+++ b/pp_lite/proto/BUILD.bazel
@@ -0,0 +1,31 @@
+load("@rules_proto//proto:defs.bzl", "proto_library")
+load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library")
+
+# gazelle:ignore otherwise gazellel will generate self in deps
+proto_library(
+ name = "proto",
+ srcs = [
+ "arguments.proto",
+ "common.proto",
+ "data_join_control_service.proto",
+ "data_join_service.proto",
+ "hashed_data_join.proto",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "@com_google_protobuf//:empty_proto",
+ ],
+)
+
+py_proto_library(
+ name = "py_proto",
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":proto"],
+)
+
+py_grpc_library(
+ name = "py_grpc",
+ srcs = [":proto"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [":py_proto"],
+)
diff --git a/pp_lite/proto/arguments.proto b/pp_lite/proto/arguments.proto
new file mode 100644
index 000000000..bc8d34530
--- /dev/null
+++ b/pp_lite/proto/arguments.proto
@@ -0,0 +1,76 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+package pp_lite.proto;
+import "pp_lite/proto/common.proto";
+
+message ClusterSpec {
+ // service name and port of workers
+ repeated string workers = 1;
+}
+
+message Arguments {
+ string input_path = 1;
+ string output_path = 2;
+ string key_column = 3;
+ DataJoinType data_join_type = 4;
+ int32 server_port = 5;
+ // 0-based worker rank
+ int32 worker_rank = 6;
+ ClusterSpec cluster_spec = 7;
+ int32 num_workers = 8;
+ int32 joiner_port = 9;
+ bool partitioned = 10;
+}
+
+message TrainerArguments {
+ string input_path = 1 [deprecated=true];
+ string output_path = 2 [deprecated=true];
+ // listen port for server-side gRpc server.
+ int32 server_port = 3;
+ // cluster_spec for TensorFlow gRpc server. Default: '{"clusterSpec": {"server": ["localhost:51001"], \
+ // "master": ["localhost:50101"], "ps": ["localhost:50102"], "worker": ["localhost:50103"]}}'.
+ string cluster_spec = 4;
+ // which model version to load if the pre-trained model exists. Default: 0.
+ int32 model_version = 5;
+ // the maximum number of model checkpoints to save. Default: 5.
+ int32 model_max_to_keep = 6;
+ // the tolerated number of the stale version compared to the current version for the
+ // model aggregation on the server. Default: 0.
+ int32 tolerated_version_gap = 7;
+ // the number of local steps for each client epoch. Default: 100.
+ int32 local_steps = 8;
+ // local training batch size. Default: 10.
+ int32 batch_size = 9;
+ // address for client-side master gRpc server.
+ string master_addr = 10;
+ int32 worker_rank = 11;
+ int32 ps_rank = 12;
+ // task mode for client. Default: 'local'. Range: ['local', 'master', 'ps', 'worker'].
+ string task_mode = 13;
+ // listen port for server-side tensorflow gRpc server.
+ int32 tf_port = 14;
+ // total number of clients.
+ int32 num_clients = 15;
+ // export server model per save_version_gap.
+ int32 save_version_gap = 16;
+ // initial weight for aggregating the client model.
+ float client_model_weight = 17;
+ string data_path = 18;
+ string export_path = 19;
+ // filter out tfrecord files we want to use
+ string file_wildcard = 20;
+}
\ No newline at end of file
diff --git a/pp_lite/proto/common.proto b/pp_lite/proto/common.proto
new file mode 100644
index 000000000..e8f4e0a39
--- /dev/null
+++ b/pp_lite/proto/common.proto
@@ -0,0 +1,35 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+package pp_lite.proto;
+
+enum DataJoinType {
+ HASHED_DATA_JOIN = 0;
+ OT_PSI = 1;
+}
+
+enum FileType {
+ CSV = 0;
+ TFRECORD = 1;
+}
+
+message Ping {
+ string message = 1;
+}
+
+message Pong {
+ string message = 1;
+}
diff --git a/pp_lite/proto/data_join_control_service.proto b/pp_lite/proto/data_join_control_service.proto
new file mode 100644
index 000000000..ed028f47c
--- /dev/null
+++ b/pp_lite/proto/data_join_control_service.proto
@@ -0,0 +1,67 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+package pp_lite.proto;
+import "pp_lite/proto/common.proto";
+import "google/protobuf/empty.proto";
+
+message CreateDataJoinRequest {
+ DataJoinType type = 1;
+ int64 partition_id = 2;
+}
+
+message CreateDataJoinResponse {
+ bool succeeded = 1;
+ bool empty = 2;
+}
+
+message VerifyParameterRequest {
+ int64 num_partitions = 1;
+ int64 num_workers = 2;
+}
+
+message VerifyParameterResponse {
+ bool succeeded = 1;
+ int64 num_partitions = 2;
+ int64 num_workers = 3;
+}
+
+message GetParameterRequest {
+ string message = 1;
+}
+
+message GetParameterResponse {
+ int64 num_partitions = 1;
+ int64 num_workers = 2;
+}
+
+message GetDataJoinResultRequest {
+ int64 partition_id = 1;
+}
+
+message DataJoinResult {
+ int64 num_joined = 1;
+ bool finished = 2;
+}
+
+service DataJoinControlService {
+ rpc HealthCheck(Ping) returns (Pong) {}
+ rpc VerifyParameter(VerifyParameterRequest) returns (VerifyParameterResponse) {}
+ rpc GetParameter(GetParameterRequest) returns (GetParameterResponse) {}
+ rpc CreateDataJoin(CreateDataJoinRequest) returns (CreateDataJoinResponse) {}
+ rpc GetDataJoinResult(GetDataJoinResultRequest) returns (DataJoinResult) {}
+ rpc Finish(google.protobuf.Empty) returns (google.protobuf.Empty) {}
+}
\ No newline at end of file
diff --git a/pp_lite/proto/data_join_service.proto b/pp_lite/proto/data_join_service.proto
new file mode 100644
index 000000000..7a10581d0
--- /dev/null
+++ b/pp_lite/proto/data_join_service.proto
@@ -0,0 +1,63 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+package pp_lite.proto;
+import "pp_lite/proto/common.proto";
+import "google/protobuf/empty.proto";
+
+message GetPartitionNumberResponse {
+ int64 partition_num = 1;
+}
+
+message GetSignedIdsRequest {
+ repeated int64 partition_ids = 1;
+}
+
+message GetSignedIdsResponse {
+ repeated string ids = 1;
+}
+
+message PublicKeyResponse {
+ string e = 1;
+ string n = 2;
+}
+
+message SignRequest {
+ repeated string ids = 1;
+}
+
+message SignResponse {
+ repeated string signed_ids = 1;
+}
+
+message SyncDataJoinResultRequest {
+ int64 partition_id = 1;
+ repeated string ids = 2;
+}
+
+message SyncDataJoinResultResponse {
+ bool succeeded = 1;
+}
+
+service DataJoinService {
+ rpc Sign(SignRequest) returns (SignResponse) {}
+ rpc Finish(google.protobuf.Empty) returns (google.protobuf.Empty) {}
+ rpc GetPublicKey(google.protobuf.Empty) returns (PublicKeyResponse) {}
+ rpc GetSignedIds(GetSignedIdsRequest) returns (stream GetSignedIdsResponse) {}
+ rpc GetPartitionNumber(google.protobuf.Empty) returns (GetPartitionNumberResponse) {}
+ rpc SyncDataJoinResult(stream SyncDataJoinResultRequest) returns (SyncDataJoinResultResponse) {}
+ rpc HealthCheck(Ping) returns (Pong) {}
+}
\ No newline at end of file
diff --git a/pp_lite/proto/hashed_data_join.proto b/pp_lite/proto/hashed_data_join.proto
new file mode 100644
index 000000000..8aff081f0
--- /dev/null
+++ b/pp_lite/proto/hashed_data_join.proto
@@ -0,0 +1,29 @@
+/* Copyright 2023 The FedLearner Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+package pp_lite.proto;
+
+message DataJoinRequest {
+ repeated string ids = 1;
+}
+
+message DataJoinResponse {
+ repeated string ids = 1;
+}
+
+service HashedDataJoinService {
+ rpc DataJoin(stream DataJoinRequest) returns (stream DataJoinResponse) {}
+}
diff --git a/pp_lite/rpc/BUILD.bazel b/pp_lite/rpc/BUILD.bazel
new file mode 100644
index 000000000..886499438
--- /dev/null
+++ b/pp_lite/rpc/BUILD.bazel
@@ -0,0 +1,19 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+py_library(
+ name = "rpc",
+ srcs = [
+ "client.py",
+ "data_join_control_client.py",
+ "hashed_data_join_client.py",
+ "server.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join:envs",
+ "//pp_lite/data_join/utils",
+ "//pp_lite/proto:py_grpc",
+ "//pp_lite/proto:py_proto",
+ "//pp_lite/utils",
+ ],
+)
diff --git a/pp_lite/rpc/client.py b/pp_lite/rpc/client.py
new file mode 100644
index 000000000..c4474613b
--- /dev/null
+++ b/pp_lite/rpc/client.py
@@ -0,0 +1,78 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import List, Optional, Iterable
+
+import grpc
+from google.protobuf import empty_pb2
+
+from pp_lite.proto.common_pb2 import Ping, Pong
+from pp_lite.proto import data_join_service_pb2, data_join_service_pb2_grpc
+from pp_lite.data_join.envs import GRPC_CLIENT_TIMEOUT
+from pp_lite.utils.decorators import retry_fn
+
+
+class DataJoinClient:
+ """Rsa psi rpc client"""
+
+ def __init__(self, server_port: int = 50052, batch_size: int = 4096):
+ logging.info(f'RpcClient started: server_port:{server_port}')
+ self._host = 'localhost'
+ self._server_port = server_port
+ self._channel = grpc.insecure_channel(f'{self._host}:{self._server_port}')
+ self._stub = data_join_service_pb2_grpc.DataJoinServiceStub(self._channel)
+ self._batch_size = batch_size
+
+ @retry_fn(retry_times=30)
+ def check_server_ready(self, timeout_seconds=5):
+ # Check server ready via channel ready future instead of for-loop `HealthCheck` call.
+ # Ref: https://grpc.github.io/grpc/python/grpc.html#grpc.channel_ready_future
+ grpc.channel_ready_future(self._channel).result(timeout=timeout_seconds)
+
+ @retry_fn(retry_times=3)
+ def get_public_key(self) -> data_join_service_pb2.PublicKeyResponse:
+ return self._stub.GetPublicKey(empty_pb2.Empty(), timeout=GRPC_CLIENT_TIMEOUT)
+
+ def health_check(self) -> Pong:
+ return self._stub.HealthCheck(Ping(), timeout=GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3)
+ def sign(self, ids: List[str]) -> data_join_service_pb2.SignResponse:
+ request = data_join_service_pb2.SignRequest(ids=ids)
+ return self._stub.Sign(request, timeout=GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3)
+ def get_partition_number(self) -> data_join_service_pb2.GetPartitionNumberResponse:
+ return self._stub.GetPartitionNumber(empty_pb2.Empty())
+
+ @retry_fn(retry_times=3)
+ def get_signed_ids(self, partition_ids: List[int]) -> Iterable[data_join_service_pb2.GetSignedIdsResponse]:
+ request = data_join_service_pb2.GetSignedIdsRequest(partition_ids=partition_ids)
+ return self._stub.GetSignedIds(request)
+
+ def sync_data_join_result(self, ids_iterator: Iterable[List[str]], partition_id: Optional[int] = None) \
+ -> data_join_service_pb2.SyncDataJoinResultResponse:
+
+ def request_iterator():
+ for ids in ids_iterator:
+ yield data_join_service_pb2.SyncDataJoinResultRequest(ids=ids, partition_id=partition_id)
+
+ return self._stub.SyncDataJoinResult(request_iterator())
+
+ def finish(self):
+ logging.info('RpcClient stopped ! ! !')
+ request = empty_pb2.Empty()
+ self._stub.Finish(request)
diff --git a/pp_lite/rpc/data_join_control_client.py b/pp_lite/rpc/data_join_control_client.py
new file mode 100644
index 000000000..695da408c
--- /dev/null
+++ b/pp_lite/rpc/data_join_control_client.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import logging
+from google.protobuf import empty_pb2
+
+from pp_lite.data_join.envs import GRPC_CLIENT_TIMEOUT
+
+from pp_lite.proto.common_pb2 import DataJoinType, Ping, Pong
+from pp_lite.proto import data_join_control_service_pb2 as service_pb2
+from pp_lite.proto import data_join_control_service_pb2_grpc as service_pb2_grpc
+
+
+class DataJoinControlClient:
+ """Ot psi rpc client"""
+
+ def __init__(self, server_port: int):
+ logging.info(f'RpcClient started: server_port:{server_port}')
+ self._host = 'localhost'
+ self._server_port = server_port
+ self._channel = grpc.insecure_channel(f'{self._host}:{self._server_port}')
+ self._stub = service_pb2_grpc.DataJoinControlServiceStub(self._channel)
+
+ def health_check(self, message: str = '') -> Pong:
+ request = Ping(message=message)
+ return self._stub.HealthCheck(request, timeout=GRPC_CLIENT_TIMEOUT)
+
+ def verify_parameter(self, num_partitions: int, num_workers: int) -> service_pb2.VerifyParameterResponse:
+ request = service_pb2.VerifyParameterRequest(num_partitions=num_partitions, num_workers=num_workers)
+ return self._stub.VerifyParameter(request, timeout=GRPC_CLIENT_TIMEOUT)
+
+ def get_parameter(self, message: str = '') -> service_pb2.GetParameterResponse:
+ request = service_pb2.GetParameterRequest(message=message)
+ return self._stub.GetParameter(request, timeout=GRPC_CLIENT_TIMEOUT)
+
+ def create_data_join(self, partition_id: int, data_join_type: DataJoinType) -> service_pb2.CreateDataJoinResponse:
+ request = service_pb2.CreateDataJoinRequest(partition_id=partition_id, type=data_join_type)
+ # timeout is not set since server may load data from slow hdfs
+ return self._stub.CreateDataJoin(request)
+
+ def get_data_join_result(self, partition_id: int) -> service_pb2.DataJoinResult:
+ request = service_pb2.GetDataJoinResultRequest(partition_id=partition_id)
+ return self._stub.GetDataJoinResult(request, timeout=GRPC_CLIENT_TIMEOUT)
+
+ def finish(self):
+ logging.info('RpcClient stopped ! ! !')
+ self._stub.Finish(empty_pb2.Empty())
+
+ def close(self):
+ self._channel.close()
diff --git a/pp_lite/rpc/hashed_data_join_client.py b/pp_lite/rpc/hashed_data_join_client.py
new file mode 100644
index 000000000..e1fc62355
--- /dev/null
+++ b/pp_lite/rpc/hashed_data_join_client.py
@@ -0,0 +1,40 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from typing import Iterator, List
+from pp_lite.proto import hashed_data_join_pb2 as service_pb2
+from pp_lite.proto import hashed_data_join_pb2_grpc as service_pb2_grpc
+from pp_lite.data_join.utils.generators import make_ids_iterator_from_list
+
+
+class HashedDataJoinClient:
+ """Hashed data join for integrated test"""
+
+ def __init__(self, server_port: int):
+ self._host = 'localhost'
+ self._server_port = server_port
+ self._channel = grpc.insecure_channel(f'{self._host}:{self._server_port}')
+ self._stub = service_pb2_grpc.HashedDataJoinServiceStub(self._channel)
+
+ def data_join(self, ids: List[str], batch_size: int = 4096) -> Iterator[service_pb2.DataJoinResponse]:
+
+ def request_iterator():
+ for part_ids in make_ids_iterator_from_list(ids, batch_size):
+ yield service_pb2.DataJoinRequest(ids=part_ids)
+
+ response_iterator = self._stub.DataJoin(request_iterator())
+
+ return response_iterator
diff --git a/pp_lite/rpc/server.py b/pp_lite/rpc/server.py
new file mode 100644
index 000000000..de4235858
--- /dev/null
+++ b/pp_lite/rpc/server.py
@@ -0,0 +1,69 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import logging
+import threading
+from typing import Callable
+from concurrent import futures
+from abc import ABCMeta, abstractmethod
+
+
+class IServicer(metaclass=ABCMeta):
+
+ @abstractmethod
+ def register(self, server: grpc.Server, stop_hook: Callable[[], None]):
+ raise NotImplementedError()
+
+
+class RpcServer:
+
+ def __init__(self, servicer: IServicer, listen_port: int):
+ self._lock = threading.Lock()
+ self._started = False
+ self._server = None
+ self._servicer = servicer
+ self._listen_port = listen_port
+
+ @property
+ def server(self):
+ return self._server
+
+ def start(self):
+ assert not self._started, 'already started'
+ with self._lock:
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=20))
+ self._servicer.register(self._server, self.stop)
+ self._server.add_insecure_port(f'[::]:{self._listen_port}')
+ self._server.start()
+ self._started = True
+ logging.info(f'RpcServer started: listen_port:{self._listen_port}')
+
+ def wait(self, timeout=None):
+ self._server.wait_for_termination(timeout)
+
+ def stop(self):
+ if not self._started:
+ return
+ with self._lock:
+ # cannot stop immediately due to Finish response will be returned
+ self._server.stop(grace=5)
+ del self._server
+ self._started = False
+ logging.info('RpcServer stopped ! ! !')
+
+ def is_alive(self):
+ with self._lock:
+ return hasattr(self, '_server')
diff --git a/pp_lite/test/BUILD.bazel b/pp_lite/test/BUILD.bazel
new file mode 100644
index 000000000..5fc9f25c6
--- /dev/null
+++ b/pp_lite/test/BUILD.bazel
@@ -0,0 +1,66 @@
+load("@rules_python//python:defs.bzl", "py_test")
+
+py_test(
+ name = "psi_ot_test",
+ size = "medium",
+ srcs = ["psi_ot_test.py"],
+ tags = ["exclusive"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join/psi_ot",
+ "//pp_lite/data_join/psi_ot/joiner",
+ "//pp_lite/proto:py_proto",
+ "//pp_lite/testing",
+ ],
+)
+
+py_test(
+ name = "psi_rsa_test",
+ size = "small",
+ srcs = ["psi_rsa_test.py"],
+ tags = ["exclusive"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join/psi_rsa",
+ "//pp_lite/testing",
+ ],
+)
+
+py_test(
+ name = "psi_rsa_partition_test",
+ size = "small",
+ srcs = ["psi_rsa_partition_test.py"],
+ tags = ["exclusive"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/data_join/psi_rsa",
+ "//pp_lite/testing",
+ ],
+)
+
+py_test(
+ name = "trainer_test",
+ size = "medium",
+ srcs = ["trainer_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite/proto:py_proto",
+ "//pp_lite/testing",
+ "//pp_lite/trainer",
+ ],
+)
+
+py_test(
+ name = "cli_test",
+ size = "small",
+ srcs = ["cli_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//pp_lite",
+ "//pp_lite/data_join/psi_ot",
+ "//pp_lite/data_join/psi_ot/joiner",
+ "//pp_lite/data_join/psi_rsa",
+ "//pp_lite/testing",
+ "@common_click//:pkg",
+ ],
+)
diff --git a/pp_lite/test/cli_test.py b/pp_lite/test/cli_test.py
new file mode 100644
index 000000000..33c8bd4d9
--- /dev/null
+++ b/pp_lite/test/cli_test.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+from unittest.mock import Mock, patch
+
+from click.testing import CliRunner
+
+from pp_lite import cli
+from web_console_v2.inspection.error_code import AreaCode, ErrorType
+
+
+class CliTest(unittest.TestCase):
+
+ # TODO(zhou.yi): create Env class to process environment variables
+ @patch('pp_lite.cli.write_termination_message')
+ def test_ot_missing_argument(self, mock_write_termination_message: Mock):
+ if 'INPUT_PATH' in os.environ:
+ del os.environ['INPUT_PATH']
+
+ runner = CliRunner()
+ with self.assertLogs(level='ERROR') as cm:
+ result = runner.invoke(cli=cli.pp_lite, args='psi-ot client')
+ # check logging
+ self.assertIn('Environment variable INPUT_PATH is missing.', cm.output[0])
+
+ # check termination log
+ mock_write_termination_message.assert_called_once_with(AreaCode.PSI_OT, ErrorType.INPUT_PARAMS_ERROR,
+ 'Environment variable INPUT_PATH is missing.')
+
+ # check exception that raise again
+ self.assertEqual(str(result.exception), '00071005-Environment variable INPUT_PATH is missing.')
+
+ @patch('pp_lite.cli.write_termination_message')
+ def test_hash_missing_argument(self, mock_write_termination_message: Mock):
+ if 'INPUT_PATH' in os.environ:
+ del os.environ['INPUT_PATH']
+
+ runner = CliRunner()
+ with self.assertLogs(level='ERROR') as cm:
+ result = runner.invoke(cli=cli.pp_lite, args='psi-hash client')
+
+ # check logging
+ self.assertIn('Environment variable INPUT_PATH is missing.', cm.output[0])
+
+ # check termination log
+ mock_write_termination_message.assert_called_once_with(AreaCode.PSI_HASH, ErrorType.INPUT_PARAMS_ERROR,
+ 'Environment variable INPUT_PATH is missing.')
+
+ # check exception that raise again
+ self.assertEqual(str(result.exception), '00091005-Environment variable INPUT_PATH is missing.')
+
+ @patch('pp_lite.cli.write_termination_message')
+ def test_trainer_client_missing_argument(self, mock_write_termination_message: Mock):
+ if 'TF_PORT' in os.environ:
+ del os.environ['TF_PORT']
+
+ runner = CliRunner()
+ with self.assertLogs(level='ERROR') as cm:
+ result = runner.invoke(cli=cli.pp_lite, args='trainer client')
+
+ # check logging
+ self.assertIn('Environment variable TF_PORT is missing.', cm.output[0])
+
+ # check termination log
+ mock_write_termination_message.assert_called_once_with(AreaCode.TRAINER, ErrorType.INPUT_PARAMS_ERROR,
+ 'Environment variable TF_PORT is missing.')
+
+ # check exception that raise again
+ self.assertEqual(str(result.exception), '00101005-Environment variable TF_PORT is missing.')
+
+ @patch('pp_lite.cli.write_termination_message')
+ def test_trainer_server_missing_argument(self, mock_write_termination_message: Mock):
+ if 'EXPORT_PATH' in os.environ:
+ del os.environ['EXPORT_PATH']
+
+ runner = CliRunner()
+ with self.assertLogs(level='ERROR') as cm:
+ result = runner.invoke(cli=cli.pp_lite, args='trainer server')
+
+ # check logging
+ self.assertIn('Environment variable EXPORT_PATH is missing.', cm.output[0])
+
+ # check termination log
+ mock_write_termination_message.assert_called_once_with(AreaCode.TRAINER, ErrorType.INPUT_PARAMS_ERROR,
+ 'Environment variable EXPORT_PATH is missing.')
+
+ # check exception that raise again
+ self.assertEqual(str(result.exception), '00101005-Environment variable EXPORT_PATH is missing.')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/test/psi_ot_test.py b/pp_lite/test/psi_ot_test.py
new file mode 100644
index 000000000..2716522ef
--- /dev/null
+++ b/pp_lite/test/psi_ot_test.py
@@ -0,0 +1,98 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import unittest
+import tempfile
+from concurrent.futures import ThreadPoolExecutor
+import importlib.util as imutil
+
+from pp_lite.data_join import envs
+from pp_lite.proto.arguments_pb2 import Arguments
+from pp_lite.proto.common_pb2 import DataJoinType
+from pp_lite.data_join.psi_ot.client import run as client_run
+from pp_lite.data_join.psi_ot.server import run as server_run
+from pp_lite.testing.make_data import make_data
+
+
+def check_psi_oprf():
+ spec = imutil.find_spec('psi_oprf')
+ if spec is None:
+ psi_oprf_existed = False
+ else:
+ psi_oprf_existed = True
+ return psi_oprf_existed
+
+
+class IntegratedTest(unittest.TestCase):
+
+ _PART_NUM = 2
+
+ def setUp(self) -> None:
+ self._temp_dir = tempfile.mkdtemp()
+ envs.STORAGE_ROOT = self._temp_dir
+ self._client_input_path = os.path.join(self._temp_dir, 'client')
+ self._server_input_path = os.path.join(self._temp_dir, 'server')
+ self._client_output_path = os.path.join(self._temp_dir, 'client_output')
+ self._server_output_path = os.path.join(self._temp_dir, 'server_output')
+ make_data(self._PART_NUM, self._client_input_path, self._server_input_path)
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self._temp_dir, ignore_errors=True)
+
+ def _run_client(self, data_join_type: DataJoinType):
+ args = Arguments(input_path=self._client_input_path,
+ output_path=self._client_output_path,
+ key_column='raw_id',
+ data_join_type=data_join_type,
+ server_port=50051,
+ joiner_port=50053,
+ worker_rank=0,
+ num_workers=1,
+ partitioned=True)
+ args.cluster_spec.workers.extend(['worker-0'])
+ client_run(args)
+
+ def _run_server(self, data_join_type: DataJoinType):
+ args = Arguments(input_path=self._server_input_path,
+ output_path=self._server_output_path,
+ key_column='raw_id',
+ data_join_type=data_join_type,
+ server_port=50051,
+ joiner_port=50053,
+ worker_rank=0)
+ args.cluster_spec.workers.extend(['worker-0'])
+ server_run(args)
+
+ def _run(self, data_join_type: DataJoinType):
+ pool = ThreadPoolExecutor(max_workers=2)
+ pool.submit(self._run_server, data_join_type)
+ self._run_client(data_join_type)
+ pool.shutdown()
+ # success tags are included
+ self.assertEqual(len(os.listdir(self._client_output_path)), self._PART_NUM * 2)
+ self.assertEqual(len(os.listdir(self._server_output_path)), self._PART_NUM * 2)
+
+ def test_run_hashed_data_join(self):
+ self._run(DataJoinType.HASHED_DATA_JOIN)
+
+ @unittest.skipUnless(check_psi_oprf(), 'require ot psi file')
+ def test_ot_psi(self):
+ self._run(DataJoinType.OT_PSI)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/test/psi_rsa_partition_test.py b/pp_lite/test/psi_rsa_partition_test.py
new file mode 100644
index 000000000..31c610d4f
--- /dev/null
+++ b/pp_lite/test/psi_rsa_partition_test.py
@@ -0,0 +1,129 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import csv
+import unittest
+import tempfile
+import shutil
+
+import rsa
+from gmpy2 import powmod # pylint: disable=no-name-in-module
+from cityhash import CityHash64 # pylint: disable=no-name-in-module
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from pp_lite.data_join import envs
+from pp_lite.data_join.psi_rsa.psi_client import run as client_run
+from pp_lite.data_join.psi_rsa.psi_server import run as server_run
+
+
+def sign(raw_id: str, private_key: rsa.PrivateKey) -> str:
+
+ def _sign(i: int):
+ return powmod(i, private_key.d, private_key.n).digits()
+
+ return hex(CityHash64(_sign(CityHash64(raw_id))))[2:]
+
+
+def _make_data(client_input: str, server_input: str, private_key: rsa.PrivateKey, part_num: int, part_size: int,
+ ex_size: int):
+ if not os.path.exists(client_input):
+ os.makedirs(client_input)
+ if not os.path.exists(server_input):
+ os.makedirs(server_input)
+ for part_id in range(part_num):
+ client_filename = os.path.join(client_input, f'part-{part_id}')
+ server_filename = os.path.join(server_input, f'part-{part_id}')
+ client_ids = range(part_id * (part_size + ex_size), part_id * (part_size + ex_size) + part_size)
+ server_ids = range(part_id * (part_size + ex_size) + ex_size, (part_id + 1) * (part_size + ex_size))
+ server_signed_ids = [sign(str(i), private_key) for i in server_ids]
+ with open(client_filename, 'wt', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, fieldnames=['example_id'])
+ writer.writeheader()
+ writer.writerows([{'example_id': str(i)} for i in client_ids])
+ with open(server_filename, 'wt', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, fieldnames=['signed_id'])
+ writer.writeheader()
+ writer.writerows([{'signed_id': str(i)} for i in server_signed_ids])
+
+
+class IntegratedTest(unittest.TestCase):
+
+ def setUp(self):
+ self._temp_dir = tempfile.mkdtemp()
+ envs.STORAGE_ROOT = self._temp_dir
+ envs.CLIENT_CONNECT_RETRY_INTERVAL = 1
+ self.client_input = os.path.join(self._temp_dir, 'client_input')
+ self.server_input = os.path.join(self._temp_dir, 'server_input')
+ self.client_output = os.path.join(self._temp_dir, 'client_output')
+ self.server_output = os.path.join(self._temp_dir, 'server_output')
+ _, private_key = rsa.newkeys(1024)
+ _make_data(self.client_input, self.server_input, private_key, 2, 1000, 200)
+ self.private_key_path = os.path.join(self.server_input, 'private.key')
+ with open(self.private_key_path, 'wb') as f:
+ f.write(private_key.save_pkcs1())
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self._temp_dir, ignore_errors=True)
+
+ @staticmethod
+ def _run_client(input_path, output_path):
+ args = {
+ 'input_dir': input_path,
+ 'output_dir': output_path,
+ 'key_column': 'example_id',
+ 'server_port': 50058,
+ 'batch_size': 4096,
+ 'worker_rank': 1,
+ 'num_workers': 5,
+ 'num_sign_parallel': 2,
+ 'partitioned': True,
+ 'partition_list': [],
+ }
+ client_run(args=args)
+
+ @staticmethod
+ def _run_server(input_path: str, output_path: str, private_key_path: str):
+ args = {
+ 'rsa_private_key_path': private_key_path,
+ 'input_dir': input_path,
+ 'output_dir': output_path,
+ 'signed_column': 'signed_id',
+ 'key_column': 'example_id',
+ 'server_port': 50058,
+ 'batch_size': 4096,
+ 'num_sign_parallel': 5
+ }
+ server_run(args=args)
+
+ def test(self):
+ futures = []
+ with ThreadPoolExecutor(max_workers=2) as pool:
+ futures.append(pool.submit(self._run_server, self.server_input, self.server_output, self.private_key_path))
+ futures.append(pool.submit(self._run_client, self.client_input, self.client_output))
+ for _ in as_completed(futures):
+ pass
+ with open(os.path.join(self.client_output, 'joined', 'part-00001-joined.csv'), 'rt', encoding='utf-8') as f:
+ reader = csv.DictReader(f)
+ ids = sorted([line['example_id'] for line in reader])
+ self.assertListEqual(ids, [str(id) for id in range(1400, 2200)])
+ with open(os.path.join(self.server_output, 'joined', 'output_1.csv'), 'rt', encoding='utf-8') as f:
+ reader = csv.DictReader(f)
+ ids = sorted([line['example_id'] for line in reader])
+ self.assertListEqual(ids, [str(id) for id in range(1400, 2200)])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/test/psi_rsa_test.py b/pp_lite/test/psi_rsa_test.py
new file mode 100644
index 000000000..791b4eb8d
--- /dev/null
+++ b/pp_lite/test/psi_rsa_test.py
@@ -0,0 +1,151 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import csv
+import unittest
+import tempfile
+import shutil
+
+import rsa
+from gmpy2 import powmod # pylint: disable=no-name-in-module
+from cityhash import CityHash64 # pylint: disable=no-name-in-module
+# Use ProcessPool to isolate logging config confliction.
+from concurrent.futures import ProcessPoolExecutor, as_completed
+import multiprocessing
+
+from pp_lite.data_join import envs
+from pp_lite.data_join.psi_rsa.psi_client import run as client_run
+from pp_lite.data_join.psi_rsa.psi_server import run as server_run
+
+
+def sign(raw_id: str, private_key: rsa.PrivateKey) -> str:
+
+ def _sign(i: int):
+ return powmod(i, private_key.d, private_key.n).digits()
+
+ return hex(CityHash64(_sign(CityHash64(raw_id))))[2:]
+
+
+def _make_data(client_input: str, server_input: str, private_key: rsa.PrivateKey, part_num: int, part_size: int,
+ ex_size: int):
+ if not os.path.exists(client_input):
+ os.makedirs(client_input)
+ if not os.path.exists(server_input):
+ os.makedirs(server_input)
+ for part_id in range(part_num):
+ client_filename = os.path.join(client_input, f'part-{part_id}')
+ server_filename = os.path.join(server_input, f'part-{part_id}')
+ client_ids = range(part_id * (part_size + ex_size), part_id * (part_size + ex_size) + part_size)
+ server_ids = range(part_id * (part_size + ex_size) + ex_size, (part_id + 1) * (part_size + ex_size))
+ server_signed_ids = [sign(str(i), private_key) for i in server_ids]
+ with open(client_filename, 'wt', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, fieldnames=['raw_id'])
+ writer.writeheader()
+ writer.writerows([{'raw_id': str(i)} for i in client_ids])
+ with open(server_filename, 'wt', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, fieldnames=['signed_id'])
+ writer.writeheader()
+ writer.writerows([{'signed_id': str(i)} for i in server_signed_ids])
+
+
+class IntegratedTest(unittest.TestCase):
+
+ def setUp(self):
+ self._temp_dir = tempfile.mkdtemp()
+ envs.STORAGE_ROOT = self._temp_dir
+ envs.CLIENT_CONNECT_RETRY_INTERVAL = 1
+ self.client_input = os.path.join(self._temp_dir, 'client_input')
+ self.server_input = os.path.join(self._temp_dir, 'server_input')
+ self.client_output = os.path.join(self._temp_dir, 'client_output')
+ self.server_output = os.path.join(self._temp_dir, 'server_output')
+ _, private_key = rsa.newkeys(1024)
+ _make_data(self.client_input, self.server_input, private_key, 2, 1000, 200)
+ self.private_key_path = os.path.join(self.server_input, 'private.key')
+ with open(self.private_key_path, 'wb') as f:
+ f.write(private_key.save_pkcs1())
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self._temp_dir, ignore_errors=True)
+
+ @staticmethod
+ def _run_client(input_path, output_path, storage_root: str):
+ envs.STORAGE_ROOT = storage_root
+ args = {
+ 'input_dir': input_path,
+ 'output_dir': output_path,
+ 'key_column': 'raw_id',
+ 'server_port': 50058,
+ 'batch_size': 4096,
+ 'num_workers': 5,
+ 'num_sign_parallel': 2,
+ 'partitioned': False,
+ 'partition_list': [],
+ }
+ client_run(args)
+
+ @staticmethod
+ def _run_server(input_path: str, output_path: str, private_key_path: str, storage_root: str):
+ envs.STORAGE_ROOT = storage_root
+ args = {
+ 'rsa_private_key_path': private_key_path,
+ 'input_dir': input_path,
+ 'output_dir': output_path,
+ 'signed_column': 'signed_id',
+ 'key_column': 'raw_id',
+ 'server_port': 50058,
+ 'batch_size': 4096,
+ 'num_sign_parallel': 5
+ }
+ server_run(args=args)
+
+ def test(self):
+ futures = []
+ with ProcessPoolExecutor(max_workers=2) as pool:
+ futures.append(
+ pool.submit(self._run_server, self.server_input, self.server_output, self.private_key_path,
+ envs.STORAGE_ROOT))
+ futures.append(pool.submit(self._run_client, self.client_input, self.client_output, envs.STORAGE_ROOT))
+ for _ in as_completed(futures):
+ pass
+ with open(os.path.join(self.client_output, 'joined', 'joined.csv'), 'rt', encoding='utf-8') as f:
+ reader = csv.DictReader(f)
+ self.assertEqual(len([line['raw_id'] for line in reader]), 1600)
+ with open(os.path.join(self.server_output, 'joined', 'output.csv'), 'rt', encoding='utf-8') as f:
+ reader = csv.DictReader(f)
+ self.assertEqual(len([line['raw_id'] for line in reader]), 1600)
+
+ def test_client_input_file(self):
+ futures = []
+ with ProcessPoolExecutor(max_workers=2) as pool:
+ futures.append(
+ pool.submit(self._run_server, self.server_input, self.server_output, self.private_key_path,
+ envs.STORAGE_ROOT))
+ futures.append(
+ pool.submit(self._run_client, os.path.join(self.client_input, 'part-0'), self.client_output,
+ envs.STORAGE_ROOT))
+ for _ in as_completed(futures):
+ pass
+ with open(os.path.join(self.client_output, 'joined', 'joined.csv'), 'rt', encoding='utf-8') as f:
+ reader = csv.DictReader(f)
+ self.assertEqual(len([line['raw_id'] for line in reader]), 800)
+ with open(os.path.join(self.server_output, 'joined', 'output.csv'), 'rt', encoding='utf-8') as f:
+ reader = csv.DictReader(f)
+ self.assertEqual(len([line['raw_id'] for line in reader]), 800)
+
+
+if __name__ == '__main__':
+ multiprocessing.set_start_method('spawn')
+ unittest.main()
diff --git a/pp_lite/test/trainer_test.py b/pp_lite/test/trainer_test.py
new file mode 100755
index 000000000..ddb62de40
--- /dev/null
+++ b/pp_lite/test/trainer_test.py
@@ -0,0 +1,141 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import numpy as np
+import time
+import unittest
+import json
+
+import tensorflow.compat.v1 as tf
+import shutil
+from multiprocessing import get_context
+
+from pp_lite.trainer.client.client import main as client_main
+from pp_lite.trainer.server.server import main as server_main
+from pp_lite.proto.arguments_pb2 import TrainerArguments
+
+
+class IntegratedTest(unittest.TestCase):
+
+ def test_e2e(self):
+ logging.basicConfig(level=logging.INFO)
+ cluster_spec_str = json.dumps({
+ 'master': ['localhost:50101'],
+ 'ps': ['localhost:50102', 'localhost:50104'],
+ 'worker': ['localhost:50103', 'localhost:50105']
+ })
+ cluster_spec_dict = json.loads(cluster_spec_str)
+ if isinstance(cluster_spec_dict, dict):
+ for name, addrs in cluster_spec_dict.items():
+ if name in ['master', 'ps', 'worker'] and isinstance(addrs, list):
+ for addr in addrs:
+ if not isinstance(addr, str):
+ raise TypeError('Input cluster_spec type error')
+ else:
+ raise TypeError('Input cluster_spec type error')
+ else:
+ raise TypeError('Input cluster_spec type error')
+
+ args = TrainerArguments(data_path='pp_lite/trainer/data/',
+ file_wildcard='**/*',
+ export_path='pp_lite/trainer/model/',
+ server_port=55550,
+ tf_port=51001,
+ num_clients=1,
+ cluster_spec=cluster_spec_str,
+ model_version=0,
+ model_max_to_keep=5,
+ tolerated_version_gap=1,
+ task_mode='local',
+ local_steps=100,
+ batch_size=10,
+ master_addr='localhost:55555',
+ ps_rank=0,
+ worker_rank=0,
+ save_version_gap=10)
+
+ # generate TFRecord
+ if not tf.io.gfile.exists(args.data_path):
+ logging.info('Generating data ...')
+ tf.io.gfile.makedirs(args.data_path)
+
+ (x, y), _ = tf.keras.datasets.mnist.load_data()
+ x = x.reshape((x.shape[0], -1)) / 255
+ n = 1000
+ num = x.shape[0] // n
+ for idx in range(num):
+ np_to_tfrecords(x[idx * n:idx * n + n], y[idx * n:idx * n + n], f'{args.data_path}{idx}')
+
+ context = get_context('spawn')
+
+ process_server = context.Process(target=server_main, args=(args,), daemon=True)
+ process_server.start()
+ time.sleep(1)
+
+ cluster_spec_dict['master'] = ['localhost:50201']
+ args.cluster_spec = json.dumps(cluster_spec_dict)
+ client_main(args=args)
+ process_server.join()
+
+ shutil.rmtree(args.data_path, ignore_errors=False, onerror=None)
+ shutil.rmtree(args.export_path, ignore_errors=False, onerror=None)
+
+
+def _bytes_feature(value):
+ if isinstance(value, type(tf.constant(0))): # if value ist tensor
+ value = value.numpy() # get value of tensor
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _float_feature(value):
+ return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+
+def _int64_feature(value):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def serialize_array(array):
+ array = tf.io.serialize_tensor(array)
+ return array
+
+
+def np_to_tfrecords(X, Y, file_path_prefix):
+ # Generate tfrecord writer
+ result_tf_file = file_path_prefix + '.tfrecords'
+ writer = tf.python_io.TFRecordWriter(result_tf_file)
+
+ # iterate over each sample,
+ # and serialize it as ProtoBuf.
+ # temporarily enable eager execution so _bytes_feature can call Tensor.numpy()
+ tf.enable_eager_execution()
+ for idx in range(X.shape[0]):
+ x = X[idx]
+ y = Y[idx]
+ data = {
+ 'X': _bytes_feature(tf.serialize_tensor(x.astype(np.float32))),
+ 'X_size': _int64_feature(x.shape[0]),
+ 'Y': _int64_feature(y)
+ }
+ features = tf.train.Features(feature=data)
+ example = tf.train.Example(features=features)
+ serialized = example.SerializeToString()
+ writer.write(serialized)
+ tf.disable_eager_execution()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/testing/BUILD.bazel b/pp_lite/testing/BUILD.bazel
new file mode 100644
index 000000000..eca1ef331
--- /dev/null
+++ b/pp_lite/testing/BUILD.bazel
@@ -0,0 +1,8 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+py_library(
+ name = "testing",
+ testonly = True,
+ srcs = ["make_data.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+)
diff --git a/pp_lite/testing/make_data.py b/pp_lite/testing/make_data.py
new file mode 100644
index 000000000..c8110fc23
--- /dev/null
+++ b/pp_lite/testing/make_data.py
@@ -0,0 +1,72 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import csv
+import os
+import shutil
+from typing import List
+import uuid
+
+
+def _make_fake_data(input_dir: str,
+ num_partitions: int,
+ line_num: int,
+ partitioned: bool = True,
+ spark: bool = False) -> List[str]:
+ header = sorted([f'x_{str(i)}' for i in range(20)])
+ header.append('part_id')
+ header = sorted(header)
+ for pid in range(num_partitions):
+ filename_prefix = 'part' if partitioned else 'abcd'
+ filename = os.path.join(input_dir, f'{filename_prefix}-{pid}')
+ if spark:
+ filename = filename + '-' + str(uuid.uuid4())
+ with open(filename, 'w', encoding='utf-8') as file:
+ writer = csv.DictWriter(file, header)
+ writer.writeheader()
+ for i in range(line_num):
+ data = {h: pid + i + j for j, h in enumerate(header)}
+ data['part_id'] = pid
+ writer.writerow(data)
+ if partitioned:
+ with open(os.path.join(input_dir, '_SUCCESS'), 'w', encoding='utf-8') as f:
+ f.write('')
+ return header
+
+
+def make_data(num_partition, client_path: str, server_path: str):
+ shutil.rmtree(client_path, ignore_errors=True)
+ shutil.rmtree(server_path, ignore_errors=True)
+ os.makedirs(client_path, exist_ok=True)
+ os.makedirs(server_path, exist_ok=True)
+ num_lines = 1000
+ ex_lines = 200
+ for part_id in range(num_partition):
+ client_range = range(part_id * num_lines * 10, part_id * num_lines * 10 + num_lines)
+ server_range = range(part_id * num_lines * 10 - ex_lines, part_id * num_lines * 10 + num_lines - ex_lines)
+ client_ids = [{'raw_id': i} for i in client_range]
+ server_ids = [{'raw_id': i} for i in server_range]
+ with open(os.path.join(client_path, f'part-{part_id}'), 'wt', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, ['raw_id'])
+ writer.writeheader()
+ writer.writerows(client_ids)
+ with open(os.path.join(server_path, f'part-{part_id}'), 'wt', encoding='utf-8') as f:
+ writer = csv.DictWriter(f, ['raw_id'])
+ writer.writeheader()
+ writer.writerows(server_ids)
+
+
+if __name__ == '__main__':
+ make_data(2, 'client_input', 'server_input')
diff --git a/pp_lite/utils/BUILD.bazel b/pp_lite/utils/BUILD.bazel
new file mode 100644
index 000000000..10630b626
--- /dev/null
+++ b/pp_lite/utils/BUILD.bazel
@@ -0,0 +1,38 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+py_library(
+ name = "utils",
+ srcs = [
+ "decorators.py",
+ "logging_config.py",
+ "metric_collector.py",
+ "metrics.py",
+ "tools.py",
+ ],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ "//py_libs:metrics_lib",
+ ],
+)
+
+py_test(
+ name = "decorators_test",
+ size = "small",
+ srcs = ["decorators_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":utils",
+ "//pp_lite/testing",
+ ],
+)
+
+py_test(
+ name = "tools_test",
+ size = "small",
+ srcs = ["tools_test.py"],
+ visibility = ["//pp_lite:pp_lite_package"],
+ deps = [
+ ":utils",
+ "//pp_lite/testing",
+ ],
+)
diff --git a/pp_lite/utils/decorators.py b/pp_lite/utils/decorators.py
new file mode 100644
index 000000000..bb0487b5e
--- /dev/null
+++ b/pp_lite/utils/decorators.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import time
+import signal
+from functools import wraps
+
+
+def raise_timeout(signum, frame):
+ raise TimeoutError
+
+
+def timeout_fn(time_in_second: int = 60):
+ """Raise TimeoutError after time_in_second.
+ Note that this decorator should be used on main thread. Found more info in
+ ref: https://stackoverflow.com/questions/54749342/valueerror-signal-only-works-in-main-thread
+ """
+
+ def decorator_fn(f):
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ signal.signal(signal.SIGALRM, raise_timeout)
+ signal.alarm(time_in_second)
+ try:
+ return f(*args, **kwargs)
+ finally:
+ signal.signal(signal.SIGALRM, signal.SIG_IGN)
+
+ return wrapper
+
+ return decorator_fn
+
+
+def retry_fn(retry_times: int = 3):
+
+ def decorator_fn(f):
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ for i in range(retry_times - 1):
+ try:
+ return f(*args, **kwargs)
+ # pylint: disable=broad-except
+ except Exception as e:
+ logging.exception(f'Call function {f.__name__} failed, retrying times...{i + 1}')
+ continue
+ return f(*args, **kwargs)
+
+ return wrapper
+
+ return decorator_fn
+
+
+def time_log(log_type: str = 'Function'):
+
+ def decorator_fn(f):
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ logging.info(f'[{log_type}] start ! ! !')
+ start_time = time.time()
+ res = f(*args, **kwargs)
+ logging.info(f'[{log_type}] used time: {time.time() - start_time} s')
+ return res
+
+ return wrapper
+
+ return decorator_fn
diff --git a/pp_lite/utils/decorators_test.py b/pp_lite/utils/decorators_test.py
new file mode 100644
index 000000000..5ec0c98be
--- /dev/null
+++ b/pp_lite/utils/decorators_test.py
@@ -0,0 +1,61 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import unittest
+from pp_lite.utils.decorators import retry_fn, timeout_fn
+
+
+class DecoratorTest(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self._res = 0
+
+ def test_timeout_fn(self):
+
+ @timeout_fn(2)
+ def func() -> int:
+ time.sleep(1)
+ return 1
+
+ @timeout_fn(1)
+ def some_unstable_func() -> int:
+ time.sleep(2)
+ return 1
+
+ self.assertEqual(func(), 1)
+ with self.assertRaises(TimeoutError):
+ some_unstable_func()
+
+ def test_retry_fn(self):
+
+ @retry_fn(4)
+ def func():
+ self._res = self._res + 2
+
+ @retry_fn(4)
+ def some_unstable_func():
+ self._res = self._res + 2
+ raise TimeoutError
+
+ func()
+ self.assertEqual(self._res, 2)
+ with self.assertRaises(TimeoutError):
+ some_unstable_func()
+ self.assertEqual(self._res, 10)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pp_lite/utils/logging_config.py b/pp_lite/utils/logging_config.py
new file mode 100644
index 000000000..bec8d3eef
--- /dev/null
+++ b/pp_lite/utils/logging_config.py
@@ -0,0 +1,54 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from typing import Dict
+from datetime import datetime
+
+
+def log_path(log_dir: str) -> str:
+ return os.path.join(log_dir, f'{datetime.now().strftime("%Y%m%d_%H%M%S_%f")}.log')
+
+
+def logging_config(file_path: str) -> Dict:
+ return {
+ 'version': 1,
+ 'disable_existing_loggers': False,
+ 'root': {
+ 'handlers': ['console', 'file'],
+ 'level': 'DEBUG'
+ },
+ 'handlers': {
+ 'console': {
+ 'class': 'logging.StreamHandler',
+ 'formatter': 'generic',
+ 'level': 'INFO'
+ },
+ 'file': {
+ 'class': 'logging.FileHandler',
+ 'formatter': 'generic',
+ 'filename': file_path,
+ 'encoding': 'utf-8',
+ 'level': 'DEBUG'
+ }
+ },
+ 'formatters': {
+ 'generic': {
+ 'format': '%(asctime)s [%(process)d] [%(levelname)s] %(message)s',
+ 'datefmt': '%Y-%m-%d %H:%M:%S',
+ 'class': 'logging.Formatter'
+ }
+ }
+ }
diff --git a/pp_lite/utils/metric_collector.py b/pp_lite/utils/metric_collector.py
new file mode 100755
index 000000000..0e276db63
--- /dev/null
+++ b/pp_lite/utils/metric_collector.py
@@ -0,0 +1,35 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from py_libs import metrics
+from os import environ
+from typing import ContextManager, Union, Dict
+
+service_name = environ.get('METRIC_COLLECTOR_SERVICE_NAME', 'default_metric_service')
+endpoint = environ.get('METRIC_COLLECTOR_EXPORT_ENDPOINT')
+
+cluster_name = environ.get('CLUSTER', 'default_cluster')
+k8s_job_name = environ.get('APPLICATION_ID', 'default_k8s_job_name')
+global_service_label = {'k8s_job_name': k8s_job_name}
+if endpoint is not None:
+ metrics.add_handler(metrics.OpenTelemetryMetricsHandler.new_handler(cluster_name, endpoint, service_name))
+
+
+def emit_counter(name: str, value: Union[int, float], tags: Dict[str, str] = None):
+ metrics.emit_counter(name, value, global_service_label if tags is None else {**tags, **global_service_label})
+
+
+def emit_timing(name: str, tags: Dict[str, str] = None) -> ContextManager[None]:
+ return metrics.emit_timing(name, global_service_label if tags is None else {**tags, **global_service_label})
diff --git a/pp_lite/utils/metrics.py b/pp_lite/utils/metrics.py
new file mode 100644
index 000000000..dae3bbde0
--- /dev/null
+++ b/pp_lite/utils/metrics.py
@@ -0,0 +1,56 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pp_lite.utils.tools import print_named_dict
+
+
+class Handler(object):
+
+ def emit_counter(self, key: str, value: int):
+ """
+ Different handler has different count strategy.
+ This version is intended to be implemented by subclasses so raises a NotImplementedError.
+ """
+ raise NotImplementedError('Emit must be implemented by Handler subclasses')
+
+
+class AuditHandler(Handler):
+
+ def __init__(self, interval: int = 50):
+ super().__init__()
+ self._audit_metrics = {}
+ self._step = 0
+ self._INTERVAL = interval
+
+ def emit_counter(self, key: str, value: int):
+ self._audit_metrics[key] = self._audit_metrics.get(key, 0) + value
+ self._step += 1
+ if self._step % self._INTERVAL == 0:
+ self.show_audit_info()
+
+ def get_value(self, key: str) -> int:
+ return self._audit_metrics.get(key, 0)
+
+ def show_audit_info(self):
+ if not self._audit_metrics:
+ return
+ print_named_dict(name='Audit', target_dict=self._audit_metrics)
+
+
+_audit_client = AuditHandler()
+
+emit_counter = _audit_client.emit_counter
+get_audit_value = _audit_client.get_value
+show_audit_info = _audit_client.show_audit_info
diff --git a/pp_lite/utils/tools.py b/pp_lite/utils/tools.py
new file mode 100644
index 000000000..e8edacbd7
--- /dev/null
+++ b/pp_lite/utils/tools.py
@@ -0,0 +1,28 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Dict, List
+
+
+def print_named_dict(name: str, target_dict: Dict):
+ logging.info(f'===================={name}====================')
+ for key, value in target_dict.items():
+ logging.info(f'{key}: {value}')
+ logging.info(f'===================={"=" * len(name)}====================')
+
+
+def get_partition_ids(worker_rank: int, num_workers: int, num_partitions: int) -> List[int]:
+ return [i for i in range(num_partitions) if i % num_workers == worker_rank]
diff --git a/pp_lite/utils/tools_test.py b/pp_lite/utils/tools_test.py
new file mode 100644
index 000000000..1334cd8e7
--- /dev/null
+++ b/pp_lite/utils/tools_test.py
@@ -0,0 +1,30 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from pp_lite.utils.tools import get_partition_ids
+
+
+class ToolsTest(unittest.TestCase):
+
+ def test_get_partition_ids(self):
+ self.assertListEqual(get_partition_ids(1, 5, 12), [1, 6, 11])
+ self.assertListEqual(get_partition_ids(3, 5, 2), [])
+ self.assertListEqual(get_partition_ids(4, 5, 10), [4, 9])
+ self.assertListEqual(get_partition_ids(5, 5, 10), [])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/py_libs/BUILD.bazel b/py_libs/BUILD.bazel
new file mode 100644
index 000000000..587d28421
--- /dev/null
+++ b/py_libs/BUILD.bazel
@@ -0,0 +1,49 @@
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+ name = "metrics_lib",
+ srcs = [
+ "metrics.py",
+ ],
+ imports = [".."],
+ deps = [
+ "@common_opentelemetry_exporter_otlp//:pkg",
+ "@common_opentelemetry_sdk//:pkg",
+ ],
+)
+
+py_test(
+ name = "metrics_lib_test",
+ size = "small",
+ srcs = [
+ "metrics_test.py",
+ ],
+ imports = [".."],
+ main = "metrics_test.py",
+ deps = [
+ ":metrics_lib",
+ ],
+)
+
+py_library(
+ name = "sdk",
+ srcs = [
+ "sdk.py",
+ ],
+ imports = [
+ "..",
+ ],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "@common_requests//:pkg",
+ ],
+)
+
+py_library(
+ name = "logging_config",
+ srcs = [
+ "logging_config.py",
+ ],
+ imports = [".."],
+)
diff --git a/py_libs/logging_config.py b/py_libs/logging_config.py
new file mode 100644
index 000000000..4ebaa6b7d
--- /dev/null
+++ b/py_libs/logging_config.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from logging import LogRecord
+from typing import Optional
+import os
+
+LOGGING_LEVEL = os.environ.get('LOGGING_LEVEL', 'INFO')
+
+
+class LevelFilter(logging.Filter):
+
+ def filter(self, record: LogRecord):
+ if record.levelno <= logging.WARNING:
+ return False
+ return True
+
+
+def logging_config(role: str, log_file: Optional[str] = None):
+ # Remove all handlers associated with the root logger object.
+ for handler in logging.root.handlers:
+ logging.root.removeHandler(handler)
+
+ logging_format = f'%(asctime)s %(levelname)-7s [{role}] %(message)s'
+ logging.basicConfig(level=LOGGING_LEVEL, format=logging_format)
+ logging.getLogger('urllib3.connectionpool').addFilter(LevelFilter())
+ if log_file is not None:
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
+ file_handler = logging.FileHandler(log_file)
+ file_handler.setLevel(LOGGING_LEVEL)
+ file_handler.setFormatter(logging.Formatter(logging_format))
+ logging.getLogger().addHandler(file_handler)
diff --git a/py_libs/metrics.py b/py_libs/metrics.py
new file mode 100644
index 000000000..bcdcbc8c7
--- /dev/null
+++ b/py_libs/metrics.py
@@ -0,0 +1,259 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from contextlib import contextmanager
+from datetime import datetime
+import logging
+from abc import ABCMeta, abstractmethod
+import sys
+from typing import ContextManager, Dict, Optional, Union
+from threading import Lock
+
+from opentelemetry import trace, _metrics as metrics
+from opentelemetry._metrics.instrument import UpDownCounter
+from opentelemetry._metrics.measurement import Measurement
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk._metrics import MeterProvider
+from opentelemetry.sdk._metrics.export import (PeriodicExportingMetricReader, ConsoleMetricExporter, MetricExporter,
+ MetricExportResult, Metric, Sequence)
+from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
+from opentelemetry.exporter.otlp.proto.grpc._metric_exporter import OTLPMetricExporter
+from opentelemetry.sdk.trace.export import (BatchSpanProcessor, ConsoleSpanExporter, SpanExportResult, SpanExporter,
+ ReadableSpan)
+
+
+def _validate_tags(tags: Dict[str, str]):
+ if tags is None:
+ return
+ for k, v in tags.items():
+ if not isinstance(k, str) or not isinstance(v, str):
+ raise TypeError(f'Expected str, actually {type(k)}: {type(v)}')
+
+
+class DevNullSpanExporter(SpanExporter):
+
+ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
+ return SpanExportResult.SUCCESS
+
+ def shutdown(self):
+ pass
+
+
+class DevNullMetricExporter(MetricExporter):
+
+ def export(self, metrics: Sequence[Metric]) -> MetricExportResult: # pylint: disable=redefined-outer-name
+ return MetricExportResult.SUCCESS
+
+ def shutdown(self):
+ pass
+
+
+class MetricsHandler(metaclass=ABCMeta):
+
+ @abstractmethod
+ def emit_counter(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
+ """Emits counter metrics which will be accumulated.
+
+ Args:
+ name: name of the metrics, e.g. foo.bar
+ value: value of the metrics in integer, e.g. 43
+ tags: extra tags of the counter, e.g. {"is_test": True}
+ """
+
+ @abstractmethod
+ def emit_store(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
+ """Emits store metrics.
+
+ Args:
+ name: name of the metrics, e.g. foo.bar
+ value: value of the metrics in integer, e.g. 43
+ tags: extra tags of the counter, e.g. {"is_test": True}
+ """
+
+ @abstractmethod
+ def emit_timing(self, name: str, tags: Dict[str, str] = None) -> ContextManager[None]:
+ """Emits timing generator.
+
+
+ Args:
+ name: name of metrics, e.g. foo.bar
+ tags: extra tags of the counter, e.g. {"is_test": True}
+
+ Returns:
+ Generator of timing scope.
+ """
+
+
+class _DefaultMetricsHandler(MetricsHandler):
+
+ def emit_counter(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ tags = tags or {}
+ logging.info(f'[Metric][Counter] {name}: {value}, tags={tags}')
+
+ def emit_store(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ tags = tags or {}
+ logging.info(f'[Metric][Store] {name}: {value}, tags={tags}')
+
+ @contextmanager
+ def emit_timing(self, name: str, tags: Dict[str, str] = None) -> ContextManager[None]:
+ tags = tags or {}
+ logging.info(f'[Meitrcs][Timing] {name} started, tags={tags}')
+ started = datetime.timestamp(datetime.now())
+ yield None
+ ended = datetime.timestamp(datetime.now())
+ logging.info(f'[Meitrcs][Timing] {name}: {(ended - started):.2f}s ended, tags={tags}')
+
+
+class OpenTelemetryMetricsHandler(MetricsHandler):
+
+ class Callback:
+
+ def __init__(self) -> None:
+ self._measurement_list = []
+
+ def add(self, value: Union[int, float], tags: Dict[str, str]):
+ self._measurement_list.append(Measurement(value=value, attributes=tags))
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if len(self._measurement_list) == 0:
+ raise StopIteration
+ return self._measurement_list.pop(0)
+
+ def __call__(self):
+ return iter(self)
+
+ @classmethod
+ def new_handler(cls,
+ cluster: str,
+ apm_server_endpoint: str,
+ instrument_module_name: Optional[str] = None) -> 'OpenTelemetryMetricsHandler':
+ instrument_module_name = instrument_module_name or 'fedlearner_webconsole'
+ resource = Resource.create(attributes={
+ 'service.name': instrument_module_name,
+ 'deployment.environment': cluster,
+ })
+ # initiailized trace stuff
+ if apm_server_endpoint == 'stdout':
+ span_exporter = ConsoleSpanExporter(out=sys.stdout)
+ elif apm_server_endpoint == '/dev/null':
+ span_exporter = DevNullSpanExporter()
+ else:
+ span_exporter = OTLPSpanExporter(endpoint=apm_server_endpoint)
+ tracer_provider = TracerProvider(resource=resource)
+ tracer_provider.add_span_processor(BatchSpanProcessor(span_exporter))
+ trace.set_tracer_provider(tracer_provider)
+
+ # initiailized meter stuff
+ if apm_server_endpoint == 'stdout':
+ metric_exporter = ConsoleMetricExporter(out=sys.stdout)
+ elif apm_server_endpoint == '/dev/null':
+ metric_exporter = DevNullMetricExporter()
+ else:
+ metric_exporter = OTLPMetricExporter(endpoint=apm_server_endpoint)
+ reader = PeriodicExportingMetricReader(metric_exporter, export_interval_millis=60000)
+ meter_provider = MeterProvider(metric_readers=[reader], resource=resource)
+ metrics.set_meter_provider(meter_provider=meter_provider)
+
+ return cls(tracer=tracer_provider.get_tracer(instrument_module_name),
+ meter=meter_provider.get_meter(instrument_module_name))
+
+ def __init__(self, tracer: trace.Tracer, meter: metrics.Meter):
+ self._tracer = tracer
+ self._meter = meter
+
+ self._lock = Lock()
+ self._cache: Dict[str, Union[UpDownCounter, OpenTelemetryMetricsHandler.Callback]] = {}
+
+ def emit_counter(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
+ # Note that the `values.` prefix is used for Elastic Index Dynamic Inference.
+ # Optimize by decreasing lock.
+ if name not in self._cache:
+ with self._lock:
+ # Double check `self._cache` content.
+ if name not in self._cache:
+ counter = self._meter.create_up_down_counter(name=f'values.{name}')
+ self._cache[name] = counter
+ assert isinstance(self._cache[name], UpDownCounter)
+ self._cache[name].add(value, attributes=tags)
+
+ def emit_store(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
+ # Note that the `values.` prefix is used for Elastic Index Dynamic Inference.
+ # Optimize by decreasing lock.
+ if name not in self._cache:
+ with self._lock:
+ # Double check `self._cache` content.
+ if name not in self._cache:
+ cb = OpenTelemetryMetricsHandler.Callback()
+ self._meter.create_observable_gauge(name=f'values.{name}', callback=cb)
+ self._cache[name] = cb
+ assert isinstance(self._cache[name], OpenTelemetryMetricsHandler.Callback)
+ self._cache[name].add(value=value, tags=tags)
+
+ def emit_timing(self, name: str, tags: Dict[str, str] = None) -> ContextManager[None]:
+ return self._tracer.start_as_current_span(name=name, attributes=tags)
+
+
+class _Client(MetricsHandler):
+ """A wrapper for all handlers.
+
+ Inspired by logging module, use this to avoid usage of global statement,
+ which will make the code more thread-safe."""
+ _handlers = []
+
+ def __init__(self):
+ self._handlers.append(_DefaultMetricsHandler())
+
+ def emit_counter(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ _validate_tags(tags)
+ for handler in self._handlers:
+ handler.emit_counter(name, value, tags)
+
+ def emit_store(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ _validate_tags(tags)
+ for handler in self._handlers:
+ handler.emit_store(name, value, tags)
+
+ @contextmanager
+ def emit_timing(self, name: str, tags: Dict[str, str] = None) -> ContextManager[None]:
+ _validate_tags(tags)
+ emit_timeings = []
+ for handler in self._handlers:
+ emit_timeings.append(handler.emit_timing(name, tags))
+ for e in emit_timeings:
+ e.__enter__()
+ yield None
+ emit_timeings.reverse()
+ for e in emit_timeings:
+ e.__exit__(None, None, None)
+
+ def add_handler(self, handler):
+ self._handlers.append(handler)
+
+ def reset_handlers(self):
+ # Only keep the first one
+ del self._handlers[1:]
+
+
+# Exports all to module level
+_client = _Client()
+emit_counter = _client.emit_counter
+emit_store = _client.emit_store
+emit_timing = _client.emit_timing
+add_handler = _client.add_handler
+reset_handlers = _client.reset_handlers
diff --git a/py_libs/metrics_test.py b/py_libs/metrics_test.py
new file mode 100644
index 000000000..63c0c5774
--- /dev/null
+++ b/py_libs/metrics_test.py
@@ -0,0 +1,295 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import contextlib
+import io
+import json
+import logging
+import multiprocessing
+from multiprocessing import Process, Queue
+import time
+import unittest
+from io import StringIO
+from unittest.mock import patch
+from os import linesep
+from typing import ContextManager, Dict
+from contextlib import contextmanager
+
+from opentelemetry import trace as otel_trace, _metrics as otel_metrics
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import BatchSpanProcessor
+from opentelemetry.sdk._metrics import MeterProvider
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace.export import ConsoleSpanExporter
+from opentelemetry.sdk._metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader
+
+from py_libs import metrics
+from py_libs.metrics import _DefaultMetricsHandler, MetricsHandler, OpenTelemetryMetricsHandler
+
+
+class _FakeMetricsHandler(MetricsHandler):
+
+ def emit_counter(self, name, value: int, tags: Dict[str, str] = None):
+ logging.info(f'[Test][Counter] {name} - {value}')
+
+ def emit_store(self, name, value: int, tags: Dict[str, str] = None):
+ logging.info(f'[Test][Store] {name} - {value}')
+
+ @contextmanager
+ def emit_timing(self, name: str, tags: Dict[str, str] = None) -> ContextManager[None]:
+ logging.info(f'[Test][Timing] {name} started')
+ yield None
+ logging.info(f'[Test][Timing] {name} ended')
+
+
+class DefaultMetricsHandlerTest(unittest.TestCase):
+
+ def setUp(self):
+ self._handler = _DefaultMetricsHandler()
+
+ def test_emit_counter(self):
+ with self.assertLogs() as cm:
+ self._handler.emit_counter('test', 1)
+ self._handler.emit_counter('test2', 2)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Counter] test: 1, tags={}', '[Metric][Counter] test2: 2, tags={}'])
+
+ def test_emit_store(self):
+ with self.assertLogs() as cm:
+ self._handler.emit_store('test', 199)
+ self._handler.emit_store('test2', 299)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Store] test: 199, tags={}', '[Metric][Store] test2: 299, tags={}'])
+
+ def test_emit_timing(self):
+ with self.assertLogs() as cm:
+ with self._handler.emit_timing('test'):
+ time.sleep(0.01)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(
+ logs, ['[Meitrcs][Timing] test started, tags={}', '[Meitrcs][Timing] test: 0.01s ended, tags={}'])
+
+
+class ClientTest(unittest.TestCase):
+
+ def setUp(self):
+ metrics.add_handler(_FakeMetricsHandler())
+
+ def tearDown(self):
+ metrics.reset_handlers()
+
+ def test_emit_counter(self):
+ with self.assertRaises(TypeError):
+ metrics.emit_counter('test', 1, tags={'name': 1})
+
+ with self.assertLogs() as cm:
+ metrics.emit_counter('test', 1)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Counter] test: 1, tags={}', '[Test][Counter] test - 1'])
+
+ def test_emit_store(self):
+ with self.assertRaises(TypeError):
+ metrics.emit_store('test', 1, tags={'name': 1})
+
+ with self.assertLogs() as cm:
+ metrics.emit_store('test', 199)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Store] test: 199, tags={}', '[Test][Store] test - 199'])
+
+ def test_emit_timing(self):
+ with self.assertRaises(TypeError):
+ metrics.emit_store('test', 1, tags={'name': 1})
+
+ with self.assertLogs() as cm:
+ with metrics.emit_timing('test'):
+ time.sleep(0.01)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, [
+ '[Meitrcs][Timing] test started, tags={}', '[Test][Timing] test started', '[Test][Timing] test ended',
+ '[Meitrcs][Timing] test: 0.01s ended, tags={}'
+ ])
+
+
+class OpenTelemetryMetricsHandlerClassMethodTest(unittest.TestCase):
+
+ def setUp(self):
+ self._span_out = StringIO()
+ self._span_exporter_patcher = patch('py_libs.metrics.OTLPSpanExporter',
+ lambda **kwargs: ConsoleSpanExporter(out=self._span_out))
+ self._metric_out = StringIO()
+ self._metric_exporter_patcher = patch('py_libs.metrics.OTLPMetricExporter',
+ lambda **kwargs: ConsoleMetricExporter(out=self._metric_out))
+ self._span_exporter_patcher.start()
+ self._metric_exporter_patcher.start()
+
+ def tearDown(self):
+ self._metric_exporter_patcher.stop()
+ self._span_exporter_patcher.stop()
+
+ def test_new_handler(self):
+ OpenTelemetryMetricsHandler.new_handler(cluster='default', apm_server_endpoint='stdout')
+ self.assertEqual(
+ otel_trace.get_tracer_provider().resource,
+ Resource(
+ attributes={
+ 'telemetry.sdk.language': 'python',
+ 'telemetry.sdk.name': 'opentelemetry',
+ 'telemetry.sdk.version': '1.10.0',
+ 'service.name': 'fedlearner_webconsole',
+ 'deployment.environment': 'default',
+ }))
+ self.assertEqual(
+ otel_metrics.get_meter_provider()._sdk_config.resource, # pylint: disable=protected-access
+ Resource(
+ attributes={
+ 'telemetry.sdk.language': 'python',
+ 'telemetry.sdk.name': 'opentelemetry',
+ 'telemetry.sdk.version': '1.10.0',
+ 'service.name': 'fedlearner_webconsole',
+ 'deployment.environment': 'default',
+ }))
+
+
+class OpenTelemetryMetricsHandlerTest(unittest.TestCase):
+
+ def setUp(self):
+ self._span_out = StringIO()
+ self._metric_out = StringIO()
+ tracer_provider = TracerProvider()
+ # We have to custom formatter for easing the streaming split json objects.
+ tracer_provider.add_span_processor(
+ BatchSpanProcessor(
+ ConsoleSpanExporter(
+ out=self._span_out,
+ formatter=lambda span: span.to_json(indent=None) + linesep,
+ )))
+ reader = PeriodicExportingMetricReader(ConsoleMetricExporter(out=self._metric_out),
+ export_interval_millis=60000)
+ meter_provider = MeterProvider(metric_readers=[reader])
+ self._tracer_provider = tracer_provider
+ self._meter_provider = meter_provider
+ self._handler = OpenTelemetryMetricsHandler(tracer=tracer_provider.get_tracer(__file__),
+ meter=meter_provider.get_meter(__file__))
+
+ def _force_flush(self):
+ self._meter_provider.force_flush()
+ self._metric_out.flush()
+ self._tracer_provider.force_flush()
+ self._span_out.flush()
+
+ def test_emit_store(self):
+ # Note that same instrument with different tags won't be aggregated.
+ # Aggregation rule for `emit_store` is delivering the last value of this interval.
+ # If no value at this interval, no `Metric` will be sent.
+ self._handler.emit_store(name='test_store', value=1, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._handler.emit_store(name='test_store', value=5, tags={'module': 'dataset', 'uuid': 'tag2'})
+ self._handler.emit_store(name='test_store', value=2, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self._force_flush()
+ self._force_flush()
+ self._handler.emit_store(name='test_store', value=0, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self.assertEqual(self._span_out.getvalue(), '')
+ self._metric_out.seek(0)
+ lines = self._metric_out.readlines()
+ measurements = []
+ for l in lines:
+ measurement = json.loads(l)
+ measurements.append(measurement)
+ self.assertEqual(len(measurements), 3)
+ self.assertEqual(measurements[0]['attributes'], {'uuid': 'tag1', 'module': 'dataset'})
+ self.assertEqual(measurements[1]['attributes'], {'uuid': 'tag2', 'module': 'dataset'})
+ self.assertEqual(measurements[0]['name'], 'values.test_store')
+ self.assertEqual([m['point']['value'] for m in measurements], [2, 5, 0])
+
+ def test_emit_counter(self):
+ # Note that same instrument with different tags won't be aggregated.
+ # Aggregation rule for `emit_counter` is delivering the accumulated value with the same tags during this interval. # pylint: disable=line-too-long
+ # If no value at this interval, a `Metric` with value of last interval will be sent.
+ self._handler.emit_counter(name='test_counter', value=1, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._handler.emit_counter(name='test_counter', value=5, tags={'module': 'dataset', 'uuid': 'tag2'})
+ self._handler.emit_counter(name='test_counter', value=2, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self._force_flush()
+ self._handler.emit_counter(name='test_counter', value=-1, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self.assertEqual(self._span_out.getvalue(), '')
+ self._metric_out.seek(0)
+ lines = self._metric_out.readlines()
+ measurements = []
+ for l in lines:
+ measurement = json.loads(l)
+ measurements.append(measurement)
+ self.assertEqual(len(measurements), 6)
+ self.assertEqual(measurements[0]['attributes'], {'uuid': 'tag1', 'module': 'dataset'})
+ self.assertEqual(measurements[1]['attributes'], {'uuid': 'tag2', 'module': 'dataset'})
+ self.assertEqual(measurement['name'], 'values.test_counter')
+ self.assertEqual([m['point']['value'] for m in measurements], [3, 5, 3, 5, 2, 5])
+
+ def test_emit_timing(self):
+ with self._handler.emit_timing('test', {}):
+ time.sleep(0.1)
+ with self._handler.emit_timing('test', {}):
+ time.sleep(0.2)
+ with self._handler.emit_timing('test2', {}):
+ time.sleep(0.1)
+ self._force_flush()
+ self._span_out.seek(0)
+ lines = self._span_out.readlines()
+ measurements = []
+ for l in lines:
+ measurement = json.loads(l)
+ measurements.append(measurement)
+
+ self.assertEqual(len(measurements), 3)
+ self.assertEqual([m['name'] for m in measurements], ['test', 'test', 'test2'])
+
+
+class OpenTelemetryMetricsHandlerOutputTest(unittest.TestCase):
+
+ @staticmethod
+ def suite_test(q: Queue, test_case: str):
+ # `OpenTelemetryMetricsHandler.new_handler` will set some global variables which cause multiple test case not idempotent issue. # pylint: disable=line-too-long
+ # So we use a children process to solve this problem.
+ f = io.StringIO()
+ with contextlib.redirect_stdout(f):
+ handler = OpenTelemetryMetricsHandler.new_handler(cluster='test_cluster', apm_server_endpoint=test_case)
+ handler.emit_store('test', 199)
+ handler.emit_counter('test2', 299)
+ otel_metrics.get_meter_provider().force_flush()
+ otel_trace.get_tracer_provider().force_flush()
+ q.put(f.getvalue())
+
+ def test_dev_null(self):
+
+ queue = multiprocessing.SimpleQueue()
+ test_process = Process(target=self.suite_test, args=(queue, '/dev/null'))
+ test_process.start()
+ test_process.join()
+ self.assertEqual(queue.get(), '')
+
+ def test_stdout(self):
+
+ queue = multiprocessing.SimpleQueue()
+ test_process = Process(target=self.suite_test, args=(queue, 'stdout'))
+ test_process.start()
+ test_process.join()
+ self.assertIn('test', queue.get())
+
+
+if __name__ == '__main__':
+ multiprocessing.set_start_method('spawn')
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/py_libs/sdk.py b/py_libs/sdk.py
new file mode 100644
index 000000000..afb2c0669
--- /dev/null
+++ b/py_libs/sdk.py
@@ -0,0 +1,355 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import logging
+from time import sleep
+import time
+import urllib
+import requests
+from http import HTTPStatus
+from typing import Dict, Tuple, Optional, List
+from fedlearner_webconsole.mmgr.models import ModelJobType
+from fedlearner_webconsole.dataset.models import DatasetJobKind, DatasetKindV2 as DatasetKind
+
+
+def _get_response_data(resp: requests.Response) -> Tuple[int, Dict]:
+ return resp.status_code, json.loads(resp.content)
+
+
+class WebconsoleClient:
+
+ def __init__(self, domain_name: str):
+ self._domain_name = domain_name
+ self._session = None
+ self.sign_in()
+
+ def sign_in(self):
+ self._session = requests.Session()
+ payload = {'username': 'robot', 'password': 'ZmxAMTIzNDUu'}
+ resp = self._session.post(f'{self._domain_name}/api/v2/auth/signin', json=payload)
+ content = json.loads(resp.content)
+ access_token = content['data']['access_token']
+ self._session.headers.update({'Authorization': f'Bearer {access_token}'})
+
+ def get_system_info(self):
+ url = f'{self._domain_name}/api/v2/settings/system_info'
+ return _get_response_data(self._session.get(url))
+
+ def get_templates(self):
+ url = f'{self._domain_name}/api/v2/workflow_templates'
+ return _get_response_data(self._session.get(url))
+
+ def get_template(self, template_id):
+ url = f'{self._domain_name}/api/v2/workflow_templates/{template_id}'
+ return _get_response_data(self._session.get(url))
+
+ def get_projects(self):
+ url = f'{self._domain_name}/api/v2/projects'
+ return _get_response_data(self._session.get(url))
+
+ def get_project_by_id(self, project_id: int):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}'
+ return _get_response_data(self._session.get(url))
+
+ def get_data_sources(self, project_id: int):
+ url = f'{self._domain_name}/api/v2/data_sources?project_id={project_id}'
+ return _get_response_data(self._session.get(url))
+
+ def get_datasets(self, project_id: int, keyword: Optional[str] = None):
+ url = f'{self._domain_name}/api/v2/datasets'
+ filter_expression = urllib.parse.quote(f'(and(project_id={project_id})(name~="{keyword}"))')
+ return _get_response_data(self._session.get(f'{url}?filter={filter_expression}'))
+
+ def post_data_source(self, project_id: int, input_data_path: str, data_source_name: str, store_format: str):
+ url = f'{self._domain_name}/api/v2/data_sources'
+ payload = {
+ 'project_id': project_id,
+ 'data_source': {
+ 'data_source_url': input_data_path,
+ 'name': data_source_name,
+ 'store_format': store_format,
+ 'dataset_format': 'TABULAR',
+ }
+ }
+ resp = self._session.post(url, json=payload)
+ return _get_response_data(resp)
+
+ def post_raw_dataset(self, project_id: int, name: str):
+ url = f'{self._domain_name}/api/v2/datasets'
+ payload = {
+ 'dataset_format': 'TABULAR',
+ 'dataset_type': 'PSI',
+ 'import_type': 'COPY',
+ 'store_format': 'TFRECORDS',
+ 'name': name,
+ 'kind': DatasetKind.RAW.value,
+ 'need_publish': True,
+ 'project_id': project_id
+ }
+ resp = self._session.post(url, json=payload)
+ return _get_response_data(resp)
+
+ def post_intersection_dataset(self, project_id: int, name: str):
+ url = f'{self._domain_name}/api/v2/datasets'
+ payload = {
+ 'dataset_format': 'TABULAR',
+ 'dataset_type': 'PSI',
+ 'import_type': 'COPY',
+ 'store_format': 'TFRECORDS',
+ 'name': name,
+ 'kind': DatasetKind.PROCESSED.value,
+ 'is_published': True,
+ 'project_id': project_id
+ }
+ resp = self._session.post(url, json=payload)
+ return _get_response_data(resp)
+
+ def post_data_batches(self, dataset_id: int, data_source_id: int):
+ url = f'{self._domain_name}/api/v2/datasets/{str(dataset_id)}/batches'
+ payload = {'data_source_id': data_source_id}
+ resp = self._session.post(url, json=payload)
+ return _get_response_data(resp)
+
+ def get_participant_datasets(self, project_id: int, kind: DatasetKind):
+ url = f'{self._domain_name}/api/v2/project/{project_id}/participant_datasets?kind={kind.value}'
+ return _get_response_data(self._session.get(url))
+
+ def get_dataset_job_variables(self, dataset_job_kind: DatasetJobKind):
+ url = f'{self._domain_name}/api/v2/dataset_job_definitions/{dataset_job_kind.value}'
+ return _get_response_data(self._session.get(url))
+
+ def post_dataset_job(self, project_id: int, output_dataset_id: int, dataset_job_parameter: Dict):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/dataset_jobs'
+ payload = {'dataset_job_parameter': dataset_job_parameter, 'output_dataset_id': output_dataset_id}
+ resp = self._session.post(url, json=payload)
+ return _get_response_data(resp)
+
+ def get_model_job_groups(self, project_id: int, keyword: Optional[str] = None):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_job_groups'
+ resp = self._session.get(url, json={'keyword': keyword})
+ return _get_response_data(resp)
+
+ def get_model_job_group(self, project_id: int, group_id: int):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_job_groups/{group_id}'
+ resp = self._session.get(url)
+ return _get_response_data(resp)
+
+ def post_model_job_groups(self, project_id: int, name: str, dataset_id: int):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_job_groups'
+ payload = {'name': name, 'dataset_id': dataset_id, 'algorithm_type': 'NN_VERTICAL'}
+ resp = self._session.post(url, json=payload)
+ return _get_response_data(resp)
+
+ def put_model_job_group(self, project_id: int, group_id: int, algorithm_id: int, config: Dict):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_job_groups/{group_id}'
+ payload = {'authorized': True, 'algorithm_id': algorithm_id, 'config': config}
+ resp = self._session.put(url, json=payload)
+ return _get_response_data(resp)
+
+ def launch_model_job(self, project_id: int, group_id: int):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_job_groups/{group_id}:launch'
+ resp = self._session.post(url, json={'comment': 'created by automated scheduler'})
+ return _get_response_data(resp)
+
+ def get_model_jobs(self, project_id: int, keyword: Optional[str] = None):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_jobs'
+ resp = self._session.get(url, json={'keyword': keyword})
+ return _get_response_data(resp)
+
+ def get_model_job(self, project_id: int, model_job_id: int):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_jobs/{model_job_id}'
+ resp = self._session.get(url)
+ return _get_response_data(resp)
+
+ def post_model_jobs(self, project_id: int, name: str, model_job_type: ModelJobType, dataset_id: int,
+ algorithm_id: int, model_id: int, config: Dict):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_jobs'
+ payload = {
+ 'name': name,
+ 'model_job_type': model_job_type.name,
+ 'dataset_id': dataset_id,
+ 'algorithm_type': 'NN_VERTICAL',
+ 'algorithm_id': algorithm_id,
+ 'model_id': model_id,
+ 'config': config
+ }
+ resp = self._session.post(url, json=payload)
+ return _get_response_data(resp)
+
+ def put_model_job(self, project_id: int, model_job_id: int, algorithm_id: int, config: Dict):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_jobs/{model_job_id}'
+ payload = {'algorithm_id': algorithm_id, 'config': config}
+ resp = self._session.put(url, json=payload)
+ return _get_response_data(resp)
+
+ def get_peer_model_job_group(self, project_id: int, group_id: int, participant_id: int):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_job_groups/{group_id}/peers/{participant_id}'
+ return _get_response_data(self._session.get(url))
+
+ def patch_peer_model_job_group(self, project_id: int, group_id: int, participant_id: int, config: Dict):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/model_job_groups/{group_id}/peers/{participant_id}'
+ return _get_response_data(self._session.patch(url, json={'config': config}))
+
+ def get_models(self, project_id: int, keyword: str):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/models?keyword={keyword}'
+ return _get_response_data(self._session.get(url))
+
+ def get_algorithms(self, project_id: int):
+ url = f'{self._domain_name}/api/v2/projects/{project_id}/algorithms'
+ return _get_response_data(self._session.get(url))
+
+
+class WebconsoleService:
+
+ def __init__(self, client: WebconsoleClient):
+ self.client = client
+
+ def get_project_by_name(self, name: str) -> Optional[Dict]:
+ code, content = self.client.get_projects()
+ assert code == HTTPStatus.OK
+ for project in content['data']:
+ if project['name'] == name:
+ return project
+ return None
+
+ def get_project_id_by_name(self, name: str) -> int:
+ project = self.get_project_by_name(name=name)
+ assert project is not None
+ return project['id']
+
+ def get_template_by_name(self, name: str) -> Optional[Dict]:
+ code, content = self.client.get_templates()
+ assert code == HTTPStatus.OK
+ for template in content['data']:
+ if template['name'] == name:
+ code, content = self.client.get_template(template['id'])
+ assert code == HTTPStatus.OK
+ return content['data']
+ return None
+
+ def get_model_job_group_by_name(self, project_id: int, name: str) -> Optional[Dict]:
+ code, content = self.client.get_model_job_groups(project_id=project_id)
+ assert code == HTTPStatus.OK
+ for group in content['data']:
+ if group['name'] == name:
+ group_id = group['id']
+ code, content = self.client.get_model_job_group(project_id=project_id, group_id=group_id)
+ if code == HTTPStatus.OK:
+ return content['data']
+ return None
+
+ def get_model_job_by_name(self, project_id: int, name: str) -> Optional[Dict]:
+ code, content = self.client.get_model_jobs(project_id=project_id, keyword=name)
+ assert code == HTTPStatus.OK
+ for job in content['data']:
+ if job['name'] == name:
+ return job
+ return None
+
+ def get_latest_model_job(self, project_id: int, group_id: int) -> Optional[Dict]:
+ code, content = self.client.get_model_job_group(project_id=project_id, group_id=group_id)
+ assert code == HTTPStatus.OK
+ if len(content['data']['model_jobs']) == 0:
+ return None
+ model_job_id = content['data']['model_jobs'][0]['id']
+ code, content = self.client.get_model_job(project_id=project_id, model_job_id=model_job_id)
+ if code != HTTPStatus.OK:
+ raise Exception(f'get job {model_job_id} failed with details {content}')
+ return content['data']
+
+ def get_model_by_name(self, project_id: int, name: str) -> Optional[Dict]:
+ code, content = self.client.get_models(project_id=project_id, keyword=name)
+ assert code == HTTPStatus.OK
+ for model in content['data']:
+ if model['name'] == name:
+ return model
+ return None
+
+ def get_data_source_by_name(self, name: str, project_id: int) -> Optional[Dict]:
+ code, content = self.client.get_data_sources(project_id=project_id)
+ assert code == HTTPStatus.OK
+ for data_source in content['data']:
+ if data_source['name'] == name:
+ return data_source
+ return None
+
+ def get_dataset_by_name(self, name: str, project_id: int) -> Optional[Dict]:
+ code, content = self.client.get_datasets(project_id=project_id, keyword=name)
+ assert code == HTTPStatus.OK
+ for dataset in content['data']:
+ if dataset['name'] == name:
+ return dataset
+ return None
+
+ def get_domain_name(self) -> str:
+ code, content = self.client.get_system_info()
+ assert code == HTTPStatus.OK
+ return content['data']['domain_name']
+
+ def get_participant_domain_name(self, name) -> str:
+ project = self.get_project_by_name(name=name)
+ assert project is not None
+ return project['participants'][0]['domain_name']
+
+ def get_participant_dataset_by_name(self, name: str, project_id: int, kind: DatasetKind) -> Optional[Dict]:
+ code, content = self.client.get_participant_datasets(project_id=project_id, kind=kind)
+ assert code == HTTPStatus.OK
+ for participant_dataset in content['data']:
+ if participant_dataset['name'] == name:
+ return participant_dataset
+ return None
+
+ def check_dataset_ready(self, name: str, project_id: int, log_interval: int = 50) -> Dict:
+ last_log_time = 0
+ while True:
+ dataset = self.get_dataset_by_name(name=name, project_id=project_id)
+ if dataset is not None and dataset['state_frontend'] == 'SUCCEEDED' and dataset['is_published']:
+ return dataset
+ current_time = time.time()
+ if current_time - last_log_time > log_interval:
+ logging.info(f'[check_dataset_ready]: still waiting for dataset {name} ready')
+ last_log_time = current_time
+ sleep(60)
+
+ def check_participant_dataset_ready(self,
+ name: str,
+ project_id: int,
+ kind: DatasetKind,
+ log_interval: int = 50) -> Dict:
+ last_log_time = 0
+ while True:
+ participant_dataset = self.get_participant_dataset_by_name(name=name, project_id=project_id, kind=kind)
+ if participant_dataset is not None:
+ return participant_dataset
+ current_time = time.time()
+ if current_time - last_log_time > log_interval:
+ logging.info(f'[check_participant_dataset_ready]: still waiting for participant dataset {name} ready')
+ last_log_time = current_time
+ sleep(60)
+
+ def get_algorithm_by_path(self, project_id: int, path: str):
+ code, content = self.client.get_algorithms(project_id=project_id)
+ assert code == HTTPStatus.OK
+ for algorithm in content['data']:
+ if algorithm['path'] == path:
+ return algorithm
+ return None
+
+ def get_groups_by_prefix(self, project_name: str, prefix: str) -> List[Dict]:
+ project_id = self.get_project_id_by_name(name=project_name)
+ code, content = self.client.get_model_job_groups(project_id=project_id)
+ assert code == HTTPStatus.OK
+ return [group for group in content['data'] if group['name'].startswith(prefix)]
diff --git a/sgx_network_simulation/Dockerfile b/sgx_network_simulation/Dockerfile
new file mode 100644
index 000000000..224d5f11e
--- /dev/null
+++ b/sgx_network_simulation/Dockerfile
@@ -0,0 +1,30 @@
+FROM golang:1.16 AS go
+
+RUN apt-get update && \
+ apt-get install -y make g++ libgmp-dev libglib2.0-dev libssl-dev && \
+ apt-get install -y protobuf-compiler && \
+ apt-get clean
+
+WORKDIR /app
+COPY tools/tcp_grpc_proxy ./
+RUN make build
+
+FROM python:3.6.8
+
+RUN echo "deb http://archive.debian.org/debian stretch main contrib non-free" > /etc/apt/sources.list
+
+RUN apt-get update && \
+ apt-get install -y curl vim make nginx && \
+ apt-get clean
+
+# upgrade nginx
+RUN echo "deb http://nginx.org/packages/mainline/debian/ stretch nginx deb-src http://nginx.org/packages/mainline/debian/ stretch nginx" > /etc/apt/sources.list.d/nginx.list
+RUN wget -qO - https://nginx.org/keys/nginx_signing.key | apt-key add -
+RUN apt update && \
+ apt remove nginx-common -y && \
+ apt install nginx
+
+COPY sgx_network_simulation/ /app/
+WORKDIR /app
+COPY --from=go /app/tcp2grpc ./
+COPY --from=go /app/grpc2tcp ./
diff --git a/sgx_network_simulation/nginx/sidecar.conf b/sgx_network_simulation/nginx/sidecar.conf
new file mode 100644
index 000000000..2586392d2
--- /dev/null
+++ b/sgx_network_simulation/nginx/sidecar.conf
@@ -0,0 +1,22 @@
+# Forwards all traffic to nginx controller
+server {
+ listen 32102 http2;
+
+ # No limits
+ client_max_body_size 0;
+ grpc_read_timeout 3600s;
+ grpc_send_timeout 3600s;
+ client_body_timeout 3600s;
+ # grpc_socket_keepalive is recommended but not required
+ # grpc_socket_keepalive is supported after nginx 1.15.6
+ grpc_socket_keepalive on;
+
+ grpc_set_header Authority fl-bytedance-client-auth.com;
+ grpc_set_header Host fl-bytedance-client-auth.com;
+ grpc_set_header X-Host sgx-test.fl-cmcc.com;
+
+ location / {
+ # Redirects to nginx controller
+ grpc_pass grpc://fedlearner-stack-ingress-nginx-controller.default.svc:80;
+ }
+}
diff --git a/sgx_network_simulation/sidecar.sh b/sgx_network_simulation/sidecar.sh
new file mode 100644
index 000000000..5933fc361
--- /dev/null
+++ b/sgx_network_simulation/sidecar.sh
@@ -0,0 +1,74 @@
+#!/bin/bash
+set -ex
+
+LISTEN_PORT_PATH="/pod-data/listen_port"
+while [ ! -s "$LISTEN_PORT_PATH" ]; do
+ echo "wait for $LISTEN_PORT_PATH ..."
+ sleep 1
+done
+WORKER_LISTEN_PORT=$(cat "$LISTEN_PORT_PATH")
+
+PROXY_LOCAL_PORT_PATH="/pod-data/proxy_local_port"
+while [ ! -s "$PROXY_LOCAL_PORT_PATH" ]; do
+ echo "wait for $PROXY_LOCAL_PORT_PATH ..."
+ sleep 1
+done
+PROXY_LOCAL_PORT=$(cat "$PROXY_LOCAL_PORT_PATH")
+
+GRPC_SERVER_PORT=32001
+if [ -n "$PORT0" ]; then
+ GRPC_SERVER_PORT=$PORT0
+fi
+
+TARGET_GRPC_PORT=32102
+if [ -n "$PORT1" ]; then
+ TARGET_GRPC_PORT=$PORT1
+fi
+
+echo "# Forwards all traffic to nginx controller
+server {
+ listen ${TARGET_GRPC_PORT} http2;
+
+ # No limits
+ client_max_body_size 0;
+ grpc_read_timeout 3600s;
+ grpc_send_timeout 3600s;
+ client_body_timeout 3600s;
+ # grpc_socket_keepalive is recommended but not required
+ # grpc_socket_keepalive is supported after nginx 1.15.6
+ grpc_socket_keepalive on;
+
+ grpc_set_header Authority ${EGRESS_HOST};
+ grpc_set_header Host ${EGRESS_HOST};
+ grpc_set_header X-Host ${SERVICE_ID}.${EGRESS_DOMAIN};
+
+ location / {
+ # Redirects to nginx controller
+ grpc_pass grpc://fedlearner-stack-ingress-nginx-controller.default.svc:80;
+ }
+}
+" > nginx/sidecar.conf
+
+rm -rf /etc/nginx/conf.d/*
+cp nginx/sidecar.conf /etc/nginx/conf.d/
+service nginx restart
+
+# Server sidecar: grpc to tcp, 5001 is the server port of main container
+echo "Starting server sidecar"
+./grpc2tcp --grpc_server_port=$GRPC_SERVER_PORT \
+ --target_tcp_address="localhost:$WORKER_LISTEN_PORT" &
+
+echo "Starting client sidecar"
+./tcp2grpc --tcp_server_port="$PROXY_LOCAL_PORT" \
+ --target_grpc_address="localhost:$TARGET_GRPC_PORT" &
+
+echo "===========Sidecar started!!============="
+
+while true
+do
+ if [[ -f "/pod-data/main-terminated" ]]
+ then
+ exit 0
+ fi
+ sleep 5
+done
diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel
new file mode 100644
index 000000000..66c16f2d6
--- /dev/null
+++ b/tools/BUILD.bazel
@@ -0,0 +1,4 @@
+package_group(
+ name = "tools_package",
+ packages = ["//tools/..."],
+)
diff --git a/tools/tcp_grpc_proxy/Dockerfile b/tools/tcp_grpc_proxy/Dockerfile
new file mode 100644
index 000000000..5e95dab47
--- /dev/null
+++ b/tools/tcp_grpc_proxy/Dockerfile
@@ -0,0 +1,26 @@
+FROM golang:1.16
+
+RUN apt-get update && \
+ apt install -y curl git vim && \
+ apt-get install -y make nginx g++ libgmp-dev libglib2.0-dev libssl-dev && \
+ apt-get install -y protobuf-compiler && \
+ apt-get clean
+
+WORKDIR /app
+COPY . /app/tcp_grpc_proxy
+
+# Copies PSI lib
+RUN git clone --recursive git://github.com/encryptogroup/PSI
+
+WORKDIR /app/PSI
+RUN make
+
+WORKDIR /app/tcp_grpc_proxy
+RUN make build
+
+# upgrade nginx
+RUN echo "deb http://nginx.org/packages/mainline/debian/ stretch nginx deb-src http://nginx.org/packages/mainline/debian/ stretch nginx" > /etc/apt/sources.list.d/nginx.list
+RUN wget -qO - https://nginx.org/keys/nginx_signing.key | apt-key add -
+RUN apt update && \
+ apt remove nginx-common -y && \
+ apt install nginx
diff --git a/tools/tcp_grpc_proxy/Makefile b/tools/tcp_grpc_proxy/Makefile
new file mode 100644
index 000000000..67e1889f9
--- /dev/null
+++ b/tools/tcp_grpc_proxy/Makefile
@@ -0,0 +1,13 @@
+install:
+ go get tcp_grpc_proxy
+ go mod download
+
+protobuf: install
+ go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
+ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
+ PATH="${PATH}:$(shell go env GOPATH)/bin" \
+ protoc -I=proto --go_out=. --go-grpc_out=. proto/*.proto
+
+build: protobuf
+ go build -o tcp2grpc cmd/tcp2grpc/main.go
+ go build -o grpc2tcp cmd/grpc2tcp/main.go
diff --git a/tools/tcp_grpc_proxy/cmd/grpc2tcp/BUILD.bazel b/tools/tcp_grpc_proxy/cmd/grpc2tcp/BUILD.bazel
new file mode 100644
index 000000000..06fc581c5
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/grpc2tcp/BUILD.bazel
@@ -0,0 +1,15 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+
+go_library(
+ name = "grpc2tcp_lib",
+ srcs = ["main.go"],
+ importpath = "fedlearner.net/tools/tcp_grpc_proxy/cmd/grpc2tcp",
+ visibility = ["//tools:tools_package"],
+ deps = ["//tools/tcp_grpc_proxy/pkg/proxy"],
+)
+
+go_binary(
+ name = "grpc2tcp",
+ embed = [":grpc2tcp_lib"],
+ visibility = ["//tools:tools_package"],
+)
diff --git a/tools/tcp_grpc_proxy/cmd/grpc2tcp/main.go b/tools/tcp_grpc_proxy/cmd/grpc2tcp/main.go
new file mode 100644
index 000000000..2b04343bb
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/grpc2tcp/main.go
@@ -0,0 +1,19 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "tcp_grpc_proxy/proxy"
+)
+
+func main() {
+ var grpcServerPort int
+ var targetTCPAddress string
+ flag.IntVar(&grpcServerPort, "grpc_server_port", 7766, "gRPC server port")
+ flag.StringVar(&targetTCPAddress, "target_tcp_address", "127.0.0.1:17766", "The target TCP server")
+ flag.Parse()
+ grpcServerAddress := fmt.Sprintf("0.0.0.0:%d", grpcServerPort)
+
+ grpc2tcpServer := proxy.NewGrpc2TCPServer(grpcServerAddress, targetTCPAddress)
+ grpc2tcpServer.Run()
+}
diff --git a/tools/tcp_grpc_proxy/cmd/grpcclient/main.go b/tools/tcp_grpc_proxy/cmd/grpcclient/main.go
new file mode 100644
index 000000000..670a89f02
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/grpcclient/main.go
@@ -0,0 +1,51 @@
+package main
+
+import (
+ "bytes"
+ "context"
+ "os"
+ "time"
+
+ "tcp_grpc_proxy/proto"
+
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+)
+
+func main() {
+ // Set up a connection to the server.
+ grpcServer := "127.0.0.1:7766"
+ conn, err := grpc.Dial(grpcServer, grpc.WithInsecure())
+ if err != nil {
+ logrus.Fatalf("did not connect: %v", err)
+ }
+ defer conn.Close()
+ tsc := proto.NewTunnelServiceClient(conn)
+
+ tc, err := tsc.Tunnel(context.Background())
+ if err != nil {
+ logrus.Fatalln(err)
+ }
+
+ sendPacket := func(data []byte) error {
+ return tc.Send(&proto.Chunk{Data: data})
+ }
+
+ go func() {
+ for {
+ chunk, err := tc.Recv()
+ if err != nil {
+ logrus.Println("Recv terminated:", err)
+ os.Exit(0)
+ }
+ logrus.Println(string(chunk.Data))
+ }
+
+ }()
+
+ for {
+ time.Sleep(time.Duration(2) * time.Second)
+ buf := bytes.NewBufferString("************Hello World**********").Bytes()
+ sendPacket(buf)
+ }
+}
diff --git a/tools/tcp_grpc_proxy/cmd/grpcserver/main.go b/tools/tcp_grpc_proxy/cmd/grpcserver/main.go
new file mode 100644
index 000000000..b17e4432f
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/grpcserver/main.go
@@ -0,0 +1,11 @@
+package main
+
+import (
+ "tcp_grpc_proxy/grpc2tcp"
+)
+
+func main() {
+ grpcServerAddress := "0.0.0.0:7766"
+ targetTCPAddress := "127.0.0.1:17766"
+ grpc2tcp.RunServer(grpcServerAddress, targetTCPAddress)
+}
diff --git a/tools/tcp_grpc_proxy/cmd/tcp2grpc/BUILD.bazel b/tools/tcp_grpc_proxy/cmd/tcp2grpc/BUILD.bazel
new file mode 100644
index 000000000..130eb9169
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/tcp2grpc/BUILD.bazel
@@ -0,0 +1,15 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
+
+go_library(
+ name = "tcp2grpc_lib",
+ srcs = ["main.go"],
+ importpath = "fedlearner.net/tools/tcp_grpc_proxy/cmd/tcp2grpc",
+ visibility = ["//tools:tools_package"],
+ deps = ["//tools/tcp_grpc_proxy/pkg/proxy"],
+)
+
+go_binary(
+ name = "tcp2grpc",
+ embed = [":tcp2grpc_lib"],
+ visibility = ["//tools:tools_package"],
+)
diff --git a/tools/tcp_grpc_proxy/cmd/tcp2grpc/main.go b/tools/tcp_grpc_proxy/cmd/tcp2grpc/main.go
new file mode 100644
index 000000000..fee88a884
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/tcp2grpc/main.go
@@ -0,0 +1,57 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "tcp_grpc_proxy/proxy"
+)
+
+func test() {
+ client, err := net.Dial("tcp", "127.0.0.1:17767")
+ if err != nil {
+ fmt.Println("err:", err)
+ return
+ }
+ defer client.Close()
+
+ go func() {
+ input := make([]byte, 1024)
+ for {
+ n, err := os.Stdin.Read(input)
+ if err != nil {
+ fmt.Println("input err:", err)
+ continue
+ }
+ client.Write([]byte(input[:n]))
+ }
+ }()
+
+ buf := make([]byte, 1024)
+ for {
+ n, err := client.Read(buf)
+ if err != nil {
+ if err == io.EOF {
+ return
+ }
+ fmt.Println("read err:", err)
+ continue
+ }
+ fmt.Println(string(buf[:n]))
+
+ }
+}
+
+func main() {
+ var tcpServerPort int
+ var targetGrpcAddress string
+ flag.IntVar(&tcpServerPort, "tcp_server_port", 17767, "TCP server port")
+ flag.StringVar(&targetGrpcAddress, "target_grpc_address", "127.0.0.1:7766", "The target gRPC server")
+ flag.Parse()
+ tcpServerAddress := fmt.Sprintf("0.0.0.0:%d", tcpServerPort)
+
+ tcp2grpcServer := proxy.NewTCP2GrpcServer(tcpServerAddress, targetGrpcAddress)
+ tcp2grpcServer.Run()
+}
diff --git a/tools/tcp_grpc_proxy/cmd/tcpclient/main.go b/tools/tcp_grpc_proxy/cmd/tcpclient/main.go
new file mode 100644
index 000000000..7e0c97467
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/tcpclient/main.go
@@ -0,0 +1,38 @@
+package main
+
+import (
+ "flag"
+ "net"
+ "time"
+
+ "github.com/sirupsen/logrus"
+)
+
+func main() {
+ var tcpServerAddress string
+ flag.StringVar(&tcpServerAddress, "tcp_server_address", "127.0.0.1:17767",
+ "TCP server address which the client connects to.")
+
+ conn, err := net.Dial("tcp", tcpServerAddress)
+ if err != nil {
+ logrus.Fatalf("Dail to tcp target %s error: %v", tcpServerAddress, err)
+ }
+ logrus.Infoln("Connected to", tcpServerAddress)
+ // Makes sure the connection gets closed
+ defer conn.Close()
+ defer logrus.Infoln("Connection closed to ", tcpServerAddress)
+
+ for {
+ conn.Write([]byte("hello world"))
+ logrus.Infof("Sent 'hello world' to server %s", tcpServerAddress)
+
+ tcpData := make([]byte, 64*1024)
+ _, err := conn.Read(tcpData)
+ if err != nil {
+ logrus.Fatalln("Read from tcp error: ", err)
+ }
+ logrus.Infof("Received '%s' from server", string(tcpData))
+
+ time.Sleep(time.Duration(5) * time.Second)
+ }
+}
diff --git a/tools/tcp_grpc_proxy/cmd/tcpserver/main.go b/tools/tcp_grpc_proxy/cmd/tcpserver/main.go
new file mode 100644
index 000000000..592c7b6bd
--- /dev/null
+++ b/tools/tcp_grpc_proxy/cmd/tcpserver/main.go
@@ -0,0 +1,46 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "net"
+
+ "github.com/sirupsen/logrus"
+)
+
+func handleTCPConn(conn net.Conn) {
+ for {
+ tcpData := make([]byte, 64*1024)
+ bytesRead, err := conn.Read(tcpData)
+ if err != nil {
+ logrus.Fatalln("Read from tcp error: ", err)
+ }
+ logrus.Infof("TCP server got %d bytes", bytesRead)
+ conn.Write([]byte("This is a string from TCP server"))
+ }
+}
+
+func main() {
+ var tcpServerPort int
+ flag.IntVar(&tcpServerPort, "tcp_server_port", 17766, "TCP server port")
+ flag.Parse()
+ tcpServerAddress := fmt.Sprintf("0.0.0.0:%d", tcpServerPort)
+
+ listener, err := net.Listen("tcp", tcpServerAddress)
+ if err != nil {
+ logrus.Fatalln("Listen TCP error: ", err)
+ }
+ defer listener.Close()
+ logrus.Infoln("Run TCPServer at ", tcpServerAddress)
+
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ logrus.Errorln("TCP listener error:", err)
+ continue
+ }
+
+ logrus.Infoln("Got tcp connection")
+ go handleTCPConn(conn)
+ }
+}
diff --git a/tools/tcp_grpc_proxy/go.mod b/tools/tcp_grpc_proxy/go.mod
new file mode 100644
index 000000000..7507c284b
--- /dev/null
+++ b/tools/tcp_grpc_proxy/go.mod
@@ -0,0 +1,12 @@
+module tcp_grpc_proxy
+
+go 1.16
+
+require (
+ github.com/golang/protobuf v1.5.2 // indirect
+ github.com/sirupsen/logrus v1.8.1
+ golang.org/x/net v0.0.0-20210525063256-abc453219eb5 // indirect
+ google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 // indirect
+ google.golang.org/grpc v1.38.0
+ google.golang.org/protobuf v1.26.0
+)
diff --git a/tools/tcp_grpc_proxy/go.sum b/tools/tcp_grpc_proxy/go.sum
new file mode 100644
index 000000000..a372202d1
--- /dev/null
+++ b/tools/tcp_grpc_proxy/go.sum
@@ -0,0 +1,106 @@
+cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
+github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
+github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
+github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
+github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
+github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
+github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
+github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
+github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
+github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
+github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
+github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
+github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
+github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
+github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
+github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
+github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
+github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
+github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
+github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
+github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
+github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
+github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
+golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
+golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
+golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20210525063256-abc453219eb5 h1:wjuX4b5yYQnEQHzd+CBcrcC6OVR2J1CN6mUy0oSxIPo=
+golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
+golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
+golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
+google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
+google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
+google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
+google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28=
+google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
+google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
+google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
+google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
+google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
+google.golang.org/grpc v1.38.0 h1:/9BgsAsa5nWe26HqOlvlgJnqBuktYOLCgjCPqsa56W0=
+google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
+google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
+google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
+google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
+google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
+google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
+google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4=
+google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
+google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
+google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
diff --git a/tools/tcp_grpc_proxy/pkg/proto/BUILD.bazel b/tools/tcp_grpc_proxy/pkg/proto/BUILD.bazel
new file mode 100644
index 000000000..727149066
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proto/BUILD.bazel
@@ -0,0 +1,28 @@
+load("@rules_proto//proto:defs.bzl", "proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+
+# gazelle:go_generate_proto true
+
+proto_library(
+ name = "proto_proto",
+ srcs = ["tunnel.proto"],
+ visibility = ["//tools:tools_package"],
+)
+
+# keep
+go_proto_library(
+ name = "proto_go_proto",
+ compilers = ["@io_bazel_rules_go//proto:go_grpc"],
+ importpath = "fedlearner.net/tools/tcp_grpc_proxy/pkg/proto",
+ proto = ":proto_proto",
+ visibility = ["//tools:tools_package"],
+)
+
+go_library(
+ name = "proto",
+ srcs = ["proto.go"],
+ embed = [":proto_go_proto"], # keep
+ importpath = "fedlearner.net/tools/tcp_grpc_proxy/pkg/proto",
+ visibility = ["//tools:tools_package"],
+)
diff --git a/tools/tcp_grpc_proxy/pkg/proto/proto.go b/tools/tcp_grpc_proxy/pkg/proto/proto.go
new file mode 100644
index 000000000..92256db4b
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proto/proto.go
@@ -0,0 +1 @@
+package proto
diff --git a/tools/tcp_grpc_proxy/pkg/proto/tunnel.proto b/tools/tcp_grpc_proxy/pkg/proto/tunnel.proto
new file mode 100644
index 000000000..ce5987254
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proto/tunnel.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+package proto;
+option go_package = "fedlearner.net/tools/tcp_grpc_proxy/pkg/proto";
+
+service TunnelService {
+ rpc Tunnel (stream Chunk) returns (stream Chunk);
+}
+
+message Chunk {
+ bytes data = 1;
+}
diff --git a/tools/tcp_grpc_proxy/pkg/proxy/BUILD.bazel b/tools/tcp_grpc_proxy/pkg/proxy/BUILD.bazel
new file mode 100644
index 000000000..7af01f33e
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proxy/BUILD.bazel
@@ -0,0 +1,32 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "proxy",
+ srcs = [
+ "grpc2tcp.go",
+ "tcp2grpc.go",
+ ],
+ importpath = "fedlearner.net/tools/tcp_grpc_proxy/pkg/proxy",
+ visibility = ["//tools:tools_package"],
+ deps = [
+ "//tools/tcp_grpc_proxy/pkg/proto",
+ "@com_github_sirupsen_logrus//:logrus",
+ "@org_golang_google_grpc//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "proxy_test",
+ srcs = [
+ "grpc2tcp_test.go",
+ "tcp2grpc_test.go",
+ ],
+ embed = [":proxy"],
+ visibility = ["//tools:tools_package"],
+ deps = [
+ "//tools/tcp_grpc_proxy/pkg/proto",
+ "@com_github_sirupsen_logrus//:logrus",
+ "@com_github_stretchr_testify//assert",
+ "@org_golang_google_grpc//:go_default_library",
+ ],
+)
diff --git a/tools/tcp_grpc_proxy/pkg/proxy/grpc2tcp.go b/tools/tcp_grpc_proxy/pkg/proxy/grpc2tcp.go
new file mode 100644
index 000000000..b9cb5c452
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proxy/grpc2tcp.go
@@ -0,0 +1,134 @@
+package proxy
+
+import (
+ "fmt"
+ "io"
+ "net"
+
+ "fedlearner.net/tools/tcp_grpc_proxy/pkg/proto"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+)
+
+// Grpc2TcpServer A server to proxy grpc traffic to TCP
+type Grpc2TcpServer struct {
+ proto.UnimplementedTunnelServiceServer
+ grpcServerAddress string
+ targetTcpAddress string
+}
+
+// Tunnel the implementation of gRPC Tunnel service
+func (s *Grpc2TcpServer) Tunnel(stream proto.TunnelService_TunnelServer) error {
+ tcpConnection, err := net.Dial("tcp", s.targetTcpAddress)
+ if err != nil {
+ logrus.Errorf("[GRPC2TCP] Dail to tcp target %s error: %v", s.targetTcpAddress, err)
+ return err
+ }
+ contextLogger := logrus.WithFields(logrus.Fields{
+ "prefix": "[GRPC2TCP]",
+ "tcp_client_addr": tcpConnection.LocalAddr().String(),
+ })
+ contextLogger.Infoln("Connected to", s.targetTcpAddress)
+ // Makes sure the connection gets closed
+ defer tcpConnection.Close()
+ defer contextLogger.Infoln("Connection closed to", s.targetTcpAddress)
+
+ errChan := make(chan error)
+
+ // Gets data from gRPC client and proxy to remote TCP server
+ go func() {
+ tcpSentBytes := 0
+ grpcReceivedBytes := 0
+ defer func() {
+ contextLogger.Infof("gRPC received %d bytes, TCP sent %d byte", grpcReceivedBytes, tcpSentBytes)
+ }()
+
+ for {
+ chunk, err := stream.Recv()
+ if err == io.EOF {
+ contextLogger.Infoln("gRpc client EOF")
+ return
+ }
+ if err != nil {
+ errChan <- fmt.Errorf("error while receiving gRPC data: %v", err)
+ return
+ }
+ data := chunk.Data
+ grpcReceivedBytes += len(data)
+
+ contextLogger.Debugln("Sending %d bytes to tcp server", len(data))
+ _, err = tcpConnection.Write(data)
+ if err != nil {
+ errChan <- fmt.Errorf("error while sending TCP data: %v", err)
+ return
+ } else {
+ tcpSentBytes += len(data)
+ }
+ }
+ }()
+
+ // Gets data from remote TCP server and proxy to gRPC client
+ go func() {
+ tcpReceivedBytes := 0
+ grpcSentBytes := 0
+ defer func() {
+ contextLogger.Infof("Tcp received %d bytes, gRPC sent %d bytes", tcpReceivedBytes, grpcSentBytes)
+ } ()
+
+ buff := make([]byte, 64*1024)
+ for {
+ bytesRead, err := tcpConnection.Read(buff)
+ if err == io.EOF {
+ contextLogger.Infoln("Remote TCP connection closed")
+ errChan <- nil
+ return
+ }
+ if err != nil {
+ errChan <- fmt.Errorf("error while receiving TCP data: %v", err)
+ return
+ }
+ tcpReceivedBytes += bytesRead
+
+ contextLogger.Debugf("Sending %d bytes to gRPC client\n", bytesRead)
+ err = stream.Send(&proto.Chunk{Data: buff[0:bytesRead]})
+ if err != nil {
+ errChan <- fmt.Errorf("error while sending gRPC data: %v", err)
+ return
+ } else {
+ grpcSentBytes += bytesRead
+ }
+ }
+ }()
+
+ // Blocking read
+ returnedError := <-errChan
+ if returnedError != nil {
+ contextLogger.Errorln(returnedError)
+ }
+ return returnedError
+}
+
+// NewGrpc2TcpServer constructs a Grpc2TCP server
+func NewGrpc2TcpServer(grpcServerAddress, targetTcpAddress string) *Grpc2TcpServer {
+ return &Grpc2TcpServer{
+ grpcServerAddress: grpcServerAddress,
+ targetTcpAddress: targetTcpAddress,
+ }
+}
+
+// Run starts the Grpc2TCP server
+func (s *Grpc2TcpServer) Run() {
+ listener, err := net.Listen("tcp", s.grpcServerAddress)
+ if err != nil {
+ logrus.Fatalln("Failed to listen: ", err)
+ }
+ defer listener.Close()
+
+ // Starts a gRPC server and register services
+ grpcServer := grpc.NewServer()
+ proto.RegisterTunnelServiceServer(grpcServer, s)
+ logrus.Infof("Starting gRPC server at: %s, target to %s", s.grpcServerAddress, s.targetTcpAddress)
+ if err := grpcServer.Serve(listener); err != nil {
+ logrus.Fatalln("Unable to start gRPC serve:", err)
+ }
+}
diff --git a/tools/tcp_grpc_proxy/pkg/proxy/grpc2tcp_test.go b/tools/tcp_grpc_proxy/pkg/proxy/grpc2tcp_test.go
new file mode 100644
index 000000000..6166916ee
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proxy/grpc2tcp_test.go
@@ -0,0 +1,84 @@
+package proxy
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "testing"
+ "time"
+
+ "fedlearner.net/tools/tcp_grpc_proxy/pkg/proto"
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "google.golang.org/grpc"
+)
+
+func runFakeTcpServer(listener net.Listener) {
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ logrus.Infoln("Intended TCP listener error:", err)
+ return
+ }
+
+ go func(conn net.Conn) {
+ defer conn.Close()
+ for {
+ request := make([]byte, 64*1024)
+ bytesRead, err := conn.Read(request)
+ if err == io.EOF {
+ logrus.Infoln("[TCP server] Connection finished")
+ return
+ }
+ if err != nil {
+ logrus.Errorln("[TCP seerver] Error:", err)
+ return
+ }
+ response := fmt.Sprintf("[Proxy] %s", string(request[0:bytesRead]))
+ conn.Write([]byte(response))
+ }
+ }(conn)
+ }
+}
+
+func TestGrpc2Tcp(t *testing.T) {
+ grpcServerAddress := "localhost:13001"
+ targetTcpAddress := "localhost:13002"
+
+ // Sets up a fake TCP server
+ listener, err := net.Listen("tcp", targetTcpAddress)
+ if err != nil {
+ assert.Fail(t, "Failed to listen")
+ }
+ go runFakeTcpServer(listener)
+
+ // Starts the proxy
+ tcp2grpcServer := NewGrpc2TcpServer(grpcServerAddress, targetTcpAddress)
+ go tcp2grpcServer.Run()
+ time.Sleep(1 * time.Second)
+
+ // Sends data by grpc connection and gets response in the same channel
+ responseChan := make(chan string)
+ for i := 0; i < 3; i++ {
+ go func(message string) {
+ grpcConn, _ := grpc.Dial(grpcServerAddress, grpc.WithInsecure())
+ grpcClient := proto.NewTunnelServiceClient(grpcConn)
+ stream, _ := grpcClient.Tunnel(context.Background())
+
+ stream.Send(&proto.Chunk{Data: []byte(message)})
+ stream.CloseSend()
+ chunk, _ := stream.Recv()
+ responseChan <- string(chunk.Data)
+ grpcConn.Close()
+ }(fmt.Sprintf("hello %d", i))
+ }
+
+ responses := make([]string, 0)
+ for i := 0; i < 3; i++ {
+ r := <-responseChan
+ responses = append(responses, r)
+ }
+ assert.ElementsMatch(t, responses,
+ []string{"[Proxy] hello 0", "[Proxy] hello 1", "[Proxy] hello 2"})
+}
diff --git a/tools/tcp_grpc_proxy/pkg/proxy/tcp2grpc.go b/tools/tcp_grpc_proxy/pkg/proxy/tcp2grpc.go
new file mode 100644
index 000000000..ae8baf8c3
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proxy/tcp2grpc.go
@@ -0,0 +1,149 @@
+package proxy
+
+import (
+ "context"
+ "io"
+ "net"
+ "sync"
+
+ "fedlearner.net/tools/tcp_grpc_proxy/pkg/proto"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+)
+
+// Tcp2GrpcServer to proxy TCP traffic to gRPC
+type Tcp2GrpcServer struct {
+ tcpServerAddress string
+ targetGrpcAddress string
+}
+
+// NewTcp2GrpcServer constructs a TCP2GrpcServer
+func NewTcp2GrpcServer(tcpServerAddress, targetGrpcAddress string) *Tcp2GrpcServer {
+ return &Tcp2GrpcServer{
+ tcpServerAddress: tcpServerAddress,
+ targetGrpcAddress: targetGrpcAddress,
+ }
+}
+
+func handleTcpConnection(tcpConn net.Conn, targetGrpcAddress string) {
+ contextLogger := logrus.WithFields(logrus.Fields{
+ "prefix": "[TCP2GRPC]",
+ "tcp_client_addr": tcpConn.RemoteAddr().String(),
+ })
+
+ contextLogger.Infoln("Handle tcp connection, target to:", targetGrpcAddress)
+ defer tcpConn.Close()
+
+ grpcConn, err := grpc.Dial(targetGrpcAddress, grpc.WithInsecure())
+ if err != nil {
+ contextLogger.Errorf("Failed to connect to grpc %s: %v\n", targetGrpcAddress, err)
+ return
+ }
+ defer grpcConn.Close()
+
+ grpcClient := proto.NewTunnelServiceClient(grpcConn)
+ stream, err := grpcClient.Tunnel(context.Background())
+ if err != nil {
+ contextLogger.Errorln("Error of tunnel service:", err)
+ return
+ }
+
+ var wg sync.WaitGroup
+
+ // Gets data from remote gRPC server and proxy to TCP client
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ tcpSentBytes := 0
+ grpcReceivedBytes := 0
+ defer func() {
+ contextLogger.Infof("gRPC received %d bytes, TCP sent %d byte", grpcReceivedBytes, tcpSentBytes)
+ }()
+
+ for {
+ chunk, err := stream.Recv()
+ if err == io.EOF {
+ contextLogger.Infoln("gRpc server EOF")
+ tcpConn.Close()
+ return
+ }
+ if err != nil {
+ contextLogger.Errorf("Recv from grpc target %s terminated: %v", targetGrpcAddress, err)
+ tcpConn.Close()
+ return
+ }
+ grpcReceivedBytes += len(chunk.Data)
+
+ contextLogger.Debugln("Sending %d bytes to TCP client", len(chunk.Data))
+ _, err = tcpConn.Write(chunk.Data)
+ if err != nil {
+ contextLogger.Errorln("Failed to send data to TCP client:", err)
+ return
+ } else {
+ tcpSentBytes += len(chunk.Data)
+ }
+ }
+ }()
+
+ // Gets data from TCP client and proxy to remote gRPC server
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ tcpReceivedBytes := 0
+ grpcSentBytes := 0
+ defer func() {
+ contextLogger.Infof("TCP received %d bytes, gRPC sent %d bytes", tcpReceivedBytes, grpcSentBytes)
+ }()
+
+ tcpData := make([]byte, 64*1024)
+ for {
+ bytesRead, err := tcpConn.Read(tcpData)
+
+ if err == io.EOF {
+ contextLogger.Infoln("Connection finished")
+ stream.CloseSend()
+ return
+ }
+ if err != nil {
+ contextLogger.Errorln("Read from tcp error:", err)
+ stream.CloseSend()
+ return
+ }
+ tcpReceivedBytes += bytesRead
+
+ contextLogger.Debugln("Sending %d bytes to gRPC server", bytesRead)
+ err = stream.Send(&proto.Chunk{Data: tcpData[0:bytesRead]})
+ if err != nil {
+ contextLogger.Errorln("Failed to send gRPC data:", err)
+ return
+ } else {
+ grpcSentBytes += bytesRead
+ }
+ }
+ }()
+
+ wg.Wait()
+}
+
+// Run Starts the server
+func (s *Tcp2GrpcServer) Run() {
+ listener, err := net.Listen("tcp", s.tcpServerAddress)
+ if err != nil {
+ logrus.Fatalln("Listen TCP error: ", err)
+ }
+ defer listener.Close()
+ logrus.Infoln("Run TCPServer at ", s.tcpServerAddress)
+
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ logrus.Errorln("TCP listener error:", err)
+ continue
+ }
+
+ logrus.Infoln("Got tcp connection")
+ go handleTcpConnection(conn, s.targetGrpcAddress)
+ }
+}
diff --git a/tools/tcp_grpc_proxy/pkg/proxy/tcp2grpc_test.go b/tools/tcp_grpc_proxy/pkg/proxy/tcp2grpc_test.go
new file mode 100644
index 000000000..3eb98bd4c
--- /dev/null
+++ b/tools/tcp_grpc_proxy/pkg/proxy/tcp2grpc_test.go
@@ -0,0 +1,85 @@
+package proxy
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "testing"
+ "time"
+
+ "fedlearner.net/tools/tcp_grpc_proxy/pkg/proto"
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+ "google.golang.org/grpc"
+)
+
+// mockGrpc2TcpServer is used to mock TunnelServer
+type mockTunnelServer struct {
+ proto.UnimplementedTunnelServiceServer
+}
+
+func (s *mockTunnelServer) Tunnel(stream proto.TunnelService_TunnelServer) error {
+ for {
+ chunk, err := stream.Recv()
+ if err == io.EOF {
+ logrus.Infoln("[gRPC server] Stream EOF")
+ return nil
+ }
+ if err != nil {
+ logrus.Errorln("[gRPC server] error:", err)
+ return err
+ }
+ response := fmt.Sprintf("[Proxy] %s", string(chunk.Data))
+ if err = stream.Send(&proto.Chunk{Data: []byte(response)}); err != nil {
+ return err
+ }
+ }
+}
+
+func runFakeGrpcServer(listener net.Listener) {
+ // Starts a gRPC server and register services
+ grpcServer := grpc.NewServer()
+ proto.RegisterTunnelServiceServer(grpcServer, &mockTunnelServer{})
+ if err := grpcServer.Serve(listener); err != nil {
+ logrus.Fatalln("Unable to start gRPC serve:", err)
+ }
+}
+
+func TestTcp2Grpc(t *testing.T) {
+ tcpServerAddress := "localhost:12001"
+ targetGrpcAddress := "localhost:12002"
+
+ // Sets up a fake gRPC server
+ listener, err := net.Listen("tcp", targetGrpcAddress)
+ if err != nil {
+ assert.Fail(t, "Failed to listen")
+ }
+ go runFakeGrpcServer(listener)
+
+ // Starts the proxy
+ tcp2grpcServer := NewTcp2GrpcServer(tcpServerAddress, targetGrpcAddress)
+ go tcp2grpcServer.Run()
+ time.Sleep(1 * time.Second)
+
+ // Sends data by tcp connection and gets response in the same channel
+ responseChan := make(chan string)
+ for i := 0; i < 3; i++ {
+ go func(message string) {
+ tcpConnection, _ := net.Dial("tcp", tcpServerAddress)
+ tcpConnection.Write([]byte(message))
+ response := make([]byte, 64*1024)
+ bytesRead, _ := tcpConnection.Read(response)
+ responseChan <- string(response[0:bytesRead])
+ tcpConnection.Close()
+ }(fmt.Sprintf("hello %d", i))
+ }
+
+ responses := make([]string, 0)
+ for i := 0; i < 3; i++ {
+ r := <-responseChan
+ responses = append(responses, r)
+ }
+
+ assert.ElementsMatch(t, responses,
+ []string{"[Proxy] hello 0", "[Proxy] hello 1", "[Proxy] hello 2"})
+}
diff --git a/tools/tcp_grpc_proxy/proto/tunnel.proto b/tools/tcp_grpc_proxy/proto/tunnel.proto
new file mode 100644
index 000000000..22ce1080b
--- /dev/null
+++ b/tools/tcp_grpc_proxy/proto/tunnel.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+package proto;
+option go_package = "proxy/proto";
+
+service TunnelService {
+ rpc Tunnel (stream Chunk) returns (stream Chunk);
+}
+
+message Chunk {
+ bytes data = 1;
+}
diff --git a/tools/tcp_grpc_proxy/proxy/grpc2tcp.go b/tools/tcp_grpc_proxy/proxy/grpc2tcp.go
new file mode 100644
index 000000000..a9c5f598d
--- /dev/null
+++ b/tools/tcp_grpc_proxy/proxy/grpc2tcp.go
@@ -0,0 +1,106 @@
+package proxy
+
+import (
+ "fmt"
+ "io"
+ "net"
+
+ "tcp_grpc_proxy/proxy/proto"
+
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+)
+
+// Grpc2TCPServer A server to proxy grpc traffic to TCP
+type Grpc2TCPServer struct {
+ proto.UnimplementedTunnelServiceServer
+ grpcServerAddress string
+ targetTCPAddress string
+}
+
+// Tunnel the implementation of gRPC Tunnel service
+func (s *Grpc2TCPServer) Tunnel(stream proto.TunnelService_TunnelServer) error {
+ tcpConnection, err := net.Dial("tcp", s.targetTCPAddress)
+ if err != nil {
+ logrus.Errorf("Dail to tcp target %s error: %v", s.targetTCPAddress, err)
+ return err
+ }
+ logrus.Infoln("Connected to", s.targetTCPAddress)
+ // Makes sure the connection gets closed
+ defer tcpConnection.Close()
+ defer logrus.Infoln("Connection closed to ", s.targetTCPAddress)
+
+ errChan := make(chan error)
+
+ // Gets data from gRPC client and proxy to remote TCP server
+ go func() {
+ for {
+ chunk, err := stream.Recv()
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ errChan <- fmt.Errorf("error while receiving gRPC data: %v", err)
+ return
+ }
+
+ data := chunk.Data
+ logrus.Infof("Sending %d bytes to tcp server", len(data))
+ _, err = tcpConnection.Write(data)
+ if err != nil {
+ errChan <- fmt.Errorf("error while sending TCP data: %v", err)
+ return
+ }
+ }
+ }()
+
+ // Gets data from remote TCP server and proxy to gRPC client
+ go func() {
+ buff := make([]byte, 64*1024)
+ for {
+ bytesRead, err := tcpConnection.Read(buff)
+ if err == io.EOF {
+ logrus.Infoln("Remote TCP connection closed")
+ return
+ }
+ if err != nil {
+ errChan <- fmt.Errorf("error while receiving TCP data: %v", err)
+ return
+ }
+
+ logrus.Infof("Sending %d bytes to gRPC client", bytesRead)
+ if err = stream.Send(&proto.Chunk{Data: buff[0:bytesRead]}); err != nil {
+ errChan <- fmt.Errorf("Error while sending gRPC data: %v", err)
+ return
+ }
+ }
+ }()
+
+ // Blocking read
+ returnedError := <-errChan
+ return returnedError
+}
+
+// NewGrpc2TCPServer constructs a Grpc2TCP server
+func NewGrpc2TCPServer(grpcServerAddress, targetTCPAddress string) *Grpc2TCPServer {
+ return &Grpc2TCPServer{
+ grpcServerAddress: grpcServerAddress,
+ targetTCPAddress: targetTCPAddress,
+ }
+}
+
+// Run starts the Grpc2TCP server
+func (s *Grpc2TCPServer) Run() {
+ listener, err := net.Listen("tcp", s.grpcServerAddress)
+ if err != nil {
+ logrus.Errorf("Failed to listen: ", err)
+ }
+
+ // Starts a gRPC server and register services
+ grpcServer := grpc.NewServer()
+ proto.RegisterTunnelServiceServer(grpcServer, s)
+ logrus.Infof("Starting gRPC server at: %s, target to %s", s.grpcServerAddress, s.targetTCPAddress)
+ if err := grpcServer.Serve(listener); err != nil {
+ logrus.Errorln("Unable to start gRPC serve:", err)
+ }
+}
diff --git a/tools/tcp_grpc_proxy/proxy/proto/tunnel.pb.go b/tools/tcp_grpc_proxy/proxy/proto/tunnel.pb.go
new file mode 100644
index 000000000..79602bc44
--- /dev/null
+++ b/tools/tcp_grpc_proxy/proxy/proto/tunnel.pb.go
@@ -0,0 +1,147 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// versions:
+// protoc-gen-go v1.26.0
+// protoc v3.17.3
+// source: tunnel.proto
+
+package proto
+
+import (
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ reflect "reflect"
+ sync "sync"
+)
+
+const (
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+type Chunk struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
+}
+
+func (x *Chunk) Reset() {
+ *x = Chunk{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_tunnel_proto_msgTypes[0]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *Chunk) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Chunk) ProtoMessage() {}
+
+func (x *Chunk) ProtoReflect() protoreflect.Message {
+ mi := &file_tunnel_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use Chunk.ProtoReflect.Descriptor instead.
+func (*Chunk) Descriptor() ([]byte, []int) {
+ return file_tunnel_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *Chunk) GetData() []byte {
+ if x != nil {
+ return x.Data
+ }
+ return nil
+}
+
+var File_tunnel_proto protoreflect.FileDescriptor
+
+var file_tunnel_proto_rawDesc = []byte{
+ 0x0a, 0x0c, 0x74, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05,
+ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x1b, 0x0a, 0x05, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x12, 0x12,
+ 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61,
+ 0x74, 0x61, 0x32, 0x39, 0x0a, 0x0d, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x53, 0x65, 0x72, 0x76,
+ 0x69, 0x63, 0x65, 0x12, 0x28, 0x0a, 0x06, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c, 0x12, 0x0c, 0x2e,
+ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x1a, 0x0c, 0x2e, 0x70, 0x72,
+ 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x28, 0x01, 0x30, 0x01, 0x42, 0x0d, 0x5a,
+ 0x0b, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
+ 0x6f, 0x74, 0x6f, 0x33,
+}
+
+var (
+ file_tunnel_proto_rawDescOnce sync.Once
+ file_tunnel_proto_rawDescData = file_tunnel_proto_rawDesc
+)
+
+func file_tunnel_proto_rawDescGZIP() []byte {
+ file_tunnel_proto_rawDescOnce.Do(func() {
+ file_tunnel_proto_rawDescData = protoimpl.X.CompressGZIP(file_tunnel_proto_rawDescData)
+ })
+ return file_tunnel_proto_rawDescData
+}
+
+var file_tunnel_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_tunnel_proto_goTypes = []interface{}{
+ (*Chunk)(nil), // 0: proto.Chunk
+}
+var file_tunnel_proto_depIdxs = []int32{
+ 0, // 0: proto.TunnelService.Tunnel:input_type -> proto.Chunk
+ 0, // 1: proto.TunnelService.Tunnel:output_type -> proto.Chunk
+ 1, // [1:2] is the sub-list for method output_type
+ 0, // [0:1] is the sub-list for method input_type
+ 0, // [0:0] is the sub-list for extension type_name
+ 0, // [0:0] is the sub-list for extension extendee
+ 0, // [0:0] is the sub-list for field type_name
+}
+
+func init() { file_tunnel_proto_init() }
+func file_tunnel_proto_init() {
+ if File_tunnel_proto != nil {
+ return
+ }
+ if !protoimpl.UnsafeEnabled {
+ file_tunnel_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*Chunk); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ type x struct{}
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
+ RawDescriptor: file_tunnel_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 1,
+ NumExtensions: 0,
+ NumServices: 1,
+ },
+ GoTypes: file_tunnel_proto_goTypes,
+ DependencyIndexes: file_tunnel_proto_depIdxs,
+ MessageInfos: file_tunnel_proto_msgTypes,
+ }.Build()
+ File_tunnel_proto = out.File
+ file_tunnel_proto_rawDesc = nil
+ file_tunnel_proto_goTypes = nil
+ file_tunnel_proto_depIdxs = nil
+}
diff --git a/tools/tcp_grpc_proxy/proxy/proto/tunnel_grpc.pb.go b/tools/tcp_grpc_proxy/proxy/proto/tunnel_grpc.pb.go
new file mode 100644
index 000000000..f60817673
--- /dev/null
+++ b/tools/tcp_grpc_proxy/proxy/proto/tunnel_grpc.pb.go
@@ -0,0 +1,133 @@
+// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
+
+package proto
+
+import (
+ context "context"
+ grpc "google.golang.org/grpc"
+ codes "google.golang.org/grpc/codes"
+ status "google.golang.org/grpc/status"
+)
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+// Requires gRPC-Go v1.32.0 or later.
+const _ = grpc.SupportPackageIsVersion7
+
+// TunnelServiceClient is the client API for TunnelService service.
+//
+// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
+type TunnelServiceClient interface {
+ Tunnel(ctx context.Context, opts ...grpc.CallOption) (TunnelService_TunnelClient, error)
+}
+
+type tunnelServiceClient struct {
+ cc grpc.ClientConnInterface
+}
+
+func NewTunnelServiceClient(cc grpc.ClientConnInterface) TunnelServiceClient {
+ return &tunnelServiceClient{cc}
+}
+
+func (c *tunnelServiceClient) Tunnel(ctx context.Context, opts ...grpc.CallOption) (TunnelService_TunnelClient, error) {
+ stream, err := c.cc.NewStream(ctx, &TunnelService_ServiceDesc.Streams[0], "/proto.TunnelService/Tunnel", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &tunnelServiceTunnelClient{stream}
+ return x, nil
+}
+
+type TunnelService_TunnelClient interface {
+ Send(*Chunk) error
+ Recv() (*Chunk, error)
+ grpc.ClientStream
+}
+
+type tunnelServiceTunnelClient struct {
+ grpc.ClientStream
+}
+
+func (x *tunnelServiceTunnelClient) Send(m *Chunk) error {
+ return x.ClientStream.SendMsg(m)
+}
+
+func (x *tunnelServiceTunnelClient) Recv() (*Chunk, error) {
+ m := new(Chunk)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// TunnelServiceServer is the server API for TunnelService service.
+// All implementations must embed UnimplementedTunnelServiceServer
+// for forward compatibility
+type TunnelServiceServer interface {
+ Tunnel(TunnelService_TunnelServer) error
+ mustEmbedUnimplementedTunnelServiceServer()
+}
+
+// UnimplementedTunnelServiceServer must be embedded to have forward compatible implementations.
+type UnimplementedTunnelServiceServer struct {
+}
+
+func (UnimplementedTunnelServiceServer) Tunnel(TunnelService_TunnelServer) error {
+ return status.Errorf(codes.Unimplemented, "method Tunnel not implemented")
+}
+func (UnimplementedTunnelServiceServer) mustEmbedUnimplementedTunnelServiceServer() {}
+
+// UnsafeTunnelServiceServer may be embedded to opt out of forward compatibility for this service.
+// Use of this interface is not recommended, as added methods to TunnelServiceServer will
+// result in compilation errors.
+type UnsafeTunnelServiceServer interface {
+ mustEmbedUnimplementedTunnelServiceServer()
+}
+
+func RegisterTunnelServiceServer(s grpc.ServiceRegistrar, srv TunnelServiceServer) {
+ s.RegisterService(&TunnelService_ServiceDesc, srv)
+}
+
+func _TunnelService_Tunnel_Handler(srv interface{}, stream grpc.ServerStream) error {
+ return srv.(TunnelServiceServer).Tunnel(&tunnelServiceTunnelServer{stream})
+}
+
+type TunnelService_TunnelServer interface {
+ Send(*Chunk) error
+ Recv() (*Chunk, error)
+ grpc.ServerStream
+}
+
+type tunnelServiceTunnelServer struct {
+ grpc.ServerStream
+}
+
+func (x *tunnelServiceTunnelServer) Send(m *Chunk) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+func (x *tunnelServiceTunnelServer) Recv() (*Chunk, error) {
+ m := new(Chunk)
+ if err := x.ServerStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// TunnelService_ServiceDesc is the grpc.ServiceDesc for TunnelService service.
+// It's only intended for direct use with grpc.RegisterService,
+// and not to be introspected or modified (even as a copy)
+var TunnelService_ServiceDesc = grpc.ServiceDesc{
+ ServiceName: "proto.TunnelService",
+ HandlerType: (*TunnelServiceServer)(nil),
+ Methods: []grpc.MethodDesc{},
+ Streams: []grpc.StreamDesc{
+ {
+ StreamName: "Tunnel",
+ Handler: _TunnelService_Tunnel_Handler,
+ ServerStreams: true,
+ ClientStreams: true,
+ },
+ },
+ Metadata: "tunnel.proto",
+}
diff --git a/tools/tcp_grpc_proxy/proxy/tcp2grpc.go b/tools/tcp_grpc_proxy/proxy/tcp2grpc.go
new file mode 100644
index 000000000..63b5586b8
--- /dev/null
+++ b/tools/tcp_grpc_proxy/proxy/tcp2grpc.go
@@ -0,0 +1,104 @@
+package proxy
+
+import (
+ "context"
+ "io"
+ "net"
+ "tcp_grpc_proxy/proxy/proto"
+
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+)
+
+// TCP2GrpcServer to proxy TCP traffic to gRPC
+type TCP2GrpcServer struct {
+ tcpServerAddress string
+ targetGrpcAddress string
+}
+
+// NewTCP2GrpcServer constructs a TCP2GrpcServer
+func NewTCP2GrpcServer(tcpServerAddress, targetGrpcAddress string) *TCP2GrpcServer {
+ return &TCP2GrpcServer{
+ tcpServerAddress: tcpServerAddress,
+ targetGrpcAddress: targetGrpcAddress,
+ }
+}
+
+func handleTCPConn(tcpConn net.Conn, targetGrpcAddress string) {
+ logrus.Infoln("Handle tcp connection, target to:", targetGrpcAddress)
+ defer tcpConn.Close()
+
+ grpcConn, err := grpc.Dial(targetGrpcAddress, grpc.WithInsecure())
+ if err != nil {
+ logrus.Errorf("Error during connect to grpc %s: %v", targetGrpcAddress, err)
+ return
+ }
+ defer grpcConn.Close()
+
+ grpcClient := proto.NewTunnelServiceClient(grpcConn)
+ stream, err := grpcClient.Tunnel(context.Background())
+ if err != nil {
+ logrus.Errorf("Error of tunnel service: %v", err)
+ return
+ }
+
+ // Gets data from remote gRPC server and proxy to TCP client
+ go func() {
+ for {
+ chunk, err := stream.Recv()
+ if err != nil {
+ logrus.Errorf("Recv from grpc target %s terminated: %v", targetGrpcAddress, err)
+ return
+ }
+ logrus.Infof("Sending %d bytes to TCP client", len(chunk.Data))
+ tcpConn.Write(chunk.Data)
+ }
+ }()
+
+ // Gets data from TCP client and proxy to remote gRPC server
+ func() {
+ for {
+ tcpData := make([]byte, 64*1024)
+ bytesRead, err := tcpConn.Read(tcpData)
+
+ if err == io.EOF {
+ logrus.Infoln("Connection finished")
+ return
+ }
+ if err != nil {
+ logrus.Errorf("Read from tcp error: %v", err)
+ return
+ }
+ logrus.Infof("Sending %d bytes to gRPC server", bytesRead)
+ if err := stream.Send(&proto.Chunk{Data: tcpData[0:bytesRead]}); err != nil {
+ logrus.Errorf("Failed to send gRPC data: %v", err)
+ return
+ }
+ }
+ }()
+
+ // If tcp connection gets closed, then we close the gRPC connection.
+ stream.CloseSend()
+ return
+}
+
+// Run Starts the server
+func (s *TCP2GrpcServer) Run() {
+ listener, err := net.Listen("tcp", s.tcpServerAddress)
+ if err != nil {
+ logrus.Fatalln("Listen TCP error: ", err)
+ }
+ defer listener.Close()
+ logrus.Infoln("Run TCPServer at ", s.tcpServerAddress)
+
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ logrus.Errorln("TCP listener error:", err)
+ continue
+ }
+
+ logrus.Infoln("Got tcp connection")
+ go handleTCPConn(conn, s.targetGrpcAddress)
+ }
+}
diff --git a/tools/tcp_grpc_proxy/start_proxy.sh b/tools/tcp_grpc_proxy/start_proxy.sh
new file mode 100644
index 000000000..6fb20a114
--- /dev/null
+++ b/tools/tcp_grpc_proxy/start_proxy.sh
@@ -0,0 +1,60 @@
+#! /bin/bash
+set -ex
+
+# The chain of the traffic:
+# TCP client -> out TCP server -> out gRPC server -> Nginx ->
+# network -> remote grpc server (Nginx) -> in gRPC server -> in TCP server
+OUT_TCP_SERVER_PORT=17767
+OUT_GRPC_SERVER_PORT=17768
+IN_GRPC_SERVER_PORT=17769
+IN_TCP_SERVER_PORT=7766
+
+REMOTE_GRPC_SERVER_HOST=1.1.1.1
+REMOTE_GRPC_SERVER_PORT=17771
+
+echo "
+upstream remote_grpc_server {
+ server ${REMOTE_GRPC_SERVER_HOST}:${REMOTE_GRPC_SERVER_PORT};
+}
+
+# Proxies to remote grpc server
+server {
+ listen ${OUT_GRPC_SERVER_PORT} http2;
+
+ # No limits
+ client_max_body_size 0;
+ grpc_read_timeout 3600s;
+ grpc_send_timeout 3600s;
+ client_body_timeout 3600s;
+ # grpc_socket_keepalive is recommended but not required
+ # grpc_socket_keepalive is supported after nginx 1.15.6
+ grpc_socket_keepalive on;
+ location / {
+ # change grpc to grpcs if ssl is used
+ grpc_pass grpc://remote_grpc_server;
+ }
+}
+
+# Listens grpc traffic, this port should be public
+server {
+ listen ${REMOTE_GRPC_SERVER_PORT} http2;
+
+ # No limits
+ client_max_body_size 0;
+ grpc_read_timeout 3600s;
+ grpc_send_timeout 3600s;
+ client_body_timeout 3600s;
+ grpc_socket_keepalive on;
+ location / {
+ grpc_pass grpc://localhost:${IN_GRPC_SERVER_PORT};
+ }
+}
+" > nginx.conf
+cp nginx.conf /etc/nginx/conf.d/nginx.conf
+service nginx restart
+
+./tcp2grpc --tcp_server_port="$OUT_TCP_SERVER_PORT" \
+ --target_grpc_address="localhost:$OUT_GRPC_SERVER_PORT" &
+
+./grpc2tcp --grpc_server_port="$IN_GRPC_SERVER_PORT" \
+ --target_tcp_address="localhost:$IN_TCP_SERVER_PORT" &
diff --git a/web_console_v2/.dockerignore b/web_console_v2/.dockerignore
index f6edae1e9..e74466cc6 100644
--- a/web_console_v2/.dockerignore
+++ b/web_console_v2/.dockerignore
@@ -4,5 +4,5 @@ Dockerfile
# Tests
client/tests
-api/test
+api/tests
api/testing
diff --git a/web_console_v2/BUILD.bazel b/web_console_v2/BUILD.bazel
new file mode 100644
index 000000000..e4c69406a
--- /dev/null
+++ b/web_console_v2/BUILD.bazel
@@ -0,0 +1,13 @@
+filegroup(
+ name = "srcs",
+ srcs = [
+ ".dockerignore",
+ ".gitignore",
+ "Dockerfile",
+ "README.md",
+ "nginx.conf",
+ "run_dev.sh",
+ "run_prod.sh",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/web_console_v2/Dockerfile b/web_console_v2/Dockerfile
index 0254897a4..6d856b5b4 100644
--- a/web_console_v2/Dockerfile
+++ b/web_console_v2/Dockerfile
@@ -1,11 +1,13 @@
-FROM python:3.7
+FROM python:3.6.8
RUN apt-get update && \
apt install -y curl && \
# For nodejs PA
curl -sL https://deb.nodesource.com/setup_14.x | bash && \
+ # For krb5-user installation
+ export DEBIAN_FRONTEND=noninteractive && \
# Install dependencies
- apt-get install -y make nodejs nginx && \
+ apt-get install -y make nodejs nginx krb5-user cron && \
apt-get clean
WORKDIR /app
@@ -14,14 +16,15 @@ COPY . .
# Builds frontend
WORKDIR /app/client
-RUN npx pnpm install && npx pnpm build && rm -rf node_modules
+RUN npx pnpm@6.4.0 install && npx pnpm@6.4.0 build && rm -rf node_modules
# Builds backend
WORKDIR /app/api
RUN pip3 install --no-cache-dir -r requirements.txt && make protobuf
+WORKDIR /app
# Nginx configuration
-COPY nginx.conf /etc/nginx/conf.d/nginx.conf
+RUN cp nginx.conf /etc/nginx/conf.d/nginx.conf
# Port for webconsole http server
EXPOSE 1989
@@ -29,19 +32,7 @@ EXPOSE 1989
# This should not be exposed in PROD
EXPOSE 1990
-# Install vscode
-RUN curl -fOL https://github.com/cdr/code-server/releases/download/v3.8.0/code-server_3.8.0_amd64.deb && \
- dpkg -i code-server_3.8.0_amd64.deb && \
- rm code-server_3.8.0_amd64.deb && \
- mkdir -p ~/.config/code-server/ && \
- echo 'bind-addr: 0.0.0.0:1992\n\
-auth: password\n\
-password: fedlearner\n\
-cert: false\n' >> ~/.config/code-server/config.yaml
-
-# Port for VScode
-EXPOSE 1992
ENV TZ="Asia/Shanghai"
WORKDIR /app
-CMD sh run_prod.sh
+CMD bash run_prod.sh
diff --git a/web_console_v2/Makefile b/web_console_v2/Makefile
deleted file mode 100644
index 9fb779f40..000000000
--- a/web_console_v2/Makefile
+++ /dev/null
@@ -1,8 +0,0 @@
-api-test:
- cd api && \
- make protobuf && \
- make lint && \
- make test
-
-docker-spark:
- cd ./docker/spark && docker build . -t spark-tfrecord:latest
\ No newline at end of file
diff --git a/web_console_v2/api/.gitignore b/web_console_v2/api/.gitignore
index b42ddd478..9f55327c4 100644
--- a/web_console_v2/api/.gitignore
+++ b/web_console_v2/api/.gitignore
@@ -4,9 +4,10 @@
# Generated proto python code
fedlearner_webconsole/proto/*.py
fedlearner_webconsole/proto/*.pyi
+fedlearner_webconsole/proto/testing/
# Coverage generated
.coverage_html_report/
.coverage*
-root.log.*
\ No newline at end of file
+root.log.*
diff --git a/web_console_v2/api/.style.yapf b/web_console_v2/api/.style.yapf
new file mode 100644
index 000000000..b3d849f2d
--- /dev/null
+++ b/web_console_v2/api/.style.yapf
@@ -0,0 +1,3 @@
+[style]
+based_on_style = google
+column_limit = 120
diff --git a/web_console_v2/api/.yapfignore b/web_console_v2/api/.yapfignore
new file mode 100644
index 000000000..eefeb4275
--- /dev/null
+++ b/web_console_v2/api/.yapfignore
@@ -0,0 +1,2 @@
+migrations/
+fedlearner_webconsole/proto/
diff --git a/web_console_v2/api/BUILD.bazel b/web_console_v2/api/BUILD.bazel
new file mode 100644
index 000000000..f20b92936
--- /dev/null
+++ b/web_console_v2/api/BUILD.bazel
@@ -0,0 +1,150 @@
+load("@rules_python//python:defs.bzl", "py_binary", "py_library")
+
+package(default_visibility = [":console_api_package"])
+
+package_group(
+ name = "console_api_package",
+ packages = ["//web_console_v2/api/..."],
+)
+
+filegroup(
+ name = "srcs",
+ srcs = [
+ "Makefile",
+ "README.md",
+ ],
+)
+
+py_library(
+ name = "checks_lib",
+ srcs = ["checks.py"],
+ imports = ["."],
+ deps = [":envs_lib"],
+)
+
+py_library(
+ name = "command_lib",
+ srcs = [
+ "command.py",
+ "es_configuration.py",
+ ],
+ imports = ["."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:app_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:es_lib",
+ "//web_console_v2/api/tools:lib",
+ "@common_click//:pkg",
+ "@common_elasticsearch//:pkg",
+ "@common_flask_migrate//:pkg",
+ "@common_requests//:pkg",
+ ],
+)
+
+py_library(
+ name = "config_lib",
+ srcs = ["config.py"],
+ imports = ["."],
+ deps = [
+ ":envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ ],
+)
+
+py_library(
+ name = "envs_lib",
+ srcs = ["envs.py"],
+ imports = ["."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_pytz//:pkg",
+ ],
+)
+
+py_test(
+ name = "envs_lib_test",
+ size = "small",
+ srcs = [
+ "envs_test.py",
+ ],
+ imports = [".."],
+ main = "envs_test.py",
+ deps = [
+ ":envs_lib",
+ ],
+)
+
+py_library(
+ name = "logging_config_lib",
+ srcs = ["logging_config.py"],
+ imports = ["."],
+ deps = [
+ ":envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:log_filter_lib",
+ ],
+)
+
+py_binary(
+ name = "rpc_server_bin",
+ srcs = ["rpc_server.py"],
+ imports = ["."],
+ main = "rpc_server.py",
+ deps = [
+ ":envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:hooks_lib",
+ ],
+)
+
+py_binary(
+ name = "composer_bin",
+ srcs = ["composer.py"],
+ imports = ["."],
+ main = "composer.py",
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:hooks_lib",
+ ],
+)
+
+py_library(
+ name = "server_lib",
+ srcs = ["server.py"],
+ imports = ["."],
+ deps = [
+ ":checks_lib",
+ ":config_lib",
+ ":envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:app_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:middlewares_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:hooks_lib",
+ "@common_flask//:pkg",
+ "@common_gunicorn//:pkg",
+ ],
+)
+
+filegroup(
+ name = "gunicorn_config",
+ srcs = [
+ "gunicorn_config.py",
+ ],
+)
+
+py_binary(
+ name = "entrypoint_bin",
+ srcs = ["entrypoint.py"],
+ imports = ["."],
+ main = "entrypoint.py",
+ deps = [
+ ":envs_lib",
+ ":server_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:hooks_lib",
+ "@common_flask//:pkg",
+ "@common_gunicorn//:pkg",
+ ],
+)
diff --git a/web_console_v2/api/Makefile b/web_console_v2/api/Makefile
index 997047b5f..a306d70e7 100644
--- a/web_console_v2/api/Makefile
+++ b/web_console_v2/api/Makefile
@@ -1,36 +1,16 @@
export PYTHONPATH:=${PWD}:$(PYTHONPATH)
-.PHONY: test unit-test-all unit-test protobuf
-
-lint:
- pylint --rcfile ./ci/pylintrc --load-plugins pylint_quotes fedlearner_webconsole
+clean:
+ rm -f err.out && \
+ find ./ -type f \( -name "*.db" -o -name "*.log" \) -exec rm -f {} \;
protobuf:
+ PATH=${PATH}:${PWD}/bin/$(shell uname) \
python -m grpc_tools.protoc -I protocols \
--python_out=. \
--grpc_python_out=. \
--mypy_out=. \
- protocols/fedlearner_webconsole/proto/*.proto
-
-UNIT_TEST_SCRIPTS := $(shell find test/ -type f -name "*_test.py")
-UNIT_TEST_SCRIPTS_REGEX := $(shell find test/$(FOLDER) -type f -name "$(REG)*.py")
-UNIT_TESTS := $(UNIT_TEST_SCRIPTS:%.py=%.phony)
-UNIT_TESTS_REGEX := $(UNIT_TEST_SCRIPTS_REGEX:%.py=%.phony)
-
-test/%.phony: test/%.py
- python $^
-
-unit-test-all: protobuf $(UNIT_TESTS)
-
-# run unit test with optional $FOLDER and $REG parameter to limit the number of
-# running tests.
-# Sample: make unit-test FOLDER="/fedlearner_webconsole/utils" REG="file*"
-unit-test: protobuf $(UNIT_TESTS_REGEX)
-
-cli-test:
- FLASK_APP=command:app flask routes
-
-test: unit-test-all cli-test
-
-clean:
- find ./ -type f \( -name "*.db" -o -name "*.log" \) -exec rm -f {} \;
+ --jsonschema_out=prefix_schema_files_with_package,disallow_additional_properties:fedlearner_webconsole/proto/jsonschemas \
+ protocols/fedlearner_webconsole/proto/*.proto \
+ protocols/fedlearner_webconsole/proto/**/*.proto \
+ protocols/fedlearner_webconsole/proto/rpc/v2/*.proto
diff --git a/web_console_v2/api/README.md b/web_console_v2/api/README.md
index c46e853dd..30a737fec 100644
--- a/web_console_v2/api/README.md
+++ b/web_console_v2/api/README.md
@@ -2,73 +2,126 @@
## Prerequisites
-* GNU Make
-* Python3
+* Bazel
* MySQL 8.0
+* Docker
## Get started
-```
-python3 -m venv
-source /bin/activate
-pip3 install -r requirements.txt
+Starting development by using fake k8s (no actual data).
+
+start all the processes
-# Generates python code for proto
-make protobuf
+```bash
+bazelisk run //web_console_v2/api/cmds:run_dev
+```
-# Use MySQL, please create database in advance, then set
-# SQLALCHEMY_DATABASE_URI, for example as follows
-export SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@localhost:33600/fedlearner
+optionally if you want to stop or restart one of the processes
-# Creates schemas for DB
-FLASK_APP=command:app flask db upgrade
+```bash
+bazelisk run //web_console_v2/api/cmds:supervisorctl_cli_bin -- -s unix:///tmp/supervisor.sock
+```
-# Creates initial user
-FLASK_APP=command:app flask create-initial-data
+## Develop with remote k8s cluster
-# Starts the server
-export FLASK_ENV=development
-flask run
+```bash
+# Changes configs in tools/local_runner/app_a.py or app_b.py
+bash tools/local_runner/run_a.sh
+bash tools/local_runner/run_b.sh
```
## Tests
### Unit tests
-```
-cd
-make unit-test
+```bash
+bazelisk test //web_console_v2/api/... --config lint
```
## Helpers
### Gets all routes
-```
-FLASK_APP=command:app flask routes
+
+```bash
+FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ bazelisk run //web_console_v2/api/cmds:flask_cli_bin -- routes
```
### Add migration files
-```
-FLASK_APP=command:app flask db migrate -m "Whats' changed"
+```bash
+FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ bazelisk run //web_console_v2/api/cmds:flask_cli_bin -- db migrate -m "Whats' changed" -d web_console_v2/api/migrations
+
# like dry-run mode, preview auto-generated SQL
-FLASK_APP=command:app flask db upgrade --sql
+FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ bazelisk run //web_console_v2/api/cmds:flask_cli_bin -- db upgrade --sql -d web_console_v2/api/migrations
+
# update database actually
-FLASK_APP=command:app flask db upgrade
+FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ bazelisk run //web_console_v2/api/cmds:flask_cli_bin -- db upgrade -d web_console_v2/api/migrations
```
### Reset migration files
Delete migrations folder first.
+
+```bash
+FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ bazelisk run //web_console_v2/api/cmds:flask_cli_bin -- db init -d web_console_v2/api/migrations
+
+FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ bazelisk run //web_console_v2/api/cmds:flask_cli_bin -- db migrate -m "Initial migration." -d web_console_v2/api/migrations
+```
+
+### Cleanup project
+
+```bash
+FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ bazelisk run //web_console_v2/api/cmds:flask_cli_bin -- cleanup-project
```
-FLASK_APP=command:app flask db init
-FLASK_APP=command:app flask db migrate -m "Initial migration."
+
+## 规范 & 风格
+
+### [Style guide](docs/style_guide.md)
+
+### Code formatter
+
+We use [yapf](https://github.com/google/yapf) to format our code, style is defined in `.style.yapf`.
+
+To check the format, please run:
+
+```bash
+bazelisk test --config lint
```
-## [Style guide](docs/style_guide.md)
-## [Best practices](docs/best_practices.md)
+To fix the errors, please run:
+
+```bash
+bazelisk test --config fix
+```
+
+### [gRPC](docs/grpc.md)
+
+## 最佳实践
+
+### [数据库相关最佳实践](docs/best_practices/db.md)
+
+### [API层最佳实践](docs/best_practices.md)
+
+### [客户端-服务端模型最佳实践](docs/best_practices/client_server.md)
+
+### [多进程最佳实践](docs/best_practices/multiprocess.md)
## References
### Default date time in sqlalchemy
+
https://stackoverflow.com/questions/13370317/sqlalchemy-default-datetime/33532154#33532154
diff --git a/web_console_v2/api/bin/Darwin/protoc-gen-jsonschema b/web_console_v2/api/bin/Darwin/protoc-gen-jsonschema
new file mode 100755
index 000000000..7e38eee5a
Binary files /dev/null and b/web_console_v2/api/bin/Darwin/protoc-gen-jsonschema differ
diff --git a/web_console_v2/api/bin/Linux/protoc-gen-jsonschema b/web_console_v2/api/bin/Linux/protoc-gen-jsonschema
new file mode 100755
index 000000000..b8935f7fb
Binary files /dev/null and b/web_console_v2/api/bin/Linux/protoc-gen-jsonschema differ
diff --git a/web_console_v2/api/checks.py b/web_console_v2/api/checks.py
new file mode 100644
index 000000000..ae90cac86
--- /dev/null
+++ b/web_console_v2/api/checks.py
@@ -0,0 +1,24 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+from envs import Envs
+
+
+def validity_check():
+ error_msg = Envs.check()
+ if error_msg:
+ print(f'Validity check failed: {error_msg}')
+ sys.exit(1)
diff --git a/web_console_v2/api/ci/pylintrc b/web_console_v2/api/ci/pylintrc
deleted file mode 100644
index 13af12998..000000000
--- a/web_console_v2/api/ci/pylintrc
+++ /dev/null
@@ -1,434 +0,0 @@
-[MASTER]
-
-# Specify a configuration file.
-#rcfile=
-
-# Python code to execute, usually for sys.path manipulation such as
-# pygtk.require().
-#init-hook=
-
-# Add files or directories to the blacklist. They should be base names, not
-# paths.
-ignore=CVS
-
-# Add files or directories matching the regex patterns to the blacklist. The
-# regex matches against base names, not paths.
-ignore-patterns=.*pb2.*
-
-# Pickle collected data for later comparisons.
-persistent=yes
-
-# List of plugins (as comma separated values of python modules names) to load,
-# usually to register additional checkers.
-load-plugins=
-
-# Use multiple processes to speed up Pylint.
-jobs=1
-
-# Allow loading of arbitrary C extensions. Extensions are imported into the
-# active Python interpreter and may run arbitrary code.
-unsafe-load-any-extension=no
-
-# A comma-separated list of package or module names from where C extensions may
-# be loaded. Extensions are loading into the active Python interpreter and may
-# run arbitrary code
-extension-pkg-whitelist=
-
-
-[MESSAGES CONTROL]
-
-# Only show warnings with the listed confidence levels. Leave empty to show
-# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
-confidence=
-
-# Enable the message, report, category or checker with the given id(s). You can
-# either give multiple identifier separated by comma (,) or put this option
-# multiple time. See also the "--disable" option for examples.
-#enable=
-
-# Disable the message, report, category or checker with the given id(s). You
-# can either give multiple identifiers separated by comma (,) or put this
-# option multiple times (only on the command line, not in the configuration
-# file where it should appear only once).You can also use "--disable=all" to
-# disable everything first and then reenable specific checks. For example, if
-# you want to run only the similarities checker, you can use "--disable=all
-# --enable=similarities". If you want to run only the classes checker, but have
-# no Warning level messages displayed, use"--disable=all --enable=classes
-# --disable=W"
-#
-# -----------------------------------------------------------------------
-# 2015-01-12 - What follows is the list of all disabled items necessary
-# to get a clean run of lint across CourseBuilder. These are separated
-# into three tiers:
-#
-# - Fix-worthy. This includes:
-# - Probable bugs
-# - Easily-addressed hygiene issues,
-# - Real warnings which we may mark as suppressed on a case-by-case basis.
-# - Items that are questionable practice, but not necessarily economical to fix.
-# - Items that we intend to ignore, as we do not consider them bad practice.
-#
-# Warning messages are documented at http://docs.pylint.org/features.html
-#
-# ----------------------------------------------------------------------
-# Fix-worthy:
-#
-# ---- Possible bugs:
-# disable=super-on-old-class
-# disable=arguments-differ (# of arguments to overriding/overridden method)
-# disable=signature-differs
-# disable=method-hidden
-# disable=abstract-method (Abstract method not overridden in derived class)
-# disable=no-member (self.foo used when foo not declared in class)
-#
-# ---- Easy-to-fix and improves readability, cleanliness:
-# disable=relative-import
-#
-# ---- Probably legitimate, but needs markup to indicate intentionality
-# disable=no-init (Class does not have __init__, nor do ancestor classes)
-# disable=import-error
-# disable=attribute-defined-outside-init
-#
-# ----------------------------------------------------------------------
-# Fix when economical:
-#
-# ---- Minor code cleanliness problems; fix when encountered.
-# disable=unused-argument
-# disable=unused-variable
-# disable=invalid-name (Variable name does not meet coding standard)
-# disable=duplicate-code
-#
-# ---- Laundry list of tunable parameters for when things are too big/small
-# disable=abstract-class-little-used
-# disable=too-few-public-methods
-# disable=too-many-instance-attributes
-# disable=too-many-ancestors
-# disable=too-many-return-statements
-# disable=too-many-lines
-# disable=too-many-locals
-# disable=too-many-function-args
-# disable=too-many-public-methods
-# disable=too-many-arguments
-#
-# ----------------------------------------------------------------------
-# Ignored; OK by our coding standard:
-#
-# disable=bad-continuation (Bad whitespace on following line)
-# disable=no-self-use (Member function never uses 'self' parameter)
-# disable=missing-docstring
-# disable=fixme
-# disable=star-args
-# disable=locally-disabled (Notes local suppression of warning)
-# disable=locally-enabled (Notes re-enable of suppressed warning)
-# disable=bad-option-value (Notes suppression of unknown warning)
-# disable=abstract-class-not-used (Warns when not used in same file)
-#
-# Unfortunately, since the options parsing does not support multi-line entries
-# nor line continuation, all of the above items are redundantly specified here
-# in a way that pylint is willing to parse.
-disable=super-on-old-class,arguments-differ,signature-differs,method-hidden,abstract-method,no-member,relative-import,no-init,import-error,attribute-defined-outside-init,abstract-class-not-used,unused-argument,unused-variable,invalid-name,duplicate-code,abstract-class-little-used,too-few-public-methods,too-many-instance-attributes,too-many-ancestors,too-many-return-statements,too-many-lines,too-many-locals,too-many-function-args,too-many-public-methods,too-many-arguments,bad-continuation,no-self-use,missing-docstring,fixme,star-args,locally-disabled,locally-enabled,bad-option-value,useless-object-inheritance,logging-format-interpolation
-
-[REPORTS]
-
-# Set the output format. Available formats are text, parseable, colorized, msvs
-# (visual studio) and html. You can also give a reporter class, eg
-# mypackage.mymodule.MyReporterClass.
-output-format=text
-
-# Put messages in a separate file for each module / package specified on the
-# command line instead of printing them on stdout. Reports (if any) will be
-# written in a file name "pylint_global.[txt|html]".
-files-output=no
-
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Python expression which should return a note less than 10 (10 is the highest
-# note). You have access to the variables errors warning, statement which
-# respectively contain the number of errors / warnings messages and the total
-# number of statements analyzed. This is used by the global evaluation report
-# (RP0004).
-evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
-
-# Template used to display messages. This is a python new-style format string
-# used to format the message information. See doc for all details
-#msg-template=
-
-
-[SPELLING]
-
-# Spelling dictionary name. Available dictionaries: none. To make it working
-# install python-enchant package.
-spelling-dict=
-
-# List of comma separated words that should not be checked.
-spelling-ignore-words=
-
-# A path to a file that contains private dictionary; one word per line.
-spelling-private-dict-file=
-
-# Tells whether to store unknown words to indicated private dictionary in
-# --spelling-private-dict-file option instead of raising a message.
-spelling-store-unknown-words=no
-
-
-[SIMILARITIES]
-
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-
-# Ignore imports when computing similarities.
-ignore-imports=no
-
-
-[LOGGING]
-
-# Logging modules to check that the string format arguments are in logging
-# function parameter format
-logging-modules=logging
-
-
-[MISCELLANEOUS]
-
-# List of note tags to take in consideration, separated by a comma.
-notes=FIXME,XXX,TODO
-
-
-[VARIABLES]
-
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
-
-# A regular expression matching the name of dummy variables (i.e. expectedly
-# not used).
-dummy-variables-rgx=_$|dummy
-
-# List of additional names supposed to be defined in builtins. Remember that
-# you should avoid to define new builtins when possible.
-additional-builtins=
-
-# List of strings which can identify a callback function by name. A callback
-# name must start or end with one of those strings.
-callbacks=cb_,_cb
-
-
-[TYPECHECK]
-
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members=yes
-
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis
-ignored-modules=
-
-# List of classes names for which member attributes should not be checked
-# (useful for classes with attributes dynamically set).
-ignored-classes=SQLObject
-
-# List of members which are set dynamically and missed by pylint inference
-# system, and so shouldn't trigger E0201 when accessed. Python regular
-# expressions are accepted.
-generated-members=REQUEST,acl_users,aq_parent
-
-
-[BASIC]
-
-# List of builtins function names that should not be used, separated by a comma
-bad-functions=map,filter,input
-
-# Good variable names which should always be accepted, separated by a comma
-good-names=i,j,k,ex,Run,_
-
-# Bad variable names which should always be refused, separated by a comma
-bad-names=foo,bar,baz,toto,tutu,tata
-
-# Colon-delimited sets of names that determine each other's naming style when
-# the name regexes allow several styles.
-name-group=
-
-# Include a hint for the correct naming format with invalid-name
-include-naming-hint=no
-
-# Regular expression matching correct function names
-function-rgx=[a-z_][a-z0-9_]{2,50}$
-
-# Naming hint for function names
-function-name-hint=[a-z_][a-z0-9_]{2,50}$
-
-# Regular expression matching correct variable names
-variable-rgx=[a-z_][a-z0-9_]{1,30}$
-
-# Naming hint for variable names
-variable-name-hint=[a-z_][a-z0-9_]{2,30}$
-
-# Regular expression matching correct constant names
-const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
-
-# Naming hint for constant names
-const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
-
-# Regular expression matching correct attribute names
-attr-rgx=[a-z_][a-z0-9_]{2,30}$
-
-# Naming hint for attribute names
-attr-name-hint=[a-z_][a-z0-9_]{2,30}$
-
-# Regular expression matching correct argument names
-argument-rgx=[a-z_][a-z0-9_]{2,30}$
-
-# Naming hint for argument names
-argument-name-hint=[a-z_][a-z0-9_]{2,30}$
-
-# Regular expression matching correct class attribute names
-class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
-
-# Naming hint for class attribute names
-class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
-
-# Regular expression matching correct inline iteration names
-inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
-
-# Naming hint for inline iteration names
-inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
-
-# Regular expression matching correct class names
-class-rgx=[A-Z_][a-zA-Z0-9]+$
-
-# Naming hint for class names
-class-name-hint=[A-Z_][a-zA-Z0-9]+$
-
-# Regular expression matching correct module names
-module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
-
-# Naming hint for module names
-module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
-
-# Regular expression matching correct method names
-method-rgx=[a-z_][a-z0-9_]{2,30}$
-
-# Naming hint for method names
-method-name-hint=[a-z_][a-z0-9_]{2,30}$
-
-# Regular expression which should only match function or class names that do
-# not require a docstring.
-no-docstring-rgx=__.*__
-
-# Minimum line length for functions/classes that require docstrings, shorter
-# ones are exempt.
-docstring-min-length=-1
-
-
-[FORMAT]
-
-# Maximum number of characters on a single line.
-max-line-length=80
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=^\s*(# )??$
-
-# Allow the body of an if to be on the same line as the test if there is no
-# else.
-single-line-if-stmt=no
-
-# List of optional constructs for which whitespace checking is disabled
-no-space-check=trailing-comma,dict-separator
-
-# Maximum number of lines in a module
-max-module-lines=2000
-
-# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
-# tab).
-indent-string=' '
-
-# Number of spaces of indent required inside a hanging or continued line.
-indent-after-paren=4
-
-# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
-expected-line-ending-format=
-
-
-[IMPORTS]
-
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules=regsub,TERMIOS,Bastion,rexec
-
-# Create a graph of every (i.e. internal and external) dependencies in the
-# given file (report RP0402 must not be disabled)
-import-graph=
-
-# Create a graph of external dependencies in the given file (report RP0402 must
-# not be disabled)
-ext-import-graph=
-
-# Create a graph of internal dependencies in the given file (report RP0402 must
-# not be disabled)
-int-import-graph=
-
-
-[CLASSES]
-
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,__new__,setUp
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls
-
-# List of valid names for the first argument in a metaclass class method.
-valid-metaclass-classmethod-first-arg=mcs
-
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,_fields,_replace,_source,_make
-
-
-[DESIGN]
-
-# Maximum number of arguments for function / method
-max-args=12
-
-# Argument names that match this expression will be ignored. Default to name
-# with leading underscore
-ignored-argument-names=_.*
-
-# Maximum number of locals for function / method body
-max-locals=25
-
-# Maximum number of return / yield for function / method body
-max-returns=6
-
-# Maximum number of branch for function / method body
-max-branches=40
-
-# Maximum number of statements in function / method body
-max-statements=105
-
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-
-# Maximum number of attributes for a class (see R0902).
-max-attributes=7
-
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
-
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=50
-
-# Set the linting for string quotes
-string-quote=single
-triple-quote=double
-docstring-quote=double
-
-[EXCEPTIONS]
-
-# Exceptions that will emit a warning when being caught. Defaults to
-# "Exception"
-overgeneral-exceptions=Exception
diff --git a/web_console_v2/api/cmds/BUILD.bazel b/web_console_v2/api/cmds/BUILD.bazel
new file mode 100644
index 000000000..3fae43d8e
--- /dev/null
+++ b/web_console_v2/api/cmds/BUILD.bazel
@@ -0,0 +1,100 @@
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_binary(
+ name = "flask_cli_bin",
+ srcs = [
+ "flask_cli.py",
+ ],
+ data = [
+ "//web_console_v2/api/migrations",
+ ],
+ main = "flask_cli.py",
+ deps = [
+ # THIS IS A HACK!!!
+ # Although `//web_console_v2/api:command_lib` is not directly used in `flask_cli.py`, we have to `deps` it for discovering python dependencies at runtime.
+ "//web_console_v2/api:command_lib",
+ "@common_flask//:pkg",
+ ],
+)
+
+py_binary(
+ name = "gunicorn_cli_bin",
+ srcs = [
+ "gunicorn_cli.py",
+ ],
+ data = [
+ "//web_console_v2/api:gunicorn_config",
+ ],
+ main = "gunicorn_cli.py",
+ deps = [
+ # THIS IS A HACK!!!
+ # Although `//web_console_v2/api:server_lib"` is not directly used in `gunicorn_cli.py`, we have to `deps` it for discovering python dependencies at runtime.
+ "//web_console_v2/api:server_lib",
+ "@common_gunicorn//:pkg",
+ ],
+)
+
+py_binary(
+ name = "supervisorctl_cli_bin",
+ srcs = [
+ "supervisorctl_cli.py",
+ ],
+ data = [
+ "supervisord.conf",
+ ],
+ main = "supervisorctl_cli.py",
+ deps = [
+ "@common_supervisor//:pkg",
+ ],
+)
+
+py_binary(
+ name = "supervisord_cli_bin",
+ srcs = [
+ "supervisord_cli.py",
+ ],
+ data = [
+ "supervisord.conf",
+ ],
+ main = "supervisord_cli.py",
+ deps = [
+ "@common_supervisor//:pkg",
+ ],
+)
+
+filegroup(
+ name = "runtime_env",
+ srcs = [
+ "runtime_env.sh",
+ ],
+)
+
+sh_binary(
+ name = "run_prod",
+ srcs = [
+ "run_prod.sh",
+ ],
+ data = [
+ ":flask_cli_bin",
+ ":gunicorn_cli_bin",
+ ":runtime_env",
+ ":supervisorctl_cli_bin",
+ ":supervisord_cli_bin",
+ "//web_console_v2/api:entrypoint_bin",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+sh_binary(
+ name = "run_dev",
+ srcs = [
+ "run_dev.sh",
+ ],
+ data = [
+ "supervisord_dev.conf",
+ ":runtime_env",
+ "//web_console_v2/api:entrypoint_bin",
+ "//web_console_v2/api/cmds:flask_cli_bin",
+ "//web_console_v2/api/cmds:supervisord_cli_bin",
+ ],
+)
diff --git a/web_console_v2/api/cmds/flask_cli.py b/web_console_v2/api/cmds/flask_cli.py
new file mode 100644
index 000000000..b60819485
--- /dev/null
+++ b/web_console_v2/api/cmds/flask_cli.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+from flask.cli import main
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/web_console_v2/api/cmds/gunicorn_cli.py b/web_console_v2/api/cmds/gunicorn_cli.py
new file mode 100644
index 000000000..6c67f60d9
--- /dev/null
+++ b/web_console_v2/api/cmds/gunicorn_cli.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+from gunicorn.app.wsgiapp import run
+
+if __name__ == '__main__':
+ sys.exit(run())
diff --git a/web_console_v2/api/cmds/run_dev.sh b/web_console_v2/api/cmds/run_dev.sh
new file mode 100755
index 000000000..c7c6cde27
--- /dev/null
+++ b/web_console_v2/api/cmds/run_dev.sh
@@ -0,0 +1,43 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/bin/bash
+set -e
+
+# This script is designed for triggering by bazel run.
+# So it can be only executed in root of our workspace.
+[[ ! ${PWD} =~ .*privacy_computing_platform ]] && echo "this scripts should be executed in root of workspace"; exit 1;
+
+function flask_command {
+ FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ web_console_v2/api/cmds/flask_cli_bin \
+ $*
+}
+
+export SYSTEM_INFO="{\"domain_name\": \"dev.fedlearner.net\", \"name\": \"Dev\"}"
+export APM_SERVER_ENDPOINT=stdout
+export FLASK_ENV=development
+
+# set runtime env
+source web_console_v2/api/cmds/runtime_env.sh
+
+# Migrates DB schemas
+flask_command create-db
+# Loads initial data
+flask_command create-initial-data
+# Runs Api Composer and gRPC service
+web_console_v2/api/cmds/supervisord_cli_bin \
+ -c web_console_v2/api/cmds/supervisord_dev.conf --nodaemon
diff --git a/web_console_v2/api/cmds/run_prod.sh b/web_console_v2/api/cmds/run_prod.sh
new file mode 100755
index 000000000..cdd6fcf92
--- /dev/null
+++ b/web_console_v2/api/cmds/run_prod.sh
@@ -0,0 +1,72 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/bin/bash
+set -e
+
+# Add hook into pythonpath
+export PYTHONPATH=$PYTHONPATH:$PWD
+# We should find if $0.runfiles exists which is runtime files exists.
+[[ -d "$0.runfiles" ]] && cd $0.runfiles/privacy_computing_platform
+
+# link all deps inside runfiles to $workspace/external
+if [[ ! -d "external" ]]
+then
+ echo "linking symbolic into external..."
+ mkdir external
+ ls ../ | grep -v privacy_computing_platform | xargs -IX ln -s $PWD/../X $PWD/external/X
+fi
+
+function flask_command {
+ FLASK_APP=web_console_v2/api/command:app \
+ APM_SERVER_ENDPOINT=/dev/null \
+ web_console_v2/api/cmds/flask_cli_bin \
+ $*
+}
+
+# set runtime env
+source web_console_v2/api/cmds/runtime_env.sh
+
+# Configure ElasticSearch ILM Information
+flask_command es-configuration
+
+# Iterates arguments
+while test $# -gt 0
+do
+ case "$1" in
+ --migrate)
+ echo "Migrating DB"
+ # Migrates DB schemas
+ flask_command db upgrade \
+ --directory web_console_v2/api/migrations
+ ;;
+ esac
+ shift
+done
+
+flask_command create-initial-data
+
+export FEDLEARNER_WEBCONSOLE_LOG_DIR=/var/log/fedlearner_webconsole/
+mkdir -p $FEDLEARNER_WEBCONSOLE_LOG_DIR
+
+# This starts supervisor daemon which will start all processes defined in
+# supervisord.conf. The daemon starts in background by default.
+web_console_v2/api/cmds/supervisord_cli_bin \
+ -c web_console_v2/api/cmds/supervisord.conf
+# This tails logs from all processes defined in supervisord.conf.
+# The purpose for this is to put supervisor to foreground so that the
+# pod/container will not be terminated.
+web_console_v2/api/cmds/supervisorctl_cli_bin \
+ -c web_console_v2/api/cmds/supervisord.conf maintail -f
diff --git a/web_console_v2/api/cmds/runtime_env.sh b/web_console_v2/api/cmds/runtime_env.sh
new file mode 100644
index 000000000..8d3f616af
--- /dev/null
+++ b/web_console_v2/api/cmds/runtime_env.sh
@@ -0,0 +1,42 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#!/bin/bash
+set -e
+
+# Adds root directory to python path to make the modules findable.
+export PYTHONPATH=$PYTHONPATH:"web_console_v2/api/"
+
+# disable pymalloc to avoid high memory usage when parrallism allocation small objects
+# Ref: https://docs.python.org/3/c-api/memory.html#the-pymalloc-allocator
+export PYTHONMALLOC=malloc
+export PYTHONUNBUFFERED=1
+
+# When recongize HADOOP_HOME, export some also useful environment variables for GFile
+if [ ! -z $HADOOP_HOME ]
+then
+ echo "set hadoop env"
+ # This is super import for compitable with hadoop and hadoop_current
+ if [ -f "$HADOOP_HOME/conf/hadoop-env.sh" ]
+ then
+ export HADOOP_CONF_DIR=$HADOOP_HOME/conf
+ source "$HADOOP_HOME/conf/hadoop-env.sh" &> /dev/null
+ else
+ export HADOOP_CONF_DIR=$HADOOP_HOME/etc/hadoop
+ source "$HADOOP_HOME/etc/hadoop/hadoop-env.sh" &> /dev/null
+ fi
+ export LD_LIBRARY_PATH=${HADOOP_HOME}/lib/native:${HADOOP_HOME}/lib/native/nfs:${JAVA_HOME}/jre/lib/amd64/server:${LD_LIBRARY_PATH}
+ export CLASSPATH=$($HADOOP_HOME/bin/hadoop classpath --glob)
+fi
diff --git a/web_console_v2/api/cmds/supervisorctl_cli.py b/web_console_v2/api/cmds/supervisorctl_cli.py
new file mode 100644
index 000000000..622963683
--- /dev/null
+++ b/web_console_v2/api/cmds/supervisorctl_cli.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+from supervisor.supervisorctl import main
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/web_console_v2/api/cmds/supervisord.conf b/web_console_v2/api/cmds/supervisord.conf
new file mode 100644
index 000000000..2101d641c
--- /dev/null
+++ b/web_console_v2/api/cmds/supervisord.conf
@@ -0,0 +1,27 @@
+[supervisord]
+pidfile=/run/supervisord.pid
+loglevel=debug
+
+; The section sets up an HTTP server that listens on "file", which can be used
+; to control the daemon
+[unix_http_server]
+file=/var/run/supervisor.sock
+
+[rpcinterface:supervisor]
+supervisor.rpcinterface_factory=supervisor.rpcinterface:make_main_rpcinterface
+
+; This section connects to the HTTP server to control the HTTP server
+[supervisorctl]
+serverurl=unix:///var/run/supervisor.sock
+
+[program:restful_api]
+command=web_console_v2/api/cmds/gunicorn_cli_bin --conf web_console_v2/api/gunicorn_config.py server:app
+redirect_stderr=true
+
+[program:rpc]
+command=web_console_v2/api/entrypoint_bin start-rpc
+redirect_stderr=true
+
+[program:composer]
+command=web_console_v2/api/entrypoint_bin start-composer
+redirect_stderr=true
diff --git a/web_console_v2/api/cmds/supervisord_cli.py b/web_console_v2/api/cmds/supervisord_cli.py
new file mode 100644
index 000000000..0725d2bad
--- /dev/null
+++ b/web_console_v2/api/cmds/supervisord_cli.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+from supervisor.supervisord import main
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/web_console_v2/api/cmds/supervisord_dev.conf b/web_console_v2/api/cmds/supervisord_dev.conf
new file mode 100644
index 000000000..373a93b99
--- /dev/null
+++ b/web_console_v2/api/cmds/supervisord_dev.conf
@@ -0,0 +1,28 @@
+[supervisord]
+pidfile=/tmp/supervisord.pid
+loglevel=debug
+
+; The section sets up an HTTP server that listens on "file", which can be used
+; to control the daemon
+[unix_http_server]
+file=/tmp/supervisor.sock
+
+[rpcinterface:supervisor]
+supervisor.rpcinterface_factory=supervisor.rpcinterface:make_main_rpcinterface
+
+; This section connects to the HTTP server to control the HTTP server
+[supervisorctl]
+serverurl=unix:///tmp/supervisor.sock
+
+[program:restful_api]
+command=./web_console_v2/api/cmds/flask_cli_bin run --eager-loading --port=1991 --host=0.0.0.0
+environment=FLASK_APP=web_console_v2.api.server:app
+redirect_stderr=true
+
+[program:rpc]
+command=./web_console_v2/api/entrypoint_bin -- start-rpc
+redirect_stderr=true
+
+[program:composer]
+command=./web_console_v2/api/entrypoint_bin -- start-composer
+redirect_stderr=true
diff --git a/web_console_v2/api/command.py b/web_console_v2/api/command.py
index ca3fdc337..776408412 100644
--- a/web_console_v2/api/command.py
+++ b/web_console_v2/api/command.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,19 +13,24 @@
# limitations under the License.
# coding: utf-8
+import click
+
from config import Config
+from flask_migrate import Migrate
+from es_configuration import es_config
from fedlearner_webconsole.app import create_app
-from fedlearner_webconsole.db import db_handler as db
+from fedlearner_webconsole.db import db
from fedlearner_webconsole.initial_db import initial_db
-from flask_migrate import Migrate
-
from fedlearner_webconsole.utils.hooks import pre_start_hook
+from tools.project_cleanup import delete_project
+from tools.workflow_migration.workflow_completed_failed import migrate_workflow_completed_failed_state
+from tools.dataset_migration.dataset_job_name_migration.dataset_job_name_migration import migrate_dataset_job_name
+from tools.variable_finder import find
class CliConfig(Config):
- START_GRPC_SERVER = False
START_SCHEDULER = False
- START_COMPOSER = False
+ START_K8S_WATCHER = False
pre_start_hook()
@@ -42,3 +47,35 @@ def create_initial_data():
@app.cli.command('create-db')
def create_db():
db.create_all()
+
+
+@app.cli.command('cleanup-project')
+@click.argument('project_id')
+def cleanup_project(project_id):
+ delete_project(int(project_id))
+
+
+@app.cli.command('migrate-workflow-completed-failed-state')
+def remove_intersection_dataset():
+ migrate_workflow_completed_failed_state()
+
+
+@app.cli.command('migrate-dataset-job-name')
+def add_dataset_job_name():
+ migrate_dataset_job_name()
+
+
+@app.cli.command('migrate-connect-to-test')
+def migrate_connect_to_test():
+ migrate_connect_to_test()
+
+
+@app.cli.command('find-variable')
+@click.argument('name')
+def find_variable(name: str):
+ find(name)
+
+
+@app.cli.command('es-configuration')
+def es_configuration():
+ es_config()
diff --git a/web_console_v2/api/composer.py b/web_console_v2/api/composer.py
new file mode 100644
index 000000000..d4d4248fa
--- /dev/null
+++ b/web_console_v2/api/composer.py
@@ -0,0 +1,28 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.composer import composer
+from fedlearner_webconsole.utils.hooks import pre_start_hook
+
+if __name__ == '__main__':
+ # TODO(wangsen.0914): refactor logging_config
+ # There's a race condition when multiple process logging to same file.
+ logging.basicConfig(level=logging.DEBUG)
+ pre_start_hook()
+ logging.info('Starting composer...')
+ composer.run(db.engine)
+ composer.wait_for_termination()
diff --git a/web_console_v2/api/config.py b/web_console_v2/api/config.py
index e58492b6e..f3d27f82d 100644
--- a/web_console_v2/api/config.py
+++ b/web_console_v2/api/config.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,9 +14,6 @@
# coding: utf-8
-import os
-import secrets
-
from fedlearner_webconsole.db import get_database_uri
from envs import Envs
@@ -28,12 +25,11 @@ class Config(object):
# For unicode strings
# Ref: https://stackoverflow.com/questions/14853694/python-jsonify-dictionary-in-utf-8
JSON_AS_ASCII = False
- JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', secrets.token_urlsafe(64))
+ JWT_SECRET_KEY = Envs.JWT_SECRET_KEY
PROPAGATE_EXCEPTIONS = True
- GRPC_LISTEN_PORT = 1990
+ GRPC_LISTEN_PORT = Envs.GRPC_LISTEN_PORT
JWT_ACCESS_TOKEN_EXPIRES = 86400
STORAGE_ROOT = Envs.STORAGE_ROOT
- START_GRPC_SERVER = True
START_SCHEDULER = True
- START_COMPOSER = os.getenv('START_COMPOSER', True)
+ START_K8S_WATCHER = True
diff --git a/web_console_v2/api/docs/BUILD.bazel b/web_console_v2/api/docs/BUILD.bazel
new file mode 100644
index 000000000..821338eb2
--- /dev/null
+++ b/web_console_v2/api/docs/BUILD.bazel
@@ -0,0 +1,5 @@
+filegroup(
+ name = "srcs",
+ srcs = glob(["**/*"]),
+ visibility = ["//visibility:public"],
+)
diff --git a/web_console_v2/api/docs/best_practices.md b/web_console_v2/api/docs/best_practices.md
index 9c5964b50..88aaaddf6 100644
--- a/web_console_v2/api/docs/best_practices.md
+++ b/web_console_v2/api/docs/best_practices.md
@@ -6,6 +6,8 @@ flask-migrate, which needs us to upgrade the migration files once schema gets
updated (inefficiently). Integers/strings makes us easy to extend the enums,
the disadvantage is we should take care of data migrations if enum is deleted.
+Natively sqlalchemy support Enum type in a column. [Ref](https://docs.sqlalchemy.org/en/14/core/type_basics.html#sqlalchemy.types.Enum)
+
### Index in DB
Index is not necessary if the value of column is very limited, such as enum
or boolean. Reference: https://tech.meituan.com/2014/06/30/mysql-index.html
@@ -53,3 +55,58 @@ See details [here](https://en.wikipedia.org/wiki/Representational_state_transfer
primaryjoin='Project.id == '
'foreign(Job.project_id)')
```
+
+### sqlalchemy session
+* Please limit the session/transaction scope as small as possible, otherwise it may not work as expected.
+[Ref](https://docs.sqlalchemy.org/en/14/orm/session_basics.html#when-do-i-construct-a-session-when-do-i-commit-it-and-when-do-i-close-it)
+```python
+#BAD: the transaction will include the runner query, it may stale.
+with db.session_scope() as session:
+ init_runners = session.query(SchedulerRunner).filter_by(
+ status=RunnerStatus.INIT.value).all()
+ for runner in init_runners:
+ # Do something with the runner
+ session.commit()
+
+#GOOD: make the transaction scope clear.
+with db.session_scope() as session:
+ running_runner_ids = session.query(SchedulerRunner.id).filter_by(
+ status=RunnerStatus.RUNNING.value).all()
+for runner_id, *_ in running_runner_ids:
+ with db.session_scope() as session:
+ runner = session.query(SchedulerRunner).get(runner_id)
+ # Do something with the runner
+ session.commit()
+```
+
+### Pagination
+- Use `utils/paginate.py`, and **read the test case** as a quickstart
+- All resources are **un-paginated** by default
+- Append page metadata in your returned body in the following format:
+```json
+// your POV
+{
+ "data": pagination.get_items(),
+ "page_meta": pagination.get_metadata()
+}
+
+// frontend POV
+{
+ "data": {...},
+ "page_meta": {
+ "current_page": 1,
+ "page_size": 5,
+ "total_pages": 2,
+ "total_items": 7
+ }
+}
+```
+- **ALWAYS** return `page_meta`
+ - If your API is called with `page=...`, then paginate for the caller; return the pagination metadata as shown above
+ - If your API is called without `page=...`, then return the un-paginated data with an **empty** `page_meta` body like so:
+ ```json
+ {
+ "data": {...},
+ "page_meta": {}
+ }
+ ```
diff --git a/web_console_v2/api/docs/how_to_install_python362.md b/web_console_v2/api/docs/how_to_install_python362.md
new file mode 100644
index 000000000..98b137414
--- /dev/null
+++ b/web_console_v2/api/docs/how_to_install_python362.md
@@ -0,0 +1,78 @@
+# Background:
+
+When we use the outdated software version, the lack of upward compatibility will cause us to be unable to use higher versions of python. At this time, we need to switch and isolate the python version.
+
+In our project environment, we need to use the version of tensorflow v1.15.0, which requires a version of python3.6.2
+
+
+
+# How to install&use:
+
+## 1. Get a python3.6.2 using pyenv
+
+Firstly, we use *brew* to install *pyenv*:
+
+```shell
+brew install pyenv
+pyenv -v # check pyenv version
+```
+
+then, we check the python3.6 version provide by pyenv and install the version we need:
+
+```shell
+pyenv install --list | grep 3.6
+pyenv install 3.6.2
+```
+
+
+
+###### if you see error like this,
+
+```shell
+Last 10 log lines:
+./Modules/posixmodule.c:8210:15: error: implicit declaration of function 'sendfile' is invalid in C99 [-Werror,-Wimplicit-function-declaration]
+ret = sendfile(in, out, offset, &sbytes, &sf, flags);
+^
+./Modules/posixmodule.c:10432:5: warning: code will never be executed [-Wunreachable-code]
+Py_FatalError("abort() called from Python code didn't abort!");
+^~~~~~~~~~~~~
+1 warning and 1 error generated.
+make: *** [Modules/posixmodule.o] Error 1
+make: *** Waiting for unfinished jobs....
+1 warning generated`
+```
+
+###### you can try to install python version using:
+
+```shell
+CFLAGS="-I$(brew --prefix openssl)/include -I$(brew --prefix bzip2)/include -I$(brew --prefix readline)/include -I$(xcrun --show-sdk-path)/usr/include" LDFLAGS="-L$(brew --prefix openssl)/lib -L$(brew --prefix readline)/lib -L$(brew --prefix zlib)/lib -L$(brew --prefix bzip2)/lib" pyenv install --patch 3.6.2 < <(curl -sSL https://github.com/python/cpython/commit/8ea6353.patch\?full_index\=1)
+```
+
+
+
+Check all the available python versions in pyenv to make sure your python3.6.2 installed successfully:
+
+```shell
+pyenv versions
+```
+
+
+
+## 2. create your own python3.6.2 virtualenv:
+
+When using different python versions, I strongly recommend you to create a python virtual environment to isolate the python packages between different versions.
+
+the python3.6.2 path installed by *pyenv* is ~/.pyenv/versions/3.6.2/bin/python3.6
+
+create/manage/use your python3.6.2 virtualenv by *virtualenvwrapper*:
+
+```shell
+mkvirtualenv --python=/users/bytedance/.pyenv/versions/3.6.2/bin/python3.6 $ENV_NAME
+workon $ENV_NAME
+```
+
+or just using *venv*, or you can just using the virtualenv created by *pycharm*
+
+
+
+***note***: please **DO NOT** directly install packages based on python3.6.2 on your real env
\ No newline at end of file
diff --git a/web_console_v2/api/entrypoint.py b/web_console_v2/api/entrypoint.py
new file mode 100644
index 000000000..e4dd2facf
--- /dev/null
+++ b/web_console_v2/api/entrypoint.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from click import group
+
+from envs import Envs
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.hooks import pre_start_hook
+from fedlearner_webconsole.composer.composer import composer
+from fedlearner_webconsole.rpc.server import rpc_server
+from fedlearner_webconsole.k8s.k8s_watcher import k8s_watcher
+
+
+@group('entrypoint')
+def entrypoint():
+ pass
+
+
+@entrypoint.command('start-rpc')
+def start_rpc():
+ logging.info('Starting Rpc...')
+ # Start k8s watcher in rpc server process for now.
+ k8s_watcher.start()
+ rpc_server.stop()
+ rpc_server.start(Envs.GRPC_LISTEN_PORT)
+ rpc_server.wait_for_termination()
+
+
+@entrypoint.command('start-composer')
+def start_composer():
+ # TODO(wangsen.0914): refactor logging_config
+ # There's a race condition when multiple process logging to same file.
+ logging.basicConfig(level=logging.DEBUG)
+ logging.info('Starting composer...')
+ composer.run(db.engine)
+ composer.wait_for_termination()
+
+
+if __name__ == '__main__':
+ pre_start_hook()
+ entrypoint()
diff --git a/web_console_v2/api/envs.py b/web_console_v2/api/envs.py
index 6f6f5f39d..b4d7a4900 100644
--- a/web_console_v2/api/envs.py
+++ b/web_console_v2/api/envs.py
@@ -1,56 +1,172 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import os
-import json
+import re
+import secrets
+from typing import Optional
+from urllib.parse import unquote
+from google.protobuf.json_format import Parse, ParseError
import pytz
+from fedlearner_webconsole.proto import setting_pb2
+from fedlearner_webconsole.utils.const import API_VERSION
+
+# SQLALCHEMY_DATABASE_URI pattern dialect+driver://username:password@host:port/database
+_SQLALCHEMY_DATABASE_URI_PATTERN = re.compile(
+ r'^(?P[^+:]+)(\+(?P[^:]+))?://'
+ r'((?P[^:@]+)?:(?P[^@]+)?@((?P[^:/]+)(:(?P[0-9]+))?)?)?'
+ r'/(?P[^?]+)?')
+
+# Limit one thread used by OpenBLAS to avoid many threads that hang.
+# ref: https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy
+os.environ['OMP_NUM_THREADS'] = '1'
+
class Envs(object):
+ SERVER_HOST = os.environ.get('SERVER_HOST', 'http://localhost:666/')
TZ = pytz.timezone(os.environ.get('TZ', 'UTC'))
- ES_HOST = os.environ.get('ES_HOST',
- 'fedlearner-stack-elasticsearch-client')
- ES_READ_HOST = os.environ.get('ES_READ_HOST', ES_HOST)
+ ES_HOST = os.environ.get('ES_HOST', 'fedlearner-stack-elasticsearch-client')
ES_PORT = os.environ.get('ES_PORT', 9200)
ES_USERNAME = os.environ.get('ES_USERNAME', 'elastic')
ES_PASSWORD = os.environ.get('ES_PASSWORD', 'Fedlearner123')
+ # apm-server service address which is used to collect trace and custom metrics
+ APM_SERVER_ENDPOINT = os.environ.get('APM_SERVER_ENDPOINT', 'http://fedlearner-stack-apm-server:8200')
# addr to Kibana in pod/cluster
- KIBANA_SERVICE_ADDRESS = os.environ.get(
- 'KIBANA_SERVICE_ADDRESS', 'http://fedlearner-stack-kibana:443')
+ KIBANA_SERVICE_ADDRESS = os.environ.get('KIBANA_SERVICE_ADDRESS', 'http://fedlearner-stack-kibana:443')
# addr to Kibana outside cluster, typically comply with port-forward
KIBANA_ADDRESS = os.environ.get('KIBANA_ADDRESS', 'localhost:1993')
# What fields are allowed in peer query.
- KIBANA_ALLOWED_FIELDS = set(
- f for f in os.environ.get('KIBANA_ALLOWED_FIELDS', '*').split(',')
- if f)
- OPERATOR_LOG_MATCH_PHRASE = os.environ.get('OPERATOR_LOG_MATCH_PHRASE',
- None)
- # Whether to use the real jwt_required decorator or fake one
+ KIBANA_ALLOWED_FIELDS = set(f for f in os.environ.get('KIBANA_ALLOWED_FIELDS', '*').split(',') if f)
+ # Kibana dashboard list of dashboard information consist of [`name`, `uuid`] in json format
+ KIBANA_DASHBOARD_LIST = os.environ.get('KIBANA_DASHBOARD_LIST', '[]')
+ OPERATOR_LOG_MATCH_PHRASE = os.environ.get('OPERATOR_LOG_MATCH_PHRASE', None)
+ # Whether to use the real credentials_required decorator or fake one
DEBUG = os.environ.get('DEBUG', False)
+ SWAGGER_URL_PREFIX = os.environ.get('SWAGGER_URL_PREFIX', API_VERSION)
+ # grpc client can use this GRPC_SERVER_URL when DEBUG is True
+ GRPC_SERVER_URL = os.environ.get('GRPC_SERVER_URL', None)
+ GRPC_LISTEN_PORT = int(os.environ.get('GRPC_LISTEN_PORT', 1990))
+ RESTFUL_LISTEN_PORT = int(os.environ.get('RESTFUL_LISTEN_PORT', 1991))
+ # composer server listen port for health checking service
+ COMPOSER_LISTEN_PORT = int(os.environ.get('COMPOSER_LISTEN_PORT', 1992))
ES_INDEX = os.environ.get('ES_INDEX', 'filebeat-*')
# Indicates which k8s namespace fedlearner pods belong to
K8S_NAMESPACE = os.environ.get('K8S_NAMESPACE', 'default')
K8S_CONFIG_PATH = os.environ.get('K8S_CONFIG_PATH', None)
- # additional info for k8s.metadata.labels
- K8S_LABEL_INFO = json.loads(os.environ.get('K8S_LABEL_INFO', '{}'))
- FEDLEARNER_WEBCONSOLE_LOG_DIR = os.environ.get(
- 'FEDLEARNER_WEBCONSOLE_LOG_DIR', '.')
+ K8S_HOOK_MODULE_PATH = os.environ.get('K8S_HOOK_MODULE_PATH', None)
+ FEDLEARNER_WEBCONSOLE_LOG_DIR = os.environ.get('FEDLEARNER_WEBCONSOLE_LOG_DIR', '.')
+ LOG_LEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()
FLASK_ENV = os.environ.get('FLASK_ENV', 'development')
+ CLUSTER = os.environ.get('CLUSTER', 'default')
+ JWT_SECRET_KEY = os.environ.get('JWT_SECRET_KEY', secrets.token_urlsafe(64))
# In seconds
- GRPC_CLIENT_TIMEOUT = os.environ.get('GRPC_CLIENT_TIMEOUT', 5)
+ GRPC_CLIENT_TIMEOUT = int(os.environ.get('GRPC_CLIENT_TIMEOUT', 5))
+ # In seconds
+ GRPC_STREAM_CLIENT_TIMEOUT = int(os.environ.get('GRPC_STREAM_CLIENT_TIMEOUT', 10))
# storage filesystem
STORAGE_ROOT = os.getenv('STORAGE_ROOT', '/data')
# BASE_DIR
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
- # spark on k8s image url
- SPARKAPP_IMAGE_URL = os.getenv('SPARKAPP_IMAGE_URL', None)
- SPARKAPP_FILES_PATH = os.getenv('SPARKAPP_FILES_PATH', None)
- SPARKAPP_VOLUMES = os.getenv('SPARKAPP_VOLUMES', None)
- SPARKAPP_VOLUME_MOUNTS = os.getenv('SPARKAPP_VOLUME_MOUNTS', None)
# Hooks
PRE_START_HOOK = os.environ.get('PRE_START_HOOK', None)
+ # Flags
+ FLAGS = os.environ.get('FLAGS', '{}')
+
+ # Third party SSO, see the example in test_sso.json
+ SSO_INFOS = os.environ.get('SSO_INFOS', '[]')
+
+ # Audit module storage setting
+ AUDIT_STORAGE = os.environ.get('AUDIT_STORAGE', 'db')
+
+ # system info, include name, domain name, ip
+ SYSTEM_INFO = os.environ.get('SYSTEM_INFO', '{}')
+
+ CUSTOMIZED_FILE_MANAGER = os.environ.get('CUSTOMIZED_FILE_MANAGER')
+ SCHEDULER_POLLING_INTERVAL = os.environ.get('FEDLEARNER_WEBCONSOLE_POLLING_INTERVAL', 60)
+
+ # DB related
+ SQLALCHEMY_DATABASE_URI = os.environ.get('SQLALCHEMY_DATABASE_URI')
+ DB_HOST = os.environ.get('DB_HOST')
+ DB_PORT = os.environ.get('DB_PORT')
+ DB_DATABASE = os.environ.get('DB_DATABASE')
+ DB_USERNAME = os.environ.get('DB_USERNAME')
+ DB_PASSWORD = os.environ.get('DB_PASSWORD')
+
+ # Fedlearner related
+ KVSTORE_TYPE = os.environ.get('KVSTORE_TYPE')
+ ETCD_NAME = os.environ.get('ETCD_NAME')
+ ETCD_ADDR = os.environ.get('ETCD_ADDR')
+ ETCD_BASE_DIR = os.environ.get('ETCD_BASE_DIR')
+ ROBOT_USERNAME = os.environ.get('ROBOT_USERNAME')
+ ROBOT_PWD = os.environ.get('ROBOT_PWD')
+ WEB_CONSOLE_V2_ENDPOINT = os.environ.get('WEB_CONSOLE_V2_ENDPOINT')
+ HADOOP_HOME = os.environ.get('HADOOP_HOME')
+ JAVA_HOME = os.environ.get('JAVA_HOME')
+
+ @staticmethod
+ def _decode_url_codec(codec: str) -> str:
+ if not codec:
+ return codec
+ return unquote(codec)
+
+ @classmethod
+ def _check_db_envs(cls) -> Optional[str]:
+ # Checks if DB related envs are matched
+ if cls.SQLALCHEMY_DATABASE_URI:
+ matches = _SQLALCHEMY_DATABASE_URI_PATTERN.match(cls.SQLALCHEMY_DATABASE_URI)
+ if not matches:
+ return 'Invalid SQLALCHEMY_DATABASE_URI'
+ if cls.DB_HOST:
+ # Other DB_* envs should be set together
+ db_host = cls._decode_url_codec(matches.group('host'))
+ if cls.DB_HOST != db_host:
+ return 'DB_HOST does not match'
+ db_port = cls._decode_url_codec(matches.group('port'))
+ if cls.DB_PORT != db_port:
+ return 'DB_PORT does not match'
+ db_database = cls._decode_url_codec(matches.group('database'))
+ if cls.DB_DATABASE != db_database:
+ return 'DB_DATABASQLALCHEMY_DATABASE_URISE does not match'
+ db_username = cls._decode_url_codec(matches.group('username'))
+ if cls.DB_USERNAME != db_username:
+ return 'DB_USERNAME does not match'
+ db_password = cls._decode_url_codec(matches.group('password'))
+ if cls.DB_PASSWORD != db_password:
+ return 'DB_PASSWORD does not match'
+ return None
+
+ @classmethod
+ def _check_system_info_envs(cls) -> Optional[str]:
+ try:
+ system_info = Parse(Envs.SYSTEM_INFO, setting_pb2.SystemInfo())
+ except ParseError as err:
+ return f'failed to parse SYSTEM_INFO {err}'
+ if system_info.domain_name == '' or system_info.name == '':
+ return 'domain_name or name is not set into SYSTEM_INFO'
+ return None
-class Features(object):
- FEATURE_MODEL_K8S_HOOK = os.getenv('FEATURE_MODEL_K8S_HOOK')
- FEATURE_MODEL_WORKFLOW_HOOK = os.getenv('FEATURE_MODEL_WORKFLOW_HOOK')
- DATA_MODULE_BETA = os.getenv('DATA_MODULE_BETA', None)
+ @classmethod
+ def check(cls) -> Optional[str]:
+ db_envs_error = cls._check_db_envs()
+ if db_envs_error:
+ return db_envs_error
+ system_info_envs_error = cls._check_system_info_envs()
+ if system_info_envs_error:
+ return system_info_envs_error
+ return None
diff --git a/web_console_v2/api/envs_test.py b/web_console_v2/api/envs_test.py
new file mode 100644
index 000000000..5dd724374
--- /dev/null
+++ b/web_console_v2/api/envs_test.py
@@ -0,0 +1,108 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch
+
+from envs import _SQLALCHEMY_DATABASE_URI_PATTERN, Envs
+
+
+class EnvsTest(unittest.TestCase):
+
+ def test_sqlalchemy_database_uri_pattern(self):
+ # Sqlite
+ matches = _SQLALCHEMY_DATABASE_URI_PATTERN.match('sqlite:///app.db')
+ self.assertIsNotNone(matches)
+ self.assertEqual(matches.group('dialect'), 'sqlite')
+ self.assertIsNone(matches.group('driver'))
+ self.assertIsNone(matches.group('username'))
+ self.assertIsNone(matches.group('password'))
+ self.assertIsNone(matches.group('host'))
+ self.assertIsNone(matches.group('port'))
+ self.assertEqual(matches.group('database'), 'app.db')
+ # MySQL
+ matches = _SQLALCHEMY_DATABASE_URI_PATTERN.match('mysql+pymysql://root:root@localhost:33600/fedlearner')
+ self.assertIsNotNone(matches)
+ self.assertEqual(matches.group('dialect'), 'mysql')
+ self.assertEqual(matches.group('driver'), 'pymysql')
+ self.assertEqual(matches.group('username'), 'root')
+ self.assertEqual(matches.group('password'), 'root')
+ self.assertEqual(matches.group('host'), 'localhost')
+ self.assertEqual(matches.group('port'), '33600')
+ self.assertEqual(matches.group('database'), 'fedlearner')
+ # MySQL with socket
+ matches = _SQLALCHEMY_DATABASE_URI_PATTERN.match('mysql+pymysql://:@/?charset=utf8mb4db_psm=mysql.fedlearner')
+ self.assertIsNotNone(matches)
+ self.assertEqual(matches.group('dialect'), 'mysql')
+ self.assertEqual(matches.group('driver'), 'pymysql')
+ self.assertIsNone(matches.group('username'))
+ self.assertIsNone(matches.group('password'))
+ self.assertIsNone(matches.group('host'))
+ self.assertIsNone(matches.group('port'))
+ self.assertIsNone(matches.group('database'))
+ # Invalid ones
+ matches = _SQLALCHEMY_DATABASE_URI_PATTERN.match('mysql+pymysql://root_33600/fedlearner')
+ self.assertIsNone(matches)
+ matches = _SQLALCHEMY_DATABASE_URI_PATTERN.match('sqlite://hello')
+ self.assertIsNone(matches)
+
+ def test_check_db_envs_valid(self):
+ with patch('envs.Envs.SQLALCHEMY_DATABASE_URI', 'mysql+pymysql://root:proot@localhost:33600/fedlearner'), \
+ patch('envs.Envs.DB_HOST', 'localhost'), \
+ patch('envs.Envs.DB_PORT', '33600'), \
+ patch('envs.Envs.DB_DATABASE', 'fedlearner'), \
+ patch('envs.Envs.DB_USERNAME', 'root'), \
+ patch('envs.Envs.DB_PASSWORD', 'proot'):
+ self.assertIsNone(Envs._check_db_envs()) # pylint: disable=protected-access
+ # DB_HOST is not set
+ with patch('envs.Envs.SQLALCHEMY_DATABASE_URI', 'mysql+pymysql://:@/?charset=utf8mb4db_psm=mysql.fedlearner'):
+ self.assertIsNone(Envs._check_db_envs()) # pylint: disable=protected-access
+ # DB_PASSWORD with some encodings
+ with patch('envs.Envs.SQLALCHEMY_DATABASE_URI', 'mysql+pymysql://root:fl%4012345@localhost:33600/fedlearner'), \
+ patch('envs.Envs.DB_HOST', 'localhost'), \
+ patch('envs.Envs.DB_PORT', '33600'), \
+ patch('envs.Envs.DB_DATABASE', 'fedlearner'), \
+ patch('envs.Envs.DB_USERNAME', 'root'), \
+ patch('envs.Envs.DB_PASSWORD', 'fl@12345'):
+ self.assertIsNone(Envs._check_db_envs()) # pylint: disable=protected-access
+
+ def test_check_db_envs_invalid(self):
+ with patch('envs.Envs.SQLALCHEMY_DATABASE_URI', 'mysql+pymysql://root:proot@localhost:33600/fedlearner'), \
+ patch('envs.Envs.DB_HOST', 'localhost'), \
+ patch('envs.Envs.DB_PORT', '336'):
+ self.assertEqual(Envs._check_db_envs(), 'DB_PORT does not match') # pylint: disable=protected-access
+ with patch('envs.Envs.SQLALCHEMY_DATABASE_URI', 'mysql+pymysql://:@/?charset=utf8mb4db_psm=mysql.fedlearner'), \
+ patch('envs.Envs.DB_HOST', 'localhost'):
+ self.assertEqual(Envs._check_db_envs(), 'DB_HOST does not match') # pylint: disable=protected-access
+
+ def test_decode_url_codec(self):
+ self.assertIsNone(Envs._decode_url_codec(None)) # pylint: disable=protected-access
+ self.assertEqual(Envs._decode_url_codec('hahaha'), 'hahaha') # pylint: disable=protected-access
+ self.assertEqual(Envs._decode_url_codec('%20%40'), ' @') # pylint: disable=protected-access
+
+ def test_system_info_valid(self):
+ with patch('envs.Envs.SYSTEM_INFO', '{"domain_name": "aaa.fedlearner.net", "name": "aaa.Inc"}'):
+ self.assertIsNone(Envs._check_system_info_envs()) # pylint: disable=protected-access
+
+ def test_system_info_invalid(self):
+ with patch('envs.Envs.SYSTEM_INFO', '{"domain_name": "aaa.fedlearner.net"}'):
+ self.assertEqual('domain_name or name is not set into SYSTEM_INFO', Envs._check_system_info_envs()) # pylint: disable=protected-access
+
+ with patch('envs.Envs.SYSTEM_INFO', '{"domain_name": "aaa.fedlearner.net"'):
+ self.assertIn('failed to parse SYSTEM_INFO', Envs._check_system_info_envs()) # pylint: disable=protected-access
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/es_configuration.py b/web_console_v2/api/es_configuration.py
index a77694eec..96081f7f2 100644
--- a/web_console_v2/api/es_configuration.py
+++ b/web_console_v2/api/es_configuration.py
@@ -1,3 +1,18 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
import requests
from elasticsearch import Elasticsearch, exceptions
@@ -14,42 +29,46 @@ def _configure_index_alias(es, alias_name):
es.indices.create(
# resolves to alias_name-yyyy.mm.dd-000001 in ES
f'<{alias_name}-{{now/d}}-000001>',
- body={"aliases": {alias_name: {"is_write_index": True}}}
- )
+ body={'aliases': {
+ alias_name: {
+ 'is_write_index': True
+ }
+ }})
def _configure_kibana_index_patterns(kibana_addr, index_type):
if not kibana_addr:
- requests.post(
- url='{}/api/saved_objects/index-pattern/{}'
- .format(kibana_addr, ALIAS_NAME[index_type]),
- json={'attributes': {
- 'title': ALIAS_NAME[index_type] + '*',
- 'timeFieldName': 'tags.process_time'
- if index_type == 'metrics' else 'tags.event_time'}},
- headers={'kbn-xsrf': 'true',
- 'Content-Type': 'application/json'},
- params={'overwrite': True}
- )
+ requests.post(url=f'{kibana_addr}/api/saved_objects/index-pattern/{ALIAS_NAME[index_type]}',
+ json={
+ 'attributes': {
+ 'title': ALIAS_NAME[index_type] + '*',
+ 'timeFieldName': 'tags.process_time' if index_type == 'metrics' else 'tags.event_time'
+ }
+ },
+ headers={
+ 'kbn-xsrf': 'true',
+ 'Content-Type': 'application/json'
+ },
+ params={'overwrite': True})
def put_ilm(es, ilm_name, hot_size='50gb', hot_age='10d', delete_age='30d'):
ilm_body = {
- "policy": {
- "phases": {
- "hot": {
- "min_age": "0ms",
- "actions": {
- "rollover": {
- "max_size": hot_size,
- "max_age": hot_age
+ 'policy': {
+ 'phases': {
+ 'hot': {
+ 'min_age': '0ms',
+ 'actions': {
+ 'rollover': {
+ 'max_size': hot_size,
+ 'max_age': hot_age
}
}
},
- "delete": {
- "min_age": delete_age,
- "actions": {
- "delete": {}
+ 'delete': {
+ 'min_age': delete_age,
+ 'actions': {
+ 'delete': {}
}
}
}
@@ -64,28 +83,25 @@ def _put_index_template(es, index_type, shards):
es.indices.put_template(template_name, template_body)
-if __name__ == '__main__':
- es = Elasticsearch([{'host': Envs.ES_HOST, 'port': Envs.ES_PORT}],
- http_auth=(Envs.ES_USERNAME, Envs.ES_PASSWORD))
+def es_config():
+ es = Elasticsearch([{'host': Envs.ES_HOST, 'port': Envs.ES_PORT}], http_auth=(Envs.ES_USERNAME, Envs.ES_PASSWORD))
if int(es.info()['version']['number'].split('.')[0]) == 7:
es.ilm.start()
for index_type, alias_name in ALIAS_NAME.items():
- put_ilm(es, 'fedlearner_{}_ilm'.format(index_type))
+ put_ilm(es, f'fedlearner_{index_type}_ilm')
_put_index_template(es, index_type, shards=1)
_configure_index_alias(es, alias_name)
# Kibana index-patterns initialization
- _configure_kibana_index_patterns(
- Envs.KIBANA_SERVICE_ADDRESS, index_type
- )
+ _configure_kibana_index_patterns(Envs.KIBANA_SERVICE_ADDRESS, index_type)
# Filebeat's built-in ilm does not contain delete phase. Below will
# add a delete phase to the existing policy.
# NOTE: Due to compatibility, should put policy only when policy exists,
# but no method to check existence. So use try-except to do the trick.
- for filebeat_name in ('filebeat-7.7.1', 'filebeat-7.0.1'):
- try:
- es.ilm.get_lifecycle(policy=filebeat_name)
- except exceptions.NotFoundError:
- pass
- else:
- put_ilm(es, filebeat_name, hot_age='1d')
+ filebeat_name = 'filebeat'
+ try:
+ es.ilm.get_lifecycle(policy=filebeat_name)
+ except exceptions.NotFoundError:
+ pass
+ else:
+ put_ilm(es, filebeat_name, hot_age='1d')
# Filebeat template and indices should be deployed during deployment.
diff --git a/web_console_v2/api/fedlearner_webconsole/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/BUILD.bazel
new file mode 100644
index 000000000..66220d6a0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/BUILD.bazel
@@ -0,0 +1,157 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "db_lib",
+ srcs = ["db.py"],
+ imports = [".."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "@common_pymysql//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "db_lib_test",
+ size = "small",
+ srcs = [
+ "db_test.py",
+ ],
+ imports = [".."],
+ main = "db_test.py",
+ deps = [
+ ":db_lib",
+ ],
+)
+
+py_library(
+ name = "initial_db_lib",
+ srcs = [
+ "initial_db.py",
+ ],
+ data = [
+ "//web_console_v2/api/fedlearner_webconsole/sys_preset_templates",
+ ],
+ imports = [".."],
+ deps = [
+ ":db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "initial_db_lib_test",
+ size = "small",
+ srcs = [
+ "initial_db_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/fedlearner_webconsole/sys_preset_templates",
+ ],
+ imports = [".."],
+ main = "initial_db_test.py",
+ deps = [
+ ":db_lib",
+ ":initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "exceptions_lib",
+ srcs = ["exceptions.py"],
+ imports = [".."],
+ deps = ["@common_flask//:pkg"],
+)
+
+py_test(
+ name = "exceptions_lib_test",
+ size = "small",
+ srcs = [
+ "exceptions_test.py",
+ ],
+ imports = [".."],
+ main = "exceptions_test.py",
+ deps = [
+ ":exceptions_lib",
+ ],
+)
+
+py_library(
+ name = "app_lib",
+ srcs = [
+ "app.py",
+ ],
+ imports = [".."],
+ deps = [
+ ":db_lib",
+ ":exceptions_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api:logging_config_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/cleanup:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/debug:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/e2e:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/file:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_watcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:middlewares_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/scheduler:scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/serving:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/sparkapp:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:swagger_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:apis_lib",
+ "@common_apispec_webframeworks//:pkg",
+ "@common_flasgger//:pkg",
+ "@common_flask//:pkg",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_sqlalchemy//:pkg",
+ "@common_webargs//:pkg",
+ "@common_werkzeug//:pkg",
+ ],
+)
+
+py_test(
+ name = "app_test",
+ size = "medium",
+ srcs = ["app_test.py"],
+ imports = ["../.."],
+ main = "app_test.py",
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/__init__.py b/web_console_v2/api/fedlearner_webconsole/__init__.py
deleted file mode 100644
index cd7504799..000000000
--- a/web_console_v2/api/fedlearner_webconsole/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-
-from fedlearner_webconsole import auth
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/algorithm/BUILD.bazel
new file mode 100644
index 000000000..ecd438982
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/BUILD.bazel
@@ -0,0 +1,195 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ # TODO(gezhengqiang): tunes the perf
+ size = "medium",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "service_lib",
+ srcs = ["service.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "service_lib_test",
+ srcs = [
+ "service_test.py",
+ ],
+ imports = ["../.."],
+ main = "service_test.py",
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "fetcher_lib",
+ srcs = ["fetcher.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm/transmit",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:resource_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "fetcher_lib_test",
+ srcs = [
+ "fetcher_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data/algorithm",
+ ],
+ imports = ["../.."],
+ main = "fetcher_test.py",
+ deps = [
+ ":fetcher_lib",
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm/transmit",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "@common_python_slugify//:pkg",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":fetcher_lib",
+ ":models_lib",
+ ":service_lib",
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms:preset_algorithm_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:resource_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:sorting_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask//:pkg",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_sqlalchemy//:pkg",
+ "@common_werkzeug//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "large",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/__init__.py b/web_console_v2/api/fedlearner_webconsole/algorithm/__init__.py
similarity index 100%
rename from web_console_v2/api/fedlearner_webconsole/proto/__init__.py
rename to web_console_v2/api/fedlearner_webconsole/algorithm/__init__.py
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/apis.py b/web_console_v2/api/fedlearner_webconsole/algorithm/apis.py
new file mode 100644
index 000000000..cd7a5b869
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/apis.py
@@ -0,0 +1,1720 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import os
+import json
+import tempfile
+import grpc
+from flask import request, send_file
+from typing import Optional
+from http import HTTPStatus
+from envs import Envs
+from sqlalchemy import Column
+from sqlalchemy.orm import Session
+from sqlalchemy.sql.elements import ColumnElement
+from werkzeug.utils import secure_filename
+from flask_restful import Resource
+from google.protobuf.json_format import ParseDict
+from marshmallow import Schema, post_load, fields, validate
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required, input_validator
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.file_tree import FileTreeBuilder
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder, SimpleExpression
+from fedlearner_webconsole.utils.sorting import SorterBuilder, SortExpression, parse_expression
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmParameter
+from fedlearner_webconsole.proto.filtering_pb2 import FilterOp
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.algorithm.utils import algorithm_project_path, algorithm_path, check_algorithm_file
+from fedlearner_webconsole.algorithm.preset_algorithms.preset_algorithm_service \
+ import create_algorithm_if_not_exists
+from fedlearner_webconsole.algorithm.service import AlgorithmProjectService, PendingAlgorithmService, AlgorithmService
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.utils.decorators.pp_flask import use_args, use_kwargs
+from fedlearner_webconsole.utils.flask_utils import make_flask_response, get_current_user, FilterExpField
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, ReleaseStatus, Source, AlgorithmType, \
+ PendingAlgorithm, normalize_path
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.exceptions import NoAccessException, NotFoundException, InvalidArgumentException, \
+ UnauthorizedException, ResourceConflictException
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.rpc.v2.resource_service_client import ResourceServiceClient
+
+file_manager = FileManager()
+file_operator = FileOperator()
+
+
+class UploadAlgorithmFile(Schema):
+ path = fields.Str(required=True)
+ filename = fields.Str(required=True)
+ is_directory = fields.Boolean(required=False, load_default=False)
+ file = fields.Raw(required=False, type='file', load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ return data
+
+
+def _validate_parameter(parameter):
+ try:
+ ParseDict(json.loads(parameter), AlgorithmParameter())
+ except: # pylint: disable=bare-except
+ return False
+ return True
+
+
+class CreateAlgorithmProjectParams(Schema):
+ name = fields.Str(required=True)
+ type = fields.Str(required=True, validate=validate.OneOf([a.name for a in AlgorithmType]))
+ parameter = fields.Str(required=False, load_default='{}', validate=_validate_parameter)
+ comment = fields.Str(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['parameter'] = ParseDict(json.loads(data['parameter']), AlgorithmParameter())
+ return data
+
+
+class GetAlgorithmProjectParams(Schema):
+ name = fields.Str(required=False, load_default=None)
+ sources = fields.List(fields.Str(required=False,
+ load_default=None,
+ validate=validate.OneOf(
+ [Source.PRESET.name, Source.USER.name, Source.THIRD_PARTY.name])),
+ load_default=None)
+ type = fields.Str(required=False, load_default=None, validate=validate.OneOf([a.name for a in AlgorithmType]))
+ keyword = fields.Str(required=False, load_default=None)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
+ filter_exp = FilterExpField(data_key='filter', required=False, load_default=None)
+ sorter_exp = fields.String(data_key='order_by', required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ return data
+
+
+class PatchAlgorithmProjectParams(Schema):
+ parameter = fields.Dict(required=False, load_default=None)
+ comment = fields.Str(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ return data
+
+
+def _get_project(project_id: int, session: Session) -> Project:
+ project = session.query(Project).get(project_id)
+ if project is None:
+ raise NotFoundException(f'project {project_id} is not found')
+ return project
+
+
+def _get_participant(participant_id: int, session: Session) -> Participant:
+ participant = session.query(Participant).get(participant_id)
+ if participant is None:
+ raise NotFoundException(f'participant {participant_id} is not found')
+ return participant
+
+
+def _get_algorithm(algo_id: int, session: Session, project_id: Optional[int] = None) -> Algorithm:
+ if project_id:
+ algo = session.query(Algorithm).filter_by(id=algo_id, project_id=project_id).first()
+ else:
+ algo = session.query(Algorithm).get(algo_id)
+ if algo is None:
+ raise NotFoundException(f'algorithm {algo_id} is not found')
+ return algo
+
+
+def _get_algorithm_project(algo_project_id: int,
+ session: Session,
+ project_id: Optional[int] = None) -> AlgorithmProject:
+ if project_id:
+ algo_project = session.query(AlgorithmProject).filter_by(id=algo_project_id, project_id=project_id).first()
+ else:
+ algo_project = session.query(AlgorithmProject).get(algo_project_id)
+ if algo_project is None:
+ raise NotFoundException(f'algorithm project {algo_project_id} is not found')
+ return algo_project
+
+
+def _get_pending_algorithm(pending_algorithm_id: int, session: Session) -> PendingAlgorithm:
+ pending_algo = session.query(PendingAlgorithm).get(pending_algorithm_id)
+ if pending_algo is None:
+ raise NotFoundException(f'pending algorithm {pending_algorithm_id} is not found')
+ return pending_algo
+
+
+class AlgorithmApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'download': fields.Bool(required=False, load_default=False)}, location='query')
+ def get(self, download: Optional[bool], algo_id: int):
+ """Get the algorithm by id
+ ---
+ tags:
+ - algorithm
+ description: get the algorithm by id
+ parameters:
+ - in: path
+ name: algo_id
+ schema:
+ type: integer
+ - in: query
+ name: download
+ schema:
+ type: boolean
+ responses:
+ 200:
+ description: detail of the algorithm
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ with db.session_scope() as session:
+ algo = _get_algorithm(algo_id, session)
+ # TODO(gezhengqiang): split download out for swagger
+ if not download:
+ return make_flask_response(algo.to_proto())
+ files = file_manager.ls(algo.path, include_directory=True)
+ if len(files) == 0:
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+ with tempfile.NamedTemporaryFile(suffix='.tar') as temp_file:
+ file_operator.archive_to([file.path for file in files], temp_file.name)
+ target_file_name = os.path.join(os.path.dirname(temp_file.name), f'{algo.name}.tar')
+ file_manager.copy(temp_file.name, target_file_name)
+ return send_file(filename_or_fp=target_file_name, mimetype='application/x-tar', as_attachment=True)
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM, op_type=Event.OperationType.DELETE)
+ def delete(self, algo_id: int):
+ """Delete the model
+ ---
+ tags:
+ - algorithm
+ description: delete the model
+ parameters:
+ - in: path
+ name: algo_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: delete the model successfully
+ """
+ with db.session_scope() as session:
+ algo = _get_algorithm(algo_id, session)
+ AlgorithmService(session).delete(algo)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM, op_type=Event.OperationType.UPDATE)
+ @use_kwargs({'comment': fields.Str(required=False, load_default=None)}, location='json')
+ def patch(self, comment: Optional[str], algo_id: int):
+ """Update an algorithm
+ ---
+ tags:
+ - algorithm
+ description: update an algorithm
+ parameters:
+ - in: path
+ name: algo_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ comment:
+ type: string
+ responses:
+ 200:
+ description: detail of the algorithm
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ with db.session_scope() as session:
+ algo = _get_algorithm(algo_id, session)
+ if comment:
+ algo.comment = comment
+ session.commit()
+ return make_flask_response(algo.to_proto())
+
+
+class AlgorithmsApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'algo_project_id': fields.Integer(required=False, load_default=None)}, location='query')
+ def get(self, project_id: int, algo_project_id: Optional[int]):
+ """Get the algorithms by algo_project_id
+ ---
+ tags:
+ - algorithm
+ description: get the algorithms by algo_project_id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: query
+ name: algo_project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: detail of the algorithms
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ with db.session_scope() as session:
+ query = session.query(Algorithm)
+ if project_id: # It means not to filter projects when project_id is 0
+ query = query.filter_by(project_id=project_id)
+ if algo_project_id:
+ query = query.filter_by(algorithm_project_id=algo_project_id)
+ query = query.order_by(Algorithm.created_at.desc())
+ algos = query.all()
+ results = [algo.to_proto() for algo in algos]
+ return make_flask_response(results)
+
+
+class AlgorithmTreeApi(Resource):
+
+ @credentials_required
+ def get(self, algo_id: int):
+ """Get the algorithm tree
+ ---
+ tags:
+ - algorithm
+ description: get the algorithm tree
+ parameters:
+ - in: path
+ name: algo_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: the file tree of the algorithm
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ name: FileTreeNode
+ type: object
+ properties:
+ filename:
+ type: string
+ path:
+ type: string
+ size:
+ type: integer
+ mtime:
+ type: integer
+ is_directory:
+ type: boolean
+ files:
+ type: array
+ items:
+ type: object
+ description: FileTreeNode
+ """
+ with db.session_scope() as session:
+ algo = _get_algorithm(algo_id, session)
+ # relative path is used in returned file tree
+ file_trees = FileTreeBuilder(algo.path, relpath=True).build()
+ return make_flask_response(file_trees)
+
+
+class AlgorithmFilesApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'path': fields.Str(required=True)}, location='query')
+ def get(self, path: str, algo_id: int):
+ """Get the algorithm file
+ ---
+ tags:
+ - algorithm
+ description: get the algorithm file
+ parameters:
+ - in: path
+ name: algo_id
+ schema:
+ type: integer
+ - in: query
+ name: path
+ schema:
+ type: string
+ responses:
+ 200:
+ description: content and path of the algorithm file
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ content:
+ type: string
+ path:
+ type: string
+ 400:
+ description: error exists when reading the file
+ 401:
+ description: unauthorized path under the algorithm
+ """
+ with db.session_scope() as session:
+ algo = _get_algorithm(algo_id, session)
+ path = normalize_path(os.path.join(algo.path, path))
+ if not algo.is_path_accessible(path):
+ raise UnauthorizedException(f'Unauthorized path {path} under algorithm {algo_id}')
+ try:
+ text = file_manager.read(path)
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ relpath = os.path.relpath(path, algo.path)
+ return make_flask_response({'content': text, 'path': relpath})
+
+
+def _build_release_status_query(exp: SimpleExpression) -> ColumnElement:
+ col: Column = getattr(AlgorithmProject, '_release_status')
+ return col.in_(exp.list_value.string_list)
+
+
+class AlgorithmProjectsApi(Resource):
+
+ FILTER_FIELDS = {
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'release_status': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: _build_release_status_query}),
+ 'type': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: None}),
+ }
+
+ SORTER_FIELDS = ['created_at', 'updated_at']
+
+ def __init__(self):
+ self._filter_builder = FilterBuilder(model_class=AlgorithmProject, supported_fields=self.FILTER_FIELDS)
+ self._sorter_builder = SorterBuilder(model_class=AlgorithmProject, supported_fields=self.SORTER_FIELDS)
+
+ @credentials_required
+ @use_args(GetAlgorithmProjectParams(), location='query')
+ def get(self, params: dict, project_id: int):
+ """Get the list of the algorithm project
+ ---
+ tags:
+ - algorithm
+ description: get the list of the algorithm project
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: query
+ name: name
+ schema:
+ type: string
+ - in: query
+ name: sources
+ schema:
+ type: array
+ items:
+ type: string
+ - in: query
+ name: type
+ schema:
+ type: string
+ - in: query
+ name: keyword
+ schema:
+ type: string
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: filter
+ schema:
+ type: string
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of the algorithm projects
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmProjectPb'
+ """
+ with db.session_scope() as session:
+ query = session.query(AlgorithmProject)
+ if params['name']:
+ query = query.filter_by(name=params['name'])
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ if params['type']:
+ query = query.filter_by(type=AlgorithmType[params['type']])
+ if params['sources']:
+ sources = [Source[n] for n in params['sources']]
+ query = query.filter(AlgorithmProject.source.in_(sources))
+ if params['keyword']:
+ query = query.filter(AlgorithmProject.name.like(f'%{params["keyword"]}%'))
+ if params['filter_exp']:
+ try:
+ query = self._filter_builder.build_query(query, params['filter_exp'])
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ try:
+ if params['sorter_exp'] is not None:
+ sorter_exp = parse_expression(params['sorter_exp'])
+ else:
+ sorter_exp = SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorted: {str(e)}') from e
+ pagination = paginate(query, params['page'], params['page_size'])
+ data = [d.to_proto() for d in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+ @input_validator
+ @credentials_required
+ @use_args(CreateAlgorithmProjectParams(), location='form')
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.CREATE)
+ def post(self, param: dict, project_id: int):
+ """Create an algorithm project
+ ---
+ tags:
+ - algorithm
+ description: create an algorithm project
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/CreateAlgorithmProjectParams'
+ responses:
+ 201:
+ description: detail of the algorithm project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmProjectPb'
+ 400:
+ description: the project does not exist
+ 403:
+ description: the algorithm project is forbidden to create
+ 409:
+ description: the algorithm project already exists
+ """
+ # TODO(hangweiqiang): clear the file if error in subsequent operation
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value and param['type'] == AlgorithmType.TRUSTED_COMPUTING.name:
+ raise NoAccessException(message='trusted computing is not enabled')
+ file = None
+ if 'file' in request.files:
+ file = request.files['file']
+ user = get_current_user()
+ path = algorithm_project_path(Envs.STORAGE_ROOT, param['name'])
+ with db.session_scope() as session, check_algorithm_file(path):
+ project = session.query(Project).get(project_id)
+ if project is None:
+ raise InvalidArgumentException(details=f'project {project_id} not exist')
+ algo_project = session.query(AlgorithmProject).filter_by(name=param['name'],
+ source=Source.USER,
+ project_id=project_id).first()
+ if algo_project is not None:
+ raise ResourceConflictException(message=f'algorithm project {param["name"]} already exists')
+ file_manager.mkdir(path)
+ algo_project = AlgorithmProjectService(session).create_algorithm_project(
+ name=param['name'],
+ project_id=project_id,
+ algorithm_type=AlgorithmType[param['type']],
+ username=user.username,
+ parameter=param['parameter'],
+ comment=param['comment'],
+ file=file,
+ path=path)
+ session.commit()
+ return make_flask_response(algo_project.to_proto(), status=HTTPStatus.CREATED)
+
+
+class AlgorithmProjectApi(Resource):
+
+ @credentials_required
+ def get(self, algo_project_id: int):
+ """Get the algorithm project by id
+ ---
+ tags:
+ - algorithm
+ description: get the algorithm project by id
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: detail of the algorithm project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmProjectPb'
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ result = algo_project.to_proto()
+ return make_flask_response(result)
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.UPDATE)
+ @use_args(PatchAlgorithmProjectParams(), location='json')
+ def patch(self, params: dict, algo_project_id: int):
+ """Update the algorithm project
+ ---
+ tags:
+ - algorithm
+ description: update the algorithm project
+ parameters:
+ - in: path
+ name: algorithm_project_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/PatchAlgorithmProjectParams'
+ responses:
+ 200:
+ description: detail of the algorithm project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmProjectPb'
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ if algo_project.source == Source.THIRD_PARTY:
+ raise NoAccessException(message='algo_project from THIRD_PARTY can not be edited')
+ if params['comment']:
+ algo_project.comment = params['comment']
+ if params['parameter']:
+ parameter = ParseDict(params['parameter'], AlgorithmParameter())
+ algo_project.set_parameter(parameter)
+ algo_project.release_status = ReleaseStatus.UNRELEASED
+ session.commit()
+ return make_flask_response(algo_project.to_proto())
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.DELETE)
+ def delete(self, algo_project_id: int):
+ """Delete the algorithm project
+ ---
+ tags:
+ - algorithm
+ description: delete the algorithm project
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: delete the algorithm project successfully
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ AlgorithmProjectService(session).delete(algo_project)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class AlgorithmProjectTreeApi(Resource):
+
+ @credentials_required
+ def get(self, algo_project_id: int):
+ """Get the algorithm project file tree
+ ---
+ tags:
+ - algorithm
+ description: get the algorithm project file tree
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: the file tree of the algorithm project
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ name: FileTreeNode
+ type: object
+ properties:
+ filename:
+ type: string
+ path:
+ type: string
+ size:
+ type: integer
+ mtime:
+ type: integer
+ is_directory:
+ type: boolean
+ files:
+ type: array
+ items:
+ type: object
+ description: FileTreeNode
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ # relative path is used in returned file tree
+ # TODO(gezhengqiang): change to return empty array
+ if algo_project.path is None:
+ return make_flask_response([])
+ file_trees = FileTreeBuilder(algo_project.path, relpath=True).build()
+ return make_flask_response(file_trees)
+
+
+class AlgorithmProjectFilesApi(Resource):
+
+ @staticmethod
+ def _mark_algorithm_project_unreleased(algo_project_id: int):
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ algo_project.release_status = ReleaseStatus.UNRELEASED
+ session.commit()
+
+ @credentials_required
+ @use_kwargs({'path': fields.Str(required=True)}, location='query')
+ def get(self, path: str, algo_project_id: int):
+ """Get the files of the algorithm project
+ ---
+ tags:
+ - algorithm
+ description: get the files of the algorithm project
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ - in: query
+ name: path
+ schema:
+ type: string
+ responses:
+ 200:
+ description: content and path of the algorithm file
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ content:
+ type: string
+ path:
+ type: string
+ 400:
+ description: error exists when reading the file
+ 401:
+ description: unauthorized path under the algorithm
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ path = normalize_path(os.path.join(algo_project.path, path))
+ if not algo_project.is_path_accessible(path):
+ raise UnauthorizedException(f'Unauthorized path {path} under algorithm {algo_project_id}')
+ try:
+ content = file_manager.read(path)
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ relpath = os.path.relpath(path, algo_project.path)
+ return make_flask_response({'content': content, 'path': relpath})
+
+ @credentials_required
+ @use_kwargs({'path': fields.Str(required=True), 'filename': fields.Str(required=True)}, location='form')
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.UPDATE)
+ def post(self, path: str, filename: str, algo_project_id: int):
+ """Upload the algorithm project file
+ ---
+ tags:
+ - algorithm
+ description: upload the algorithm project file
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ - in: form
+ name: path
+ schema:
+ type: string
+ - in: form
+ name: filename
+ schema:
+ type: string
+ responses:
+ 200:
+ description: filename and path of the algorithm project file
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ path:
+ type: string
+ filename:
+ type: string
+ 400:
+ description: file does not exist or is not directory
+ 401:
+ description: unauthorized path under the algorithm project
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ # TODO(hangweiqiang): check algorithm file accessibility in decorator
+ path = normalize_path(os.path.join(algo_project.path, path))
+ if not algo_project.is_path_accessible(path):
+ raise UnauthorizedException(f'Unauthorized path {path} under algorithm project {algo_project_id}')
+ if not file_manager.isdir(path):
+ raise InvalidArgumentException(details=f'file {str(path)} does not exist or is not directory')
+ secure_file_name = secure_filename(filename)
+ file_path = normalize_path(os.path.join(path, secure_file_name))
+ file = request.files['file']
+ file_content = file.read()
+ file_manager.write(file_path, file_content)
+ self._mark_algorithm_project_unreleased(algo_project_id)
+ relpath = os.path.relpath(path, algo_project.path)
+ return make_flask_response({'path': relpath, 'filename': secure_file_name})
+
+ @credentials_required
+ @use_args(UploadAlgorithmFile(), location='form')
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.UPDATE)
+ def put(self, param: dict, algo_project_id: int):
+ """put the algorithm project file
+ ---
+ tags:
+ - algorithm
+ description: put the algorithm project file
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/UploadAlgorithmFile'
+ responses:
+ 200:
+ description: content, path and filename of the algorithm project
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ content:
+ type: string
+ path:
+ type: string
+ filename:
+ type: string
+ 400:
+ description: file does not exist or is not directory or file path already exists
+ 401:
+ description: unauthorized path under the algorithm project
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ path = normalize_path(os.path.join(algo_project.path, param['path']))
+ if not algo_project.is_path_accessible(path):
+ raise UnauthorizedException(f'Unauthorized path {path} under algorithm project {algo_project_id}')
+ if not file_manager.isdir(path):
+ raise InvalidArgumentException(details=f'file {str(param["path"])} does not exist or is not directory')
+ secure_file_name = secure_filename(param['filename'])
+ file_path = os.path.join(path, secure_file_name)
+ file_content = None
+ if param['is_directory']:
+ if file_manager.exists(file_path):
+ raise InvalidArgumentException(details=f'file {str(param["path"])} already exists')
+ file_manager.mkdir(file_path)
+ else:
+ file_content = param['file']
+ file_manager.write(file_path, file_content or '')
+ if isinstance(file_content, bytes):
+ file_content = file_content.decode('utf-8')
+ self._mark_algorithm_project_unreleased(algo_project_id)
+ relpath = os.path.relpath(path, algo_project.path)
+ return make_flask_response({'content': file_content, 'path': relpath, 'filename': secure_file_name})
+
+ @credentials_required
+ @use_kwargs({'path': fields.Str(required=True)}, location='query')
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.UPDATE)
+ def delete(self, path: str, algo_project_id: int):
+ """Delete the algorithm project file
+ ---
+ tags:
+ - algorithm
+ description: delete the algorithm project file
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ - in: query
+ name: path
+ schema:
+ type: string
+ responses:
+ 204:
+ description: delete the algorithm project file successfully
+ 400:
+ description: error exists when removing the file
+ 401:
+ description: unauthorized path under the algorithm project
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ path = normalize_path(os.path.join(algo_project.path, path))
+ if not algo_project.is_path_accessible(path):
+ raise UnauthorizedException(f'Unauthorized path {path} under algorithm project {algo_project_id}')
+ try:
+ file_manager.remove(path)
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ self._mark_algorithm_project_unreleased(algo_project_id)
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+ @credentials_required
+ @use_kwargs({'path': fields.Str(required=True), 'dest': fields.Str(required=True)}, location='json')
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.UPDATE)
+ def patch(self, path: str, dest: str, algo_project_id: int):
+ """Patch the algorithm project file
+ ---
+ tags:
+ - algorithm
+ description: patch the algorithm project file
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ path:
+ type: string
+ dest:
+ type: string
+ responses:
+ 204:
+ description: patch the algorithm project file successfully
+ 401:
+ description: unauthorized path under the algorithm project
+ 401:
+ description: unauthorized dest under the algorithm project
+ """
+ with db.session_scope() as session:
+ algo_project = _get_algorithm_project(algo_project_id, session)
+ path = normalize_path(os.path.join(algo_project.path, path))
+ dest = normalize_path(os.path.join(algo_project.path, dest))
+ if not algo_project.is_path_accessible(path):
+ raise UnauthorizedException(f'Unauthorized path {path} under algorithm project {algo_project_id}')
+ if not algo_project.is_path_accessible(dest):
+ raise UnauthorizedException(f'Unauthorized dest {dest} under algorithm project {algo_project_id}')
+ try:
+ file_manager.rename(path, dest)
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ self._mark_algorithm_project_unreleased(algo_project_id)
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class ParticipantAlgorithmProjectsApi(Resource):
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'filter_exp': FilterExpField(data_key='filter', required=False, load_default=None),
+ 'sorter_exp': fields.String(data_key='order_by', required=False, load_default=None)
+ },
+ location='query')
+ def get(self, project_id: int, participant_id: int, filter_exp: Optional[str], sorter_exp: Optional[str]):
+ """Get the list of the participant algorithm project
+ ---
+ tags:
+ - algorithm
+ description: get the list of the participant algorithm project
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: query
+ name: filter_exp
+ schema:
+ type: string
+ - in: query
+ name: sorter_exp
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of the participant algorithm projects
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmProjectPb'
+ """
+ with db.session_scope() as session:
+ project = _get_project(project_id, session)
+ participants = project.participants
+ if participant_id:
+ participants = [_get_participant(participant_id, session)]
+ algorithm_projects = []
+ for participant in participants:
+ try:
+ client = ResourceServiceClient.from_project_and_participant(participant.domain_name, project.name)
+ participant_algorithm_projects = client.list_algorithm_projects(
+ filter_exp=filter_exp).algorithm_projects
+ for algo_project in participant_algorithm_projects:
+ algo_project.participant_id = participant.id
+ algorithm_projects.extend(participant_algorithm_projects)
+ except grpc.RpcError as e:
+ logging.warning(f'[algorithm] failed to get {participant.type} participant {participant.id}\'s '
+ f'algorithm projects with grpc code {e.code()} and details {e.details()}')
+ if len(algorithm_projects) != 0:
+ field = 'created_at'
+ is_asc = False
+ if sorter_exp:
+ sorter_exp = parse_expression(sorter_exp)
+ field = sorter_exp.field
+ is_asc = sorter_exp.is_asc
+ try:
+ algorithm_projects = sorted(algorithm_projects, key=lambda x: getattr(x, field), reverse=not is_asc)
+ except AttributeError as e:
+ raise InvalidArgumentException(details=f'Invalid sort attribute: {str(e)}') from e
+ return make_flask_response(algorithm_projects)
+
+
+class ParticipantAlgorithmProjectApi(Resource):
+
+ def get(self, project_id: int, participant_id: int, algorithm_project_uuid: str):
+ """Get the participant algorithm project by algorithm_project_uuid
+ ---
+ tags:
+ - algorithm
+ description: get the participant algorithm project by algorithm_project_uuid
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: path
+ name: algorithm_project_uuid
+ schema:
+ type: string
+ responses:
+ 200:
+ description: detail of the participant algorithm project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmProjectPb'
+
+ """
+ with db.session_scope() as session:
+ project = _get_project(project_id, session)
+ participant = _get_participant(participant_id, session)
+ client = ResourceServiceClient.from_project_and_participant(participant.domain_name, project.name)
+ algorithm_project = client.get_algorithm_project(algorithm_project_uuid=algorithm_project_uuid)
+ return make_flask_response(algorithm_project)
+
+
+class ParticipantAlgorithmsApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'algorithm_project_uuid': fields.Str(required=True)}, location='query')
+ def get(self, project_id: int, participant_id: int, algorithm_project_uuid: str):
+ """Get the participant algorithms by algorithm_project_uuid
+ ---
+ tags:
+ - algorithm
+ description: get the participant algorithms by algorithm_project_uuid
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: query
+ name: algorithm_project_uuid
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of the participant algorithms
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ with db.session_scope() as session:
+ project = _get_project(project_id, session)
+ participant = _get_participant(participant_id, session)
+ client = ResourceServiceClient.from_project_and_participant(participant.domain_name, project.name)
+ participant_algorithms = client.list_algorithms(algorithm_project_uuid).algorithms
+ for algo in participant_algorithms:
+ algo.participant_id = participant_id
+ algorithms = sorted(participant_algorithms, key=lambda x: x.created_at, reverse=True)
+ return make_flask_response(algorithms)
+
+
+class ParticipantAlgorithmApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, participant_id: int, algorithm_uuid: str):
+ """Get the participant algorithm by algorithm_uuid
+ ---
+ tags:
+ - algorithm
+ description: get the participant algorithm by algorithm_uuid
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: path
+ name: algorithm_uuid
+ schema:
+ type: string
+ responses:
+ 200:
+ description: detail of the participant algorithm
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+
+ """
+ with db.session_scope() as session:
+ project = _get_project(project_id, session)
+ participant = _get_participant(participant_id, session)
+ client = ResourceServiceClient.from_project_and_participant(participant.domain_name, project.name)
+ algorithm = client.get_algorithm(algorithm_uuid)
+ return make_flask_response(algorithm)
+
+
+class ParticipantAlgorithmTreeApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, participant_id: int, algorithm_uuid: str):
+ """Get the participant algorithm tree
+ ---
+ tags:
+ - algorithm
+ description: get the participant algorithm tree
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: path
+ name: algorithm_uuid
+ schema:
+ type: string
+ responses:
+ 200:
+ description: the file tree of the participant algorithm
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ name: FileTreeNode
+ type: object
+ properties:
+ filename:
+ type: string
+ path:
+ type: string
+ size:
+ type: integer
+ mtime:
+ type: integer
+ is_directory:
+ type: boolean
+ files:
+ type: array
+ items:
+ type: object
+ description: FileTreeNode
+ """
+ algorithm = AlgorithmFetcher(project_id=project_id).get_algorithm_from_participant(
+ algorithm_uuid=algorithm_uuid, participant_id=participant_id)
+
+ # relative path is used in returned file tree
+ file_trees = FileTreeBuilder(algorithm.path, relpath=True).build()
+ return make_flask_response(file_trees)
+
+
+class ParticipantAlgorithmFilesApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'path': fields.Str(required=True)}, location='query')
+ def get(self, project_id: int, participant_id: int, algorithm_uuid: str, path: str):
+ """Get the algorithm file
+ ---
+ tags:
+ - algorithm
+ description: get the algorithm file
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: path
+ name: algorithm_uuid
+ schema:
+ type: string
+ - in: query
+ name: path
+ schema:
+ type: string
+ responses:
+ 200:
+ description: content and path of the participant algorithm file
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ content:
+ type: string
+ path:
+ type: string
+ 400:
+ description: error exists when reading the file
+ 401:
+ description: unauthorized path under the algorithm
+ """
+ algorithm = AlgorithmFetcher(project_id=project_id).get_algorithm_from_participant(
+ algorithm_uuid=algorithm_uuid, participant_id=participant_id)
+ path = normalize_path(os.path.join(algorithm.path, path))
+ if not normalize_path(path).startswith(algorithm.path):
+ raise UnauthorizedException(f'Unauthorized path {path} under the participant algorithm {algorithm_uuid}')
+ try:
+ text = file_manager.read(path)
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ relpath = os.path.relpath(path, algorithm.path)
+ return make_flask_response({'content': text, 'path': relpath})
+
+
+class FetchAlgorithmApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, algorithm_uuid: str):
+ """Get the algorithm by uuid
+ ---
+ tags:
+ - algorithm
+ description: get the algorithm by uuid, whether it is from your own side or from a participant
+ parameters:
+ - in: path
+ name: algo_id
+ schema:
+ type: integer
+ - in: path
+ name: algorithm_uuid
+ schema:
+ type: string
+ responses:
+ 200:
+ description: detail of the algorithm
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ algorithm = AlgorithmFetcher(project_id=project_id).get_algorithm(uuid=algorithm_uuid)
+ return make_flask_response(algorithm)
+
+
+class UpdatePresetAlgorithmApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @emits_event(resource_type=Event.ResourceType.PRESET_ALGORITHM, op_type=Event.OperationType.UPDATE)
+ def post(self):
+ """Update the preset algorithm
+ ---
+ tags:
+ - algorithm
+ description: update the preset algorithm
+ responses:
+ 200:
+ description: detail of the preset algorithm projects
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmProjectPb'
+ """
+ create_algorithm_if_not_exists()
+ with db.session_scope() as session:
+ algo_projects = session.query(AlgorithmProject).filter_by(source=Source.PRESET).all()
+ results = [project.to_proto() for project in algo_projects]
+ return make_flask_response(results)
+
+
+class ReleaseAlgorithmApi(Resource):
+
+ @input_validator
+ @credentials_required
+ @use_kwargs({'comment': fields.Str(required=False, load_default=None, location='body')})
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM, op_type=Event.OperationType.CREATE)
+ def post(self, comment: Optional[str], algo_project_id: int):
+ """Release the algorithm
+ ---
+ tags:
+ - algorithm
+ description: release the algorithm
+ parameters:
+ - in: path
+ name: algo_project_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ comment:
+ type: string
+ responses:
+ 200:
+ description: detail of the algorithm
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ user = get_current_user()
+ with db.session_scope() as session:
+ algorithm_project = _get_algorithm_project(algo_project_id, session)
+ if algorithm_project.source == Source.THIRD_PARTY:
+ raise NoAccessException(message='algo_project from THIRD_PARTY can not be released')
+ version = algorithm_project.latest_version + 1
+ path = algorithm_path(Envs.STORAGE_ROOT, algorithm_project.name, version)
+ with check_algorithm_file(path):
+ algo = AlgorithmProjectService(session).release_algorithm(algorithm_project=algorithm_project,
+ username=user.username,
+ comment=comment,
+ path=path)
+ session.commit()
+ return make_flask_response(algo.to_proto())
+
+
+class PublishAlgorithmApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM, op_type=Event.OperationType.UPDATE)
+ def post(self, algorithm_id: int, project_id: int):
+ """Publish the algorithm
+ ---
+ tags:
+ - algorithm
+ description: publish the algorithm
+ parameters:
+ - in: path
+ name: algorithm_id
+ schema:
+ type: integer
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: detail of the algorithm
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ with db.session_scope() as session:
+ _get_algorithm(algorithm_id, session, project_id)
+ algorithm = AlgorithmService(session).publish_algorithm(algorithm_id, project_id)
+ session.commit()
+ return make_flask_response(algorithm.to_proto())
+
+
+class UnpublishAlgorithmApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM, op_type=Event.OperationType.UPDATE)
+ def post(self, algorithm_id: int, project_id: int):
+ """Unpublish the algorithm
+ ---
+ tags:
+ - algorithm
+ description: unpublish the algorithm
+ parameters:
+ - in: path
+ name: algorithm_id
+ schema:
+ type: integer
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: detail of the algorithm
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.AlgorithmPb'
+ """
+ with db.session_scope() as session:
+ _get_algorithm(algorithm_id, session, project_id)
+ algorithm = AlgorithmService(session).unpublish_algorithm(algorithm_id, project_id)
+ session.commit()
+ return make_flask_response(algorithm.to_proto())
+
+
+class PendingAlgorithmsApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int):
+ """Get the list of the pending algorithms
+ ---
+ tags:
+ - algorithm
+ description: get the list of the pending algorithms
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: list of the pending algorithms
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.PendingAlgorithmPb'
+ """
+ with db.session_scope() as session:
+ query = session.query(PendingAlgorithm)
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ query = query.order_by(PendingAlgorithm.created_at.desc())
+ pending_algorithms = query.all()
+ results = [algo.to_proto() for algo in pending_algorithms]
+ return make_flask_response(results)
+
+
+class AcceptPendingAlgorithmApi(Resource):
+
+ @input_validator
+ @credentials_required
+ @use_kwargs({
+ 'name': fields.Str(required=True),
+ 'comment': fields.Str(required=False, load_default=None, location='body')
+ })
+ @emits_event(resource_type=Event.ResourceType.ALGORITHM_PROJECT, op_type=Event.OperationType.CREATE)
+ def post(self, name: str, comment: Optional[str], project_id: int, pending_algorithm_id: int):
+ """Accept the pending algorithm
+ ---
+ tags:
+ - algorithm
+ description: accept the pending algorithm
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: pending_algorithm_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ name:
+ type: string
+ comment:
+ type: string
+ responses:
+ 204:
+ description: accept the pending algorithm successfully
+ """
+ del project_id
+ with db.session_scope() as session:
+ pending_algo = _get_pending_algorithm(pending_algorithm_id, session)
+ algo_project = session.query(AlgorithmProject).filter_by(
+ uuid=pending_algo.algorithm_project_uuid).filter_by(source=Source.THIRD_PARTY).first()
+ user = get_current_user()
+ if algo_project is None:
+ algo_project = PendingAlgorithmService(session).create_algorithm_project(pending_algorithm=pending_algo,
+ username=user.username,
+ name=name,
+ comment=comment)
+ session.flush()
+ algo_path = algorithm_path(Envs.STORAGE_ROOT, name, pending_algo.version)
+ with check_algorithm_file(algo_path):
+ pending_algo.deleted_at = now()
+ PendingAlgorithmService(session).create_algorithm(pending_algorithm=pending_algo,
+ algorithm_project_id=algo_project.id,
+ username=user.username,
+ path=algo_path,
+ comment=comment)
+ algo_project.latest_version = pending_algo.version
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class PendingAlgorithmTreeApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, pending_algo_id: int):
+ """Get the file tree of the pending algorithm
+ ---
+ tags:
+ - algorithm
+ description: get the file tree of the pending algorithm
+ parameters:
+ - in: path
+ name: pending_algo_id
+ schema:
+ type: integer
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: the file tree of the pending algorithm
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ name: FileTreeNode
+ type: object
+ properties:
+ filename:
+ type: string
+ path:
+ type: string
+ size:
+ type: integer
+ mtime:
+ type: integer
+ is_directory:
+ type: boolean
+ files:
+ type: array
+ items:
+ type: object
+ description: FileTreeNode
+ """
+ with db.session_scope() as session:
+ pending_algo = _get_pending_algorithm(pending_algo_id, session)
+ # relative path is used in returned file tree
+ file_trees = FileTreeBuilder(pending_algo.path, relpath=True).build()
+ return make_flask_response(file_trees)
+
+
+class PendingAlgorithmFilesApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'path': fields.Str(required=True)}, location='query')
+ def get(self, path: str, project_id: int, pending_algo_id: int):
+ """Get the files of the pending algorithm
+ ---
+ tags:
+ - algorithm
+ description: get the files of the pending algorithm
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: pending_algo_id
+ schema:
+ type: integer
+ - in: query
+ name: path
+ schema:
+ type: string
+ responses:
+ 200:
+ description: content and path of the pending algorithm file
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ content:
+ type: string
+ path:
+ type: string
+ 400:
+ description: error exists when reading the file
+ 401:
+ description: unauthorized path under the pending algorithm
+ """
+ with db.session_scope() as session:
+ pending_algo = _get_pending_algorithm(pending_algo_id, session)
+ path = normalize_path(os.path.join(pending_algo.path, path))
+ if not pending_algo.is_path_accessible(path):
+ raise UnauthorizedException(f'Unauthorized path {path} under pending algorithm {pending_algo_id}')
+ try:
+ text = file_manager.read(path)
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ relpath = os.path.relpath(path, pending_algo.path)
+ return make_flask_response({'content': text, 'path': relpath})
+
+
+def initialize_algorithm_apis(api):
+ # TODO(gezhengqiang): add project in the url
+ api.add_resource(AlgorithmApi, '/algorithms/')
+ api.add_resource(AlgorithmsApi, '/projects//algorithms')
+ api.add_resource(AlgorithmTreeApi, '/algorithms//tree')
+ api.add_resource(AlgorithmFilesApi, '/algorithms//files')
+ api.add_resource(AlgorithmProjectsApi, '/projects//algorithm_projects')
+ api.add_resource(AlgorithmProjectApi, '/algorithm_projects/')
+ api.add_resource(AlgorithmProjectTreeApi, '/algorithm_projects//tree')
+ api.add_resource(AlgorithmProjectFilesApi, '/algorithm_projects//files')
+ api.add_resource(ParticipantAlgorithmProjectsApi,
+ '/projects//participants//algorithm_projects')
+ api.add_resource(
+ ParticipantAlgorithmProjectApi, '/projects//participants//'
+ 'algorithm_projects/')
+ api.add_resource(ParticipantAlgorithmsApi,
+ '/projects//participants//algorithms')
+ api.add_resource(ParticipantAlgorithmApi,
+ '/projects//participants//algorithms/')
+ api.add_resource(
+ ParticipantAlgorithmTreeApi,
+ '/projects//participants//algorithms//tree')
+ api.add_resource(
+ ParticipantAlgorithmFilesApi,
+ '/projects//participants//algorithms//files')
+ api.add_resource(FetchAlgorithmApi, '/projects//algorithms/')
+ # TODO(gezhengqiang): algorithm project publish has been changed to release, the api will be deleted in future
+ api.add_resource(ReleaseAlgorithmApi,
+ '/algorithm_projects/:publish',
+ endpoint='algorithm_project:publish')
+ api.add_resource(ReleaseAlgorithmApi,
+ '/algorithm_projects/:release',
+ endpoint='algorithm_project:release')
+ api.add_resource(PublishAlgorithmApi, '/projects//algorithms/:publish')
+ api.add_resource(UnpublishAlgorithmApi, '/projects//algorithms/:unpublish')
+ api.add_resource(PendingAlgorithmsApi, '/projects//pending_algorithms')
+ api.add_resource(AcceptPendingAlgorithmApi,
+ '/projects//pending_algorithms/:accept')
+ api.add_resource(PendingAlgorithmTreeApi,
+ '/projects//pending_algorithms//tree')
+ api.add_resource(PendingAlgorithmFilesApi,
+ '/projects//pending_algorithms//files')
+ api.add_resource(UpdatePresetAlgorithmApi, '/preset_algorithms:update')
+ schema_manager.append(UploadAlgorithmFile)
+ schema_manager.append(CreateAlgorithmProjectParams)
+ schema_manager.append(GetAlgorithmProjectParams)
+ schema_manager.append(PatchAlgorithmProjectParams)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/apis_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/apis_test.py
new file mode 100644
index 000000000..bb1f924a1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/apis_test.py
@@ -0,0 +1,1464 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import tarfile
+import unittest
+import tempfile
+import urllib.parse
+import grpc
+
+from envs import Envs
+from io import BytesIO
+from http import HTTPStatus
+from datetime import datetime
+from pathlib import Path
+from unittest.mock import patch
+from testing.common import BaseTestCase
+from testing.rpc.client import FakeRpcError
+from google.protobuf.json_format import ParseDict
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.filtering import parse_expression
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.algorithm.transmit.sender import AlgorithmSender
+from fedlearner_webconsole.algorithm.models import (Algorithm, AlgorithmType, Source, AlgorithmProject,
+ PendingAlgorithm, ReleaseStatus, PublishStatus)
+from fedlearner_webconsole.algorithm.utils import algorithm_project_path
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmParameter, AlgorithmVariable, AlgorithmProjectPb, \
+ AlgorithmPb
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import ListAlgorithmsResponse,\
+ ListAlgorithmProjectsResponse
+from fedlearner_webconsole.flag.models import Flag
+
+
+def generate_algorithm_files():
+ path = tempfile.mkdtemp()
+ path = Path(path, 'e2e_test').resolve()
+ path.mkdir()
+ path.joinpath('follower').mkdir()
+ path.joinpath('follower').joinpath('main.py').touch()
+ path.joinpath('leader').mkdir()
+ file_path = path.joinpath('leader').joinpath('main.py')
+ file_path.touch()
+ file_path.write_text('import tensorflow', encoding='utf-8')
+ return str(path)
+
+
+def _generate_tar_file():
+ path = generate_algorithm_files()
+ tar_path = os.path.join(tempfile.mkdtemp(), 'test.tar.gz')
+ with tarfile.open(tar_path, 'w:gz') as tar:
+ tar.add(os.path.join(path, 'leader'), arcname='leader')
+ tar.add(os.path.join(path, 'follower'), arcname='follower')
+ tar = tarfile.open(tar_path, 'r') # pylint: disable=consider-using-with
+ return tar
+
+
+class AlgorithmApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ user = User(username='user')
+ session.add(user)
+ session.flush()
+ algo_project1 = AlgorithmProject(id=1, name='test-algo-project-1', uuid='test-algo-project-1-uuid')
+ algo1 = Algorithm(name='test-algo-1',
+ version=1,
+ project_id=1,
+ algorithm_project_id=1,
+ path=generate_algorithm_files(),
+ username=user.username,
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL)
+ algo1.set_parameter(AlgorithmParameter(variables=[AlgorithmVariable(name='BATCH_SIZE', value='123')]))
+ algo_project2 = AlgorithmProject(id=2, name='test-algo-project', publish_status=PublishStatus.PUBLISHED)
+ algo2 = Algorithm(name='test-algo-2',
+ algorithm_project_id=2,
+ publish_status=PublishStatus.PUBLISHED,
+ path=tempfile.mkdtemp())
+ session.add_all([algo_project1, algo_project2, algo1, algo2])
+ session.commit()
+
+ def test_get_algorithm_by_id(self):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo-1').first()
+ response = self.get_helper(f'/api/v2/algorithms/{algo.id}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.maxDiff = None
+ self.assertResponseDataEqual(response, {
+ 'name': 'test-algo-1',
+ 'project_id': 1,
+ 'status': 'UNPUBLISHED',
+ 'version': 1,
+ 'type': 'NN_VERTICAL',
+ 'source': 'PRESET',
+ 'username': 'user',
+ 'algorithm_project_id': 1,
+ 'algorithm_project_uuid': 'test-algo-project-1-uuid',
+ 'path': algo.path,
+ 'parameter': {
+ 'variables': [{
+ 'name': 'BATCH_SIZE',
+ 'value': '123',
+ 'required': False,
+ 'display_name': '',
+ 'comment': '',
+ 'value_type': 'STRING'
+ }]
+ },
+ 'participant_id': 0,
+ 'participant_name': '',
+ 'favorite': False,
+ 'comment': ''
+ },
+ ignore_fields=['id', 'uuid', 'created_at', 'updated_at', 'deleted_at'])
+
+ def test_get_with_not_found_exception(self):
+ response = self.get_helper('/api/v2/algorithms/12')
+ self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_delete_algorithm(self):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo-1').first()
+ resp = self.delete_helper(f'/api/v2/algorithms/{algo.id}')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo-1').execution_options(
+ include_deleted=True).first()
+ self.assertIsNone(algo)
+
+ def test_download_algorithm_files(self):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo-2').first()
+ resp = self.get_helper(f'/api/v2/algorithms/{algo.id}?download=true')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo-1').first()
+ resp = self.get_helper(f'/api/v2/algorithms/{algo.id}?download=true')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(resp.headers['Content-Disposition'], 'attachment; filename=test-algo-1.tar')
+ self.assertEqual(resp.headers['Content-Type'], 'application/x-tar')
+ tar = tarfile.TarFile(fileobj=BytesIO(resp.data)) # pylint: disable=consider-using-with
+ with tempfile.TemporaryDirectory() as temp_dir:
+ tar.extractall(temp_dir)
+ self.assertEqual(['follower', 'leader'], sorted(os.listdir(temp_dir)))
+ self.assertEqual(['main.py'], os.listdir(os.path.join(temp_dir, 'follower')))
+ self.assertEqual(['main.py'], os.listdir(os.path.join(temp_dir, 'leader')))
+
+ def test_patch_algorithm(self):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo-1').first()
+ resp = self.patch_helper(f'/api/v2/algorithms/{algo.id}', data={'comment': 'test edit comment'})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ algorithm = session.query(Algorithm).filter_by(id=algo.id).first()
+ self.assertEqual(algorithm.comment, 'test edit comment')
+
+
+class AlgorithmsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='test-project')
+ algo_project = AlgorithmProject(id=1, name='test-algo-project')
+ algo1 = Algorithm(name='test-algo-1', algorithm_project_id=1, project_id=1)
+ algo2 = Algorithm(name='test-algo-2', algorithm_project_id=1, project_id=1)
+ algo3 = Algorithm(name='test-algo-3', algorithm_project_id=2, project_id=1)
+ session.add_all([project, algo_project, algo1, algo2, algo3])
+ session.commit()
+
+ def test_get_algorithms_by_algo_project_id(self):
+ resp = self.get_helper('/api/v2/projects/1/algorithms?algo_project_id=1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['name'], 'test-algo-2')
+ self.assertEqual(data[1]['name'], 'test-algo-1')
+ resp = self.get_helper('/api/v2/projects/0/algorithms?algo_project_id=1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ resp = self.get_helper('/api/v2/projects/0/algorithms')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 3)
+
+
+class AlgorithmFilesApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ algo = Algorithm(name='test-algo', path=path)
+ session.add(algo)
+ session.commit()
+
+ def test_get_algorithm_tree(self):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo').first()
+ resp = self.get_helper(f'/api/v2/algorithms/{algo.id}/tree')
+ data = self.get_response_data(resp)
+ data = sorted(data, key=lambda d: d['filename'])
+ self.assertPartiallyEqual(data[1], {
+ 'filename': 'leader',
+ 'path': 'leader',
+ 'is_directory': True
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+ self.assertPartiallyEqual(data[1]['files'][0], {
+ 'filename': 'main.py',
+ 'path': 'leader/main.py',
+ 'is_directory': False
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+
+ def test_get_algorithm_files(self):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo').first()
+ resp = self.get_helper(f'/api/v2/algorithms/{algo.id}/files?path=..')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ resp = self.get_helper(f'/api/v2/algorithms/{algo.id}/files?path=leader')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.get_helper(f'/api/v2/algorithms/{algo.id}/files?path=leader/config.py')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.get_helper(f'/api/v2/algorithms/{algo.id}/files?path=leader/main.py')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'content': 'import tensorflow', 'path': 'leader/main.py'})
+
+
+class AlgorithmFilesDownloadApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ algo = Algorithm(name='test-algo', project_id=1, path=path)
+ session.add(algo)
+ session.commit()
+
+
+class AlgorithmProjectsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test-project')
+ session.add(project)
+ session.commit()
+
+ def test_get_algorithms(self):
+ with db.session_scope() as session:
+ algo_project_1 = AlgorithmProject(name='test-algo-1', created_at=datetime(2021, 12, 1, 0, 0, 0))
+ algo_project_2 = AlgorithmProject(name='test-algo-2',
+ project_id=1,
+ created_at=datetime(2021, 12, 1, 0, 0, 1))
+ session.add_all([algo_project_1, algo_project_2])
+ session.commit()
+ # test get all
+ response = self.get_helper('/api/v2/projects/0/algorithm_projects')
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['name'], 'test-algo-2')
+ # test get by project
+ response = self.get_helper('/api/v2/projects/1/algorithm_projects')
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'test-algo-2')
+ # test get by keyword
+ response = self.get_helper('/api/v2/projects/0/algorithm_projects?keyword=algo-2')
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'test-algo-2')
+
+ def test_get_algorithms_by_source(self):
+ with db.session_scope() as session:
+ algo_project_1 = AlgorithmProject(name='test-preset-1', source=Source.PRESET)
+ algo_project_2 = AlgorithmProject(name='test-preset-2',
+ source=Source.USER,
+ created_at=datetime(2021, 12, 1, 0, 0, 0))
+ algo_project_3 = AlgorithmProject(name='test-preset-3',
+ source=Source.THIRD_PARTY,
+ created_at=datetime(2021, 12, 1, 0, 0, 1))
+ session.add_all([algo_project_1, algo_project_2, algo_project_3])
+ session.commit()
+ response = self.get_helper('/api/v2/projects/0/algorithm_projects?sources=PRESET')
+ data = self.get_response_data(response)
+ self.assertEqual(data[0]['name'], 'test-preset-1')
+ response = self.get_helper('/api/v2/projects/0/algorithm_projects?sources=USER&sources=THIRD_PARTY')
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['name'], 'test-preset-3')
+ self.assertEqual(data[1]['name'], 'test-preset-2')
+
+ def test_get_algorithm_projects_by_filter(self):
+ with db.session_scope() as session:
+ algo_project_1 = AlgorithmProject(name='test-algo-1',
+ release_status=ReleaseStatus.RELEASED,
+ type=AlgorithmType.NN_VERTICAL,
+ created_at=datetime(2021, 12, 1, 0, 0, 0),
+ updated_at=datetime(2021, 12, 5, 3, 0, 0))
+ algo_project_2 = AlgorithmProject(name='test-algo-2',
+ release_status=ReleaseStatus.UNRELEASED,
+ type=AlgorithmType.TREE_VERTICAL,
+ created_at=datetime(2021, 12, 2, 0, 0, 0),
+ updated_at=datetime(2021, 12, 5, 4, 0, 0))
+ algo_project_3 = AlgorithmProject(name='test-preset-1',
+ release_status=ReleaseStatus.RELEASED,
+ type=AlgorithmType.NN_VERTICAL,
+ created_at=datetime(2021, 12, 3, 0, 0, 0),
+ updated_at=datetime(2021, 12, 5, 2, 0, 0))
+ algo_project_4 = AlgorithmProject(name='test-preset-2',
+ release_status=ReleaseStatus.UNRELEASED,
+ type=AlgorithmType.TREE_VERTICAL,
+ created_at=datetime(2021, 12, 4, 0, 0, 0),
+ updated_at=datetime(2021, 12, 5, 1, 0, 0))
+ session.add_all([algo_project_1, algo_project_2, algo_project_3, algo_project_4])
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/0/algorithm_projects')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 4)
+ filter_param = urllib.parse.quote('(type:["NN_VERTICAL"])')
+ resp = self.get_helper(f'/api/v2/projects/0/algorithm_projects?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['test-preset-1', 'test-algo-1'])
+ filter_param = urllib.parse.quote('(release_status:["UNRELEASED"])')
+ resp = self.get_helper(f'/api/v2/projects/0/algorithm_projects?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['test-preset-2', 'test-algo-2'])
+ filter_param = urllib.parse.quote('(name~="test-algo")')
+ resp = self.get_helper(f'/api/v2/projects/0/algorithm_projects?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['test-algo-2', 'test-algo-1'])
+ order_by_param = urllib.parse.quote('created_at asc')
+ resp = self.get_helper(f'/api/v2/projects/0/algorithm_projects?order_by={order_by_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['test-algo-1', 'test-algo-2', 'test-preset-1', 'test-preset-2'])
+ order_by_param = urllib.parse.quote('updated_at asc')
+ resp = self.get_helper(f'/api/v2/projects/0/algorithm_projects?order_by={order_by_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['test-preset-2', 'test-preset-1', 'test-algo-1', 'test-algo-2'])
+
+ def test_post_algorithm_project_with_wrong_parameter(self):
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(name='test-project').first()
+ parameters = {'variable': []}
+ resp = self.post_helper(f'/api/v2/projects/{project.id}/algorithm_projects',
+ data={
+ 'name': 'test-algo-project',
+ 'type': AlgorithmType.NN_VERTICAL.name,
+ 'parameter': json.dumps(parameters)
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ @patch('fedlearner_webconsole.algorithm.service.AlgorithmProject')
+ @patch('fedlearner_webconsole.algorithm.apis.algorithm_project_path')
+ def test_post_algorithm_project_with_exceptions(self, mock_algorithm_project_path, mock_algorithm_project):
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(name='test-project').first()
+ parameters = {'variables': [{'name': 'BATCH_SIZE', 'value': '128'}]}
+ file = (BytesIO(_generate_tar_file().fileobj.read()), 'test.tar.gz')
+ Envs.STORAGE_ROOT = tempfile.mkdtemp()
+ name = 'test-algo-project'
+ path = os.path.join(Envs.STORAGE_ROOT, 'algorithm_projects', name)
+ mock_algorithm_project_path.return_value = path
+ mock_algorithm_project.side_effect = Exception()
+ self.client.post(f'/api/v2/projects/{project.id}/algorithm_projects',
+ data={
+ 'name': name,
+ 'file': [file],
+ 'type': AlgorithmType.NN_VERTICAL.name,
+ 'parameter': json.dumps(parameters),
+ 'comment': 'haha'
+ },
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertFalse(os.path.exists(path))
+
+ def test_post_algorithm_project(self):
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(name='test-project').first()
+ parameters = {'variables': [{'name': 'BATCH_SIZE', 'value': '128'}]}
+ file = (BytesIO(_generate_tar_file().fileobj.read()), 'test.tar.gz')
+ Envs.STORAGE_ROOT = tempfile.mkdtemp()
+ resp = self.client.post(f'/api/v2/projects/{project.id}/algorithm_projects',
+ data={
+ 'name': 'test-algo-project',
+ 'file': [file],
+ 'type': AlgorithmType.NN_VERTICAL.name,
+ 'parameter': json.dumps(parameters),
+ 'comment': 'haha'
+ },
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ algo_project: AlgorithmProject = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ self.assertEqual(algo_project.type, AlgorithmType.NN_VERTICAL)
+ algo_parameter = ParseDict(parameters, AlgorithmParameter())
+ self.assertEqual(algo_project.get_parameter(), algo_parameter)
+ self.assertEqual(algo_project.comment, 'haha')
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'leader', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'follower', 'main.py')))
+ with open(os.path.join(algo_project.path, 'leader', 'main.py'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), 'import tensorflow')
+ with open(os.path.join(algo_project.path, 'follower', 'main.py'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), '')
+
+ def test_post_algorithm_project_with_empty_file(self):
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(name='test-project').first()
+ Envs.STORAGE_ROOT = tempfile.mkdtemp()
+ resp = self.client.post(f'/api/v2/projects/{project.id}/algorithm_projects',
+ data={
+ 'name': 'test-algo-project',
+ 'type': AlgorithmType.NN_VERTICAL.name,
+ },
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ self.assertEqual(algo_project.name, 'test-algo-project')
+ self.assertEqual(algo_project.type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(algo_project.get_parameter(), AlgorithmParameter())
+ self.assertTrue(os.path.exists(algo_project.path))
+ self.assertEqual(os.listdir(algo_project.path), [])
+
+ @patch('fedlearner_webconsole.utils.file_manager.FileManager.mkdir')
+ def test_post_algorithm_project_with_duplicate_name(self, mock_mkdir):
+ with db.session_scope() as session:
+ project1 = Project(id=2, name='test-project-1')
+ project2 = Project(id=3, name='test-project-2')
+ algo_project = AlgorithmProject(name='test-algo-project', project_id=2, source=Source.USER)
+ session.add_all([project1, project2, algo_project])
+ session.commit()
+ resp = self.client.post('/api/v2/projects/2/algorithm_projects',
+ data={
+ 'name': 'test-algo-project',
+ 'type': AlgorithmType.NN_VERTICAL.name,
+ },
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+ resp = self.client.post('/api/v2/projects/3/algorithm_projects',
+ data={
+ 'name': 'test-algo-project',
+ 'type': AlgorithmType.TREE_VERTICAL.name,
+ },
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+
+ def test_post_algorithm_project_with_trusted_computing(self):
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(name='test-project').first()
+ parameters = {'variables': [{'name': 'OUTPUT_PATH', 'value': '/output'}]}
+ file = (BytesIO(_generate_tar_file().fileobj.read()), 'test.tar.gz')
+ Envs.STORAGE_ROOT = tempfile.mkdtemp()
+ golden_data = {
+ 'name': 'test-algo-project-trust',
+ 'file': [file],
+ 'type': AlgorithmType.TRUSTED_COMPUTING.name,
+ 'parameter': json.dumps(parameters),
+ 'comment': 'comment for algorithm project with trusted computing type'
+ }
+ resp = self.client.post(f'/api/v2/projects/{project.id}/algorithm_projects',
+ data=golden_data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ algo_project: AlgorithmProject = \
+ session.query(AlgorithmProject).filter_by(name=golden_data['name']).first()
+ self.assertEqual(algo_project.type.name, golden_data['type'])
+ algo_parameter = ParseDict(parameters, AlgorithmParameter())
+ self.assertEqual(algo_project.get_parameter(), algo_parameter)
+ self.assertEqual(algo_project.comment, golden_data['comment'])
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'leader', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'follower', 'main.py')))
+ with open(os.path.join(algo_project.path, 'leader', 'main.py'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), 'import tensorflow')
+ with open(os.path.join(algo_project.path, 'follower', 'main.py'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), '')
+
+
+class AlgorithmProjectApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ user = User(username='test-user')
+ session.add(user)
+ session.flush()
+ algo_project = AlgorithmProject(name='test-algo-project',
+ type=AlgorithmType.NN_VERTICAL,
+ project_id=1,
+ username=user.username,
+ source=Source.PRESET,
+ path=generate_algorithm_files(),
+ comment='comment')
+ parameter = {
+ 'variables': [{
+ 'name': 'BATCH_SIZE',
+ 'value': '12',
+ 'display_name': 'batch_size',
+ 'required': False,
+ 'comment': '',
+ 'value_type': 'STRING'
+ }]
+ }
+ algo_parameter = ParseDict(parameter, AlgorithmParameter())
+ algo_project.set_parameter(algo_parameter)
+ session.add(algo_project)
+ algo_project = AlgorithmProject(name='test-algo-project-third-party',
+ type=AlgorithmType.NN_VERTICAL,
+ project_id=1,
+ username=user.username,
+ source=Source.THIRD_PARTY,
+ path=generate_algorithm_files(),
+ comment='comment')
+ session.add(algo_project)
+ session.commit()
+
+ def test_get_algorithm_project(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ parameter = to_dict(algo_project.get_parameter())
+ resp = self.get_helper(f'/api/v2/algorithm_projects/{algo_project.id}')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ expected_data = {
+ 'algorithms': [],
+ 'name': 'test-algo-project',
+ 'type': 'NN_VERTICAL',
+ 'project_id': 1,
+ 'username': 'test-user',
+ 'latest_version': 0,
+ 'source': 'PRESET',
+ 'participant_id': 0,
+ 'participant_name': '',
+ 'parameter': parameter,
+ 'publish_status': 'UNPUBLISHED',
+ 'release_status': 'UNRELEASED',
+ 'path': algo_project.path,
+ 'comment': 'comment'
+ }
+ self.maxDiff = None
+ self.assertResponseDataEqual(resp,
+ expected_data,
+ ignore_fields=['id', 'uuid', 'created_at', 'updated_at', 'deleted_at'])
+
+ def test_get_algorithms_from_algorithm_project(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ algo_1 = Algorithm(name='test-algo', version=1, algorithm_project_id=algo_project.id)
+ algo_2 = Algorithm(name='test-algo', version=2, algorithm_project_id=algo_project.id)
+ session.add_all([algo_1, algo_2])
+ session.commit()
+ resp = self.get_helper(f'/api/v2/algorithm_projects/{algo_project.id}')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data['algorithms']), 2)
+
+ def test_patch_algorithm_project(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ parameters = {'variables': [{'name': 'BATCH_SIZE', 'value': '128'}]}
+ resp = self.patch_helper(f'/api/v2/algorithm_projects/{algo_project.id}',
+ data={
+ 'parameter': parameters,
+ 'comment': 'comment'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).get(algo_project.id)
+ self.assertEqual(algo_project.comment, 'comment')
+ algo_parameter = ParseDict(parameters, AlgorithmParameter())
+ self.assertEqual(algo_project.get_parameter(), algo_parameter)
+
+ def test_patch_third_party_algorithm_project(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project-third-party').first()
+ comment = 'test edit comment'
+ resp = self.patch_helper(f'/api/v2/algorithm_projects/{algo_project.id}',
+ data={
+ 'parameter': None,
+ 'comment': comment
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).get(algo_project.id)
+ self.assertNotEqual(algo_project.comment, comment)
+
+ def test_delete_algorithm_project(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ algo = Algorithm(name='test-algo', algorithm_project_id=algo_project.id, path=generate_algorithm_files())
+ session.add(algo)
+ session.commit()
+ resp = self.delete_helper(f'/api/v2/algorithm_projects/{algo.id}')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter(
+ AlgorithmProject.name.like('%test-algo-project')).execution_options(include_deleted=True).first()
+ algo = session.query(Algorithm).filter(
+ Algorithm.name.like('%test-algo')).execution_options(include_deleted=True).first()
+ self.assertIsNone(algo_project)
+ self.assertIsNone(algo)
+
+
+class AlgorithmProjectFilesApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(name='test-algo-project', path=path)
+ session.add(algo_project)
+ session.commit()
+
+ def test_get_file_tree(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ resp = self.get_helper(f'/api/v2/algorithm_projects/{algo_project.id}/tree')
+ data = self.get_response_data(resp)
+ data = sorted(data, key=lambda d: d['filename'])
+ self.assertPartiallyEqual(data[1], {
+ 'filename': 'leader',
+ 'path': 'leader',
+ 'is_directory': True
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+ self.assertPartiallyEqual(data[1]['files'][0], {
+ 'filename': 'main.py',
+ 'path': 'leader/main.py',
+ 'is_directory': False
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+
+ def test_get_project_files(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ resp = self.get_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=leader/../..')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ resp = self.get_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=leader/config.py')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.get_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=leader')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.get_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=leader/main.py')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'content': 'import tensorflow', 'path': 'leader/main.py'})
+
+ def test_post_project_files(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ # unauthorized path under algorithm
+ data = {'path': '..', 'filename': 'test', 'file': (BytesIO(b'abcdef'), 'test.jpg')}
+ resp = self.client.post(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ # fail due to path not found
+ data = {'path': 'test', 'filename': ',.test.jpg.', 'file': (BytesIO(b'abcdef'), 'test.jpg')}
+ resp = self.client.post(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # put file under leader directory
+ data = {'path': 'leader', 'filename': ',.test.jpg.', 'file': (BytesIO(b'abcdef'), 'test.jpg')}
+ resp = self.client.post(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'path': 'leader', 'filename': 'test.jpg'})
+ with open(os.path.join(algo_project.path, 'leader', 'test.jpg'), 'rb') as fin:
+ file_content = fin.read()
+ self.assertEqual(file_content, b'abcdef')
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+
+ def test_put_empty_file(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ # put empty file under leader directory
+ data = {'path': 'leader', 'filename': 'test'}
+ resp = self.client.put(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'content': None, 'path': 'leader', 'filename': 'test'})
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'leader', 'test')))
+
+ def test_put_file_by_content(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ # put file under leader directory by content
+ data = {'path': 'leader', 'filename': 'test', 'file': BytesIO(b'123')}
+ resp = self.client.put(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'path': 'leader', 'filename': 'test', 'content': '123'})
+ with open(os.path.join(algo_project.path, 'leader', 'test'), 'r', encoding='utf-8') as file:
+ self.assertEqual(file.read(), '123')
+
+ def test_put_directory(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ # fail due to file already exist
+ data = {'path': '.', 'filename': 'leader', 'is_directory': True}
+ resp = self.client.put(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # fail due to path not exist
+ data = {'path': 'test', 'filename': 'test', 'is_directory': True}
+ resp = self.client.put(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # create directory under leader
+ data = {'path': 'leader', 'filename': 'test', 'is_directory': True}
+ resp = self.client.put(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'content': None, 'path': 'leader', 'filename': 'test'})
+ self.assertTrue(os.path.isdir(os.path.join(algo_project.path, 'leader', 'test')))
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+
+ def test_delete_project_files(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ resp = self.delete_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=leader/../..')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ resp = self.delete_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=leader/config.py')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.delete_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=leader/main.py')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'leader')))
+ self.assertFalse(os.path.exists(os.path.join(algo_project.path, 'leader', 'main.py')))
+ resp = self.delete_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files?path=follower')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ self.assertEqual(os.listdir(os.path.join(algo_project.path)), ['leader'])
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+
+ def test_patch_algorithm_project_file_rename(self):
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ resp = self.patch_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data={
+ 'path': 'leader',
+ 'dest': 'leader1'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'leader1')))
+ self.assertFalse(os.path.exists(os.path.join(algo_project.path, 'leader')))
+ resp = self.patch_helper(f'/api/v2/algorithm_projects/{algo_project.id}/files',
+ data={
+ 'path': 'leader1/main.py',
+ 'dest': 'main.py'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ self.assertTrue(os.path.exists(os.path.join(algo_project.path, 'main.py')))
+ self.assertFalse(os.path.exists(os.path.join(algo_project.path, 'leader', 'main.py')))
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+
+
+class ParticipantAlgorithmProjectsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-1')
+ participant_1 = Participant(id=1, name='part-1', domain_name='test-1')
+ participant_2 = Participant(id=2, name='part-2', domain_name='test-2')
+ project_participant_1 = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ project_participant_2 = ProjectParticipant(id=2, project_id=1, participant_id=2)
+ algorithm_project_1 = AlgorithmProject(id=1,
+ uuid='algo-project-uuid-1',
+ name='test-algo-project-1',
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.PRESET,
+ latest_version=1,
+ comment='comment-1',
+ created_at=datetime(2021, 12, 3, 0, 0, 0),
+ updated_at=datetime(2021, 12, 7, 2, 0, 0))
+ algorithm_project_2 = AlgorithmProject(id=2,
+ uuid='algo-project-uuid-2',
+ name='test-algo-project-2',
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ latest_version=1,
+ comment='comment-2',
+ created_at=datetime(2021, 12, 4, 0, 0, 0),
+ updated_at=datetime(2021, 12, 6, 2, 0, 0))
+ session.add_all([
+ project, participant_1, participant_2, project_participant_1, project_participant_2,
+ algorithm_project_1, algorithm_project_2
+ ])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.list_algorithm_projects')
+ def test_get_participant_algorithm_projects(self, mock_list_algorithm_projects):
+ with db.session_scope() as session:
+ algo_project_1 = session.query(AlgorithmProject).get(1)
+ algo_project_2 = session.query(AlgorithmProject).get(2)
+ participant_algorithm_projects1 = [algo_project_1.to_proto()]
+ participant_algorithm_projects2 = [algo_project_2.to_proto()]
+ mock_list_algorithm_projects.return_value = ListAlgorithmProjectsResponse(
+ algorithm_projects=participant_algorithm_projects1)
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithm_projects')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(mock_list_algorithm_projects.call_count, 1)
+ self.assertEqual(data[0]['uuid'], 'algo-project-uuid-1')
+ self.assertEqual(data[0]['latest_version'], 1)
+ self.assertEqual(data[0]['comment'], 'comment-1')
+ self.assertEqual(data[0]['participant_id'], 1)
+ mock_list_algorithm_projects.side_effect = [
+ ListAlgorithmProjectsResponse(algorithm_projects=participant_algorithm_projects1),
+ ListAlgorithmProjectsResponse(algorithm_projects=participant_algorithm_projects2)
+ ]
+ resp = self.get_helper('/api/v2/projects/1/participants/0/algorithm_projects')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(mock_list_algorithm_projects.call_count, 3)
+ self.assertEqual(data[0]['uuid'], 'algo-project-uuid-2')
+ self.assertEqual(data[0]['latest_version'], 1)
+ self.assertEqual(data[0]['comment'], 'comment-2')
+ self.assertEqual(data[0]['participant_id'], 2)
+ self.assertEqual(data[1]['name'], 'test-algo-project-1')
+ self.assertEqual(data[1]['source'], 'PRESET')
+ self.assertEqual(data[1]['participant_id'], 1)
+ # when grpc error
+ mock_list_algorithm_projects.side_effect = [
+ FakeRpcError(grpc.StatusCode.UNIMPLEMENTED, 'rpc not implemented'),
+ ListAlgorithmProjectsResponse(algorithm_projects=participant_algorithm_projects2)
+ ]
+ resp = self.get_helper('/api/v2/projects/1/participants/0/algorithm_projects')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 1)
+ mock_list_algorithm_projects.side_effect = [
+ FakeRpcError(grpc.StatusCode.UNIMPLEMENTED, 'rpc not implemented'),
+ FakeRpcError(grpc.StatusCode.UNIMPLEMENTED, 'rpc not implemented')
+ ]
+ resp = self.get_helper('/api/v2/projects/1/participants/0/algorithm_projects')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 0)
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.list_algorithm_projects')
+ def test_get_participant_algorithm_projects_with_filter(self, mock_list_algorithm_projects):
+ with db.session_scope() as session:
+ algo_project_1 = session.query(AlgorithmProject).get(1)
+ algo_project_2 = session.query(AlgorithmProject).get(2)
+ participant_algorithm_projects1 = [algo_project_1.to_proto()]
+ participant_algorithm_projects2 = [algo_project_1.to_proto(), algo_project_2.to_proto()]
+ mock_list_algorithm_projects.return_value = ListAlgorithmProjectsResponse(
+ algorithm_projects=participant_algorithm_projects1)
+ filter_param = urllib.parse.quote('(name~="1")')
+ resp = self.get_helper(f'/api/v2/projects/1/participants/1/algorithm_projects?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['uuid'], 'algo-project-uuid-1')
+ self.assertEqual(data[0]['latest_version'], 1)
+ self.assertEqual(data[0]['comment'], 'comment-1')
+ mock_list_algorithm_projects.assert_called_with(filter_exp=parse_expression('(name~="1")'))
+ mock_list_algorithm_projects.return_value = ListAlgorithmProjectsResponse(
+ algorithm_projects=participant_algorithm_projects2)
+ order_by_param = urllib.parse.quote('updated_at asc')
+ resp = self.get_helper(f'/api/v2/projects/1/participants/1/algorithm_projects?order_by={order_by_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(data[0]['name'], 'test-algo-project-2')
+ self.assertEqual(data[1]['uuid'], 'algo-project-uuid-1')
+ mock_list_algorithm_projects.return_value = ListAlgorithmProjectsResponse(
+ algorithm_projects=participant_algorithm_projects2)
+ order_by_param = urllib.parse.quote('created_at asc')
+ resp = self.get_helper(f'/api/v2/projects/1/participants/1/algorithm_projects?order_by={order_by_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(data[0]['name'], 'test-algo-project-1')
+ self.assertEqual(data[1]['uuid'], 'algo-project-uuid-2')
+ order_by_param = urllib.parse.quote('unknown_attribute asc')
+ resp = self.get_helper(f'/api/v2/projects/1/participants/1/algorithm_projects?order_by={order_by_param}')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+
+class ParticipantAlgorithmProjectApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-1')
+ participant = Participant(id=1, name='part-1', domain_name='test')
+ session.add_all([project, participant])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm_project')
+ def test_get_algorithm_project(self, mock_get_algorithm_project):
+ participant_algorithm_project = AlgorithmProjectPb(uuid='algo-project-uuid-1',
+ name='test-algo-project-1',
+ type=AlgorithmType.NN_VERTICAL.name,
+ source=Source.USER.name,
+ latest_version=1,
+ comment='comment-1',
+ created_at=1326542405,
+ updated_at=1326542405)
+ mock_get_algorithm_project.return_value = participant_algorithm_project
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithm_projects/algo-project-uuid-1')
+ data = self.get_response_data(resp)
+ self.assertEqual(data['uuid'], 'algo-project-uuid-1')
+ self.assertEqual(data['latest_version'], 1)
+ self.assertEqual(data['comment'], 'comment-1')
+
+
+class ParticipantAlgorithmsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-1')
+ participant = Participant(id=1, name='part-1', domain_name='test')
+ session.add_all([project, participant])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.list_algorithms')
+ def test_get_participant_algorithms(self, mock_list_algorithms):
+ parameter = ParseDict({'variables': [{'name': 'BATCH_SIZE', 'value': '128'}]}, AlgorithmParameter())
+ participant_algorithms = [
+ AlgorithmPb(uuid='algo-uuid-1',
+ name='test-algo-1',
+ version=1,
+ type=AlgorithmType.NN_VERTICAL.name,
+ source=Source.USER.name,
+ parameter=parameter,
+ comment='comment-1',
+ created_at=1326542405,
+ updated_at=1326542405),
+ AlgorithmPb(uuid='algo-uuid-2',
+ name='test-algo-2',
+ version=2,
+ type=AlgorithmType.TREE_VERTICAL.name,
+ source=Source.THIRD_PARTY.name,
+ parameter=parameter,
+ comment='comment-2',
+ created_at=1326542405,
+ updated_at=1326542405)
+ ]
+ mock_list_algorithms.return_value = ListAlgorithmsResponse(algorithms=participant_algorithms)
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithms?algorithm_project_uuid=uuid')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['uuid'], 'algo-uuid-1')
+ self.assertEqual(data[0]['version'], 1)
+ self.assertEqual(data[0]['comment'], 'comment-1')
+ self.assertEqual(data[0]['participant_id'], 1)
+ self.assertEqual(data[1]['name'], 'test-algo-2')
+ self.assertEqual(data[1]['type'], 'TREE_VERTICAL')
+ self.assertEqual(data[1]['source'], 'THIRD_PARTY')
+ self.assertEqual(data[1]['participant_id'], 1)
+
+
+class ParticipantAlgorithmApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-1')
+ participant = Participant(id=1, name='part-1', domain_name='test')
+ algorithm = Algorithm(id=1,
+ uuid='algo-uuid-1',
+ name='test-algo-1',
+ version=1,
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ session.add_all([project, participant, algorithm])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm')
+ def test_get_participant_algorithm(self, mock_get_algorithm):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(1)
+ mock_get_algorithm.return_value = AlgorithmPb(uuid=algo.uuid,
+ name=algo.name,
+ version=algo.version,
+ type=algo.type.name,
+ source=algo.source.name,
+ comment=algo.comment)
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithms/algo-uuid-1')
+ data = self.get_response_data(resp)
+ self.assertEqual(data['name'], 'test-algo-1')
+ self.assertEqual(data['version'], 1)
+ self.assertEqual(data['type'], 'NN_VERTICAL')
+ self.assertEqual(data['source'], 'USER')
+
+
+class ParticipantAlgorithmFilesApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-1')
+ participant = Participant(id=1, name='part-1', domain_name='test')
+ path = generate_algorithm_files()
+ algorithm = Algorithm(id=1, uuid='algo-uuid-1', name='test-algo-1', path=path)
+ session.add_all([project, participant, algorithm])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm')
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm_files')
+ def test_get_participant_algorithm_tree(self, mock_get_algorithm_files, mock_get_algorithm):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(1)
+ mock_get_algorithm.return_value = algo.to_proto()
+ data_iterator = AlgorithmSender().make_algorithm_iterator(algo.path)
+ mock_get_algorithm_files.return_value = data_iterator
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithms/algo-uuid-1/tree')
+ data = self.get_response_data(resp)
+ data = sorted(data, key=lambda d: d['filename'])
+ self.assertPartiallyEqual(data[1], {
+ 'filename': 'leader',
+ 'path': 'leader',
+ 'is_directory': True
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+ self.assertPartiallyEqual(data[1]['files'][0], {
+ 'filename': 'main.py',
+ 'path': 'leader/main.py',
+ 'is_directory': False
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm')
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm_files')
+ def test_get_participant_algorithm_files(self, mock_get_algorithm_files, mock_get_algorithm):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(1)
+ mock_get_algorithm.return_value = algo.to_proto()
+ data_iterator = AlgorithmSender().make_algorithm_iterator(algo.path)
+ mock_get_algorithm_files.return_value = data_iterator
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithms/algo-uuid-1/files?path=..')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithms/algo-uuid-1/files?path=leader')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.get_helper(
+ '/api/v2/projects/1/participants/1/algorithms/algo-uuid-1/files?path=leader/config.py')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.get_helper('/api/v2/projects/1/participants/1/algorithms/algo-uuid-1/files?path=leader/main.py')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'content': 'import tensorflow', 'path': 'leader/main.py'})
+
+
+class FetchAlgorithmApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='part', domain_name='part-test.com')
+ project_participant = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ algorithm = Algorithm(id=1, uuid='uuid', name='algo', project_id=1, source=Source.USER)
+ session.add_all([project, participant, project_participant, algorithm])
+ session.commit()
+
+ def test_get_algorithm(self):
+ resp = self.get_helper('/api/v2/projects/1/algorithms/uuid')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['name'], 'algo')
+
+ @patch('fedlearner_webconsole.algorithm.fetcher.AlgorithmFetcher.get_algorithm_from_participant')
+ def test_get_algorithm_from_participant(self, mock_get_algorithm_from_participant):
+ mock_get_algorithm_from_participant.return_value = AlgorithmPb(name='peer-algo')
+ resp = self.get_helper('/api/v2/projects/1/algorithms/uuid-1')
+ mock_get_algorithm_from_participant.assert_called_with(algorithm_uuid='uuid-1', participant_id=1)
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['name'], 'peer-algo')
+
+
+class UpdatePresetAlgorithmApiTest(BaseTestCase):
+
+ def test_update_preset_algorithms(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ self.signin_as_admin()
+ resp = self.post_helper('/api/v2/preset_algorithms:update')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted([d['name'] for d in data]), ['e2e_test', 'horizontal_e2e_test', 'secure_boost'])
+ with db.session_scope() as session:
+ algo_project1 = session.query(AlgorithmProject).filter_by(name='e2e_test').first()
+ self.assertIsNotNone(algo_project1)
+ self.assertEqual(algo_project1.type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(algo_project1.source, Source.PRESET)
+ algo_project2 = session.query(AlgorithmProject).filter_by(name='secure_boost').first()
+ self.assertIsNotNone(algo_project2)
+ self.assertEqual(algo_project2.type, AlgorithmType.TREE_VERTICAL)
+ self.assertEqual(algo_project2.source, Source.PRESET)
+ algo1 = session.query(Algorithm).filter_by(name='e2e_test').first()
+ self.assertIsNotNone(algo1)
+ self.assertEqual(algo1.type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(algo1.source, Source.PRESET)
+ self.assertEqual(algo1.algorithm_project_id, algo_project1.id)
+ self.assertTrue(os.path.exists(os.path.join(algo1.path, 'follower/config.py')))
+ self.assertTrue(os.path.exists(os.path.join(algo1.path, 'leader/config.py')))
+ algo2 = session.query(Algorithm).filter_by(name='secure_boost').first()
+ self.assertIsNotNone(algo2)
+ self.assertEqual(algo2.type, AlgorithmType.TREE_VERTICAL)
+ self.assertEqual(algo2.source, Source.PRESET)
+ self.assertEqual(algo2.algorithm_project_id, algo_project2.id)
+
+
+class ReleaseAlgorithmApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.algorithm.apis.algorithm_path')
+ @patch('fedlearner_webconsole.algorithm.service.Algorithm')
+ def test_release_algorithm_with_exceptions(self, mock_algorithm, mock_algorithm_path):
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(name='test-algo', path=path)
+ session.add(algo_project)
+ session.commit()
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+ Envs.STORAGE_ROOT = tempfile.mkdtemp()
+ algorithm_path = os.path.join(Envs.STORAGE_ROOT, 'algorithms', 'test_with_exceptions')
+ mock_algorithm_path.return_value = algorithm_path
+ mock_algorithm.side_effect = Exception()
+ self.post_helper(f'/api/v2/algorithm_projects/{algo_project.id}:release', {})
+ self.assertFalse(os.path.exists(algorithm_path))
+
+ def test_release_algorithm(self):
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(name='test-algo', path=path)
+ session.add(algo_project)
+ session.commit()
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+ Envs.STORAGE_ROOT = tempfile.mkdtemp()
+ resp = self.post_helper(f'/api/v2/algorithm_projects/{algo_project.id}:release', {})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).get(1)
+ algo = algo_project.algorithms[0]
+ self.assertEqual(algo_project.release_status, ReleaseStatus.RELEASED)
+ self.assertEqual(algo.name, 'test-algo')
+ self.assertEqual(algo.version, 1)
+ self.assertTrue(algo.path.startswith(Envs.STORAGE_ROOT))
+ with open(os.path.join(algo.path, 'leader', 'main.py'), 'r', encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), 'import tensorflow')
+ self.assertTrue(os.path.exists(os.path.join(algo.path, 'follower', 'main.py')))
+ resp = self.post_helper(f'/api/v2/algorithm_projects/{algo_project.id}:release', data={'comment': 'comment'})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).get(1)
+ self.assertEqual(algo_project.release_status, ReleaseStatus.RELEASED)
+ self.assertEqual(len(algo_project.algorithms), 2)
+ algo = algo_project.algorithms[0]
+ self.assertEqual(algo.name, 'test-algo')
+ self.assertEqual(algo.comment, 'comment')
+ self.assertEqual(algo.version, 2)
+
+ def test_release_algorithm_failed(self):
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(name='test-algo', path=path, source=Source.THIRD_PARTY)
+ session.add(algo_project)
+ session.commit()
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+ Envs.STORAGE_ROOT = tempfile.mkdtemp()
+ resp = self.post_helper(f'/api/v2/algorithm_projects/{algo_project.id}:release', {})
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+
+
+class PublishAlgorithmApiTest(BaseTestCase):
+
+ def test_publish_algorithm(self):
+ with db.session_scope() as session:
+ project = Project(id=1, name='test-project')
+ algo_project = AlgorithmProject(id=1,
+ project_id=1,
+ name='test-algo-project',
+ publish_status=PublishStatus.UNPUBLISHED)
+ algo = Algorithm(id=1,
+ project_id=1,
+ name='test-algo',
+ algorithm_project_id=1,
+ publish_status=PublishStatus.UNPUBLISHED)
+ session.add_all([project, algo_project, algo])
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/algorithms/1:publish')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['status'], PublishStatus.PUBLISHED.name)
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(1)
+ self.assertEqual(algo.publish_status, PublishStatus.PUBLISHED)
+ algo_project = session.query(AlgorithmProject).get(1)
+ self.assertEqual(algo_project.publish_status, PublishStatus.PUBLISHED)
+
+
+class UnpublishAlgorithmApiTest(BaseTestCase):
+
+ def test_unpublish_algorithm(self):
+ with db.session_scope() as session:
+ project = Project(id=1, name='test-project')
+ algo_project = AlgorithmProject(id=1,
+ project_id=1,
+ name='test-algo-project',
+ publish_status=PublishStatus.PUBLISHED)
+ algo1 = Algorithm(id=1,
+ project_id=1,
+ algorithm_project_id=1,
+ name='test-algo-1',
+ publish_status=PublishStatus.PUBLISHED)
+ algo2 = Algorithm(id=2,
+ project_id=1,
+ algorithm_project_id=1,
+ name='test-algo-2',
+ publish_status=PublishStatus.PUBLISHED)
+ session.add_all([project, algo_project, algo1, algo2])
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/algorithms/1:unpublish')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['status'], PublishStatus.UNPUBLISHED.name)
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(1)
+ self.assertEqual(algo.publish_status, PublishStatus.UNPUBLISHED)
+ algo_project = session.query(AlgorithmProject).get(1)
+ self.assertEqual(algo_project.publish_status, PublishStatus.PUBLISHED)
+ resp = self.post_helper('/api/v2/projects/1/algorithms/2:unpublish')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['status'], PublishStatus.UNPUBLISHED.name)
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(2)
+ self.assertEqual(algo.publish_status, PublishStatus.UNPUBLISHED)
+ algo_project = session.query(AlgorithmProject).get(1)
+ self.assertEqual(algo_project.publish_status, PublishStatus.UNPUBLISHED)
+
+
+class PendingAlgorithmsApiTest(BaseTestCase):
+
+ def test_get_pending_algorithms(self):
+ with db.session_scope() as session:
+ uuid = resource_uuid()
+ algo_project = AlgorithmProject(name='test-algo', uuid=uuid)
+ participant = Participant(name='test-part', domain_name='haha')
+ session.add(algo_project)
+ session.add(participant)
+ session.flush()
+ pending_algo_1 = PendingAlgorithm(name='test-algo-1',
+ algorithm_project_uuid=uuid,
+ project_id=1,
+ created_at=datetime(2021, 12, 2, 0, 0),
+ participant_id=participant.id)
+ pending_algo_2 = PendingAlgorithm(name='test-algo-2', project_id=2, created_at=datetime(2021, 12, 2, 0, 1))
+ pending_algo_3 = PendingAlgorithm(name='test-algo-3', project_id=2, deleted_at=datetime(2021, 12, 2))
+ session.add_all([pending_algo_1, pending_algo_2, pending_algo_3])
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/0/pending_algorithms')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['name'], 'test-algo-2')
+ resp = self.get_helper('/api/v2/projects/1/pending_algorithms')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 1)
+ with db.session_scope() as session:
+ algo_project = session.query(AlgorithmProject).filter_by(uuid=uuid).first()
+ self.assertPartiallyEqual(
+ data[0], {
+ 'name': 'test-algo-1',
+ 'project_id': 1,
+ 'algorithm_project_id': algo_project.id,
+ 'version': 0,
+ 'type': 'UNSPECIFIED',
+ 'path': '',
+ 'comment': '',
+ 'participant_id': participant.id,
+ 'participant_name': 'test-part'
+ },
+ ignore_fields=['id', 'algorithm_uuid', 'algorithm_project_uuid', 'created_at', 'updated_at', 'deleted_at'])
+
+
+class AcceptPendingAlgorithmApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test')
+ session.add(project)
+ session.flush()
+ pending_algo = PendingAlgorithm(name='test-algo',
+ version=2,
+ project_id=project.id,
+ algorithm_uuid=resource_uuid(),
+ algorithm_project_uuid=resource_uuid(),
+ type=AlgorithmType.NN_VERTICAL,
+ participant_id=1)
+ pending_algo.path = generate_algorithm_files()
+ session.add(pending_algo)
+ session.commit()
+
+ @patch('fedlearner_webconsole.algorithm.service.Algorithm')
+ @patch('fedlearner_webconsole.algorithm.service.AlgorithmProject')
+ @patch('fedlearner_webconsole.algorithm.apis.algorithm_project_path')
+ @patch('fedlearner_webconsole.algorithm.apis.algorithm_path')
+ def test_accept_pending_algorithm_with_exceptions(self, mock_algorithm_path, mock_algorithm_project_path,
+ mock_algorithm, mock_algorithm_project):
+ with db.session_scope() as session:
+ pending_algo = session.query(PendingAlgorithm).filter_by(name='test-algo').first()
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ name = 'test_with_exceptions'
+ algo_project_path = os.path.join(Envs.STORAGE_ROOT, 'algorithm_projects', name)
+ mock_algorithm_project_path.return_value = algo_project_path
+ algorithm_path = os.path.join(Envs.STORAGE_ROOT, 'algorithms', name)
+ mock_algorithm_path.return_value = algorithm_path
+ mock_algorithm.side_effect = Exception()
+ mock_algorithm_project.side_effect = Exception()
+ self.post_helper(f'/api/v2/projects/1/pending_algorithms/{pending_algo.id}:accept', data={'name': 'algo-1'})
+ self.assertFalse(os.path.exists(algorithm_path))
+ self.assertFalse(os.path.exists(algo_project_path))
+
+ def test_accept_pending_algorithm(self):
+ with db.session_scope() as session:
+ pending_algo = session.query(PendingAlgorithm).filter_by(name='test-algo').first()
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ resp = self.post_helper(f'/api/v2/projects/1/pending_algorithms/{pending_algo.id}:accept',
+ data={'name': 'algo-1'})
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ pending_algo = session.query(PendingAlgorithm).execution_options(include_deleted=True).filter_by(
+ name='test-algo').first()
+ self.assertTrue(bool(pending_algo.deleted_at))
+ algo_project = session.query(AlgorithmProject).filter_by(name='algo-1').first()
+ self.assertEqual(algo_project.username, 'ada')
+ self.assertEqual(algo_project.participant_id, pending_algo.participant_id)
+ self.assertEqual(algo_project.latest_version, pending_algo.version)
+ self.assertEqual(algo_project.type, pending_algo.type)
+ self.assertEqual(algo_project.source, Source.THIRD_PARTY)
+ self.assertEqual(len(algo_project.algorithms), 1)
+ self.assertEqual(algo_project.release_status, ReleaseStatus.RELEASED)
+ self.assertEqual(algo_project.algorithms[0].participant_id, pending_algo.participant_id)
+
+ def test_accept_with_duplicate_uuid(self):
+ uuid = resource_uuid()
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ with db.session_scope() as session:
+ pending_algo = session.query(PendingAlgorithm).filter_by(name='test-algo').first()
+ algorithm_project_uuid = pending_algo.algorithm_project_uuid
+ algo_project_path = algorithm_project_path(Envs.STORAGE_ROOT, 'test-algo')
+ algo_project = AlgorithmProject(name='test-algo-project',
+ uuid=algorithm_project_uuid,
+ path=algo_project_path,
+ source=Source.THIRD_PARTY)
+ algo_project.release_status = ReleaseStatus.RELEASED
+ session.add(algo_project)
+ session.commit()
+
+ resp = self.post_helper(f'/api/v2/projects/1/pending_algorithms/{pending_algo.id}:accept',
+ data={'name': 'test-algo-project'})
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ pending_algo = session.query(PendingAlgorithm).execution_options(include_deleted=True).filter_by(
+ name='test-algo').first()
+ self.assertTrue(bool(pending_algo.deleted_at))
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project').first()
+ self.assertEqual(algo_project.source, Source.THIRD_PARTY)
+ self.assertEqual(len(algo_project.algorithms), 1)
+ self.assertEqual(algo_project.release_status, ReleaseStatus.RELEASED)
+ self.assertEqual(algo_project.uuid, pending_algo.algorithm_project_uuid)
+ self.assertEqual(algo_project.algorithms[0].participant_id, pending_algo.participant_id)
+ self.assertEqual(algo_project.algorithms[0].name, pending_algo.name)
+ self.assertEqual(algo_project.algorithms[0].parameter, pending_algo.parameter)
+ self.assertEqual(algo_project.algorithms[0].uuid, pending_algo.algorithm_uuid)
+ self.assertEqual(sorted(os.listdir(algo_project.algorithms[0].path)), ['follower', 'leader'])
+
+
+class PendingAlgorithmTreeApi(BaseTestCase):
+
+ def test_get_tree(self):
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ pending_algo = PendingAlgorithm(name='test-algo', path=path)
+ session.add(pending_algo)
+ session.commit()
+ resp = self.get_helper(f'/api/v2/projects/0/pending_algorithms/{pending_algo.id}/tree')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ data = sorted(data, key=lambda d: d['filename'])
+ self.assertPartiallyEqual(data[1], {
+ 'filename': 'leader',
+ 'path': 'leader',
+ 'is_directory': True
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+ self.assertPartiallyEqual(data[1]['files'][0], {
+ 'filename': 'main.py',
+ 'path': 'leader/main.py',
+ 'is_directory': False
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+
+
+class PendingAlgorithmFilesApi(BaseTestCase):
+
+ def test_get_file(self):
+ path = generate_algorithm_files()
+ with db.session_scope() as session:
+ pending_algo = PendingAlgorithm(name='test-algo', path=path)
+ session.add(pending_algo)
+ session.commit()
+ resp = self.get_helper(f'/api/v2/projects/0/pending_algorithms/{pending_algo.id}/files?path=leader/main.py')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'content': 'import tensorflow', 'path': 'leader/main.py'})
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/fetcher.py b/web_console_v2/api/fedlearner_webconsole/algorithm/fetcher.py
new file mode 100644
index 000000000..1ccd68d90
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/fetcher.py
@@ -0,0 +1,67 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from envs import Envs
+from fedlearner_webconsole.algorithm.models import Algorithm, Source
+from fedlearner_webconsole.algorithm.utils import algorithm_cache_path
+from fedlearner_webconsole.algorithm.utils import check_algorithm_file
+from fedlearner_webconsole.algorithm.transmit.receiver import AlgorithmReceiver
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmPb
+from fedlearner_webconsole.rpc.v2.resource_service_client import ResourceServiceClient
+from fedlearner_webconsole.utils.file_manager import file_manager
+from fedlearner_webconsole.exceptions import NotFoundException
+
+
+class AlgorithmFetcher:
+
+ def __init__(self, project_id: int):
+ self._project_id = project_id
+
+ def get_algorithm_from_participant(self, algorithm_uuid: str, participant_id: int) -> AlgorithmPb:
+ with db.session_scope() as session:
+ project = session.query(Project).get(self._project_id)
+ participant = session.query(Participant).get(participant_id)
+ client = ResourceServiceClient.from_project_and_participant(participant.domain_name, project.name)
+ algorithm = client.get_algorithm(algorithm_uuid=algorithm_uuid)
+ algo_cache_path = algorithm_cache_path(Envs.STORAGE_ROOT, algorithm_uuid)
+ if not file_manager.exists(algo_cache_path):
+ data_iterator = client.get_algorithm_files(algorithm_uuid=algorithm_uuid)
+ # Get the hash in the first response to be used for verification when the file is received
+ resp = next(data_iterator)
+ with check_algorithm_file(algo_cache_path):
+ AlgorithmReceiver().write_data_and_extract(data_iterator, algo_cache_path, resp.hash)
+ algorithm.path = algo_cache_path
+ algorithm.source = Source.PARTICIPANT.name
+ algorithm.participant_id = participant_id
+ return algorithm
+
+ def get_algorithm(self, uuid: str) -> AlgorithmPb:
+ """Raise NotFoundException when the algorithm is not found"""
+
+ with db.session_scope() as session:
+ algorithm = session.query(Algorithm).filter_by(uuid=uuid).first()
+ participants = session.query(Project).get(self._project_id).participants
+ if algorithm:
+ return algorithm.to_proto()
+ for participant in participants:
+ try:
+ return self.get_algorithm_from_participant(algorithm_uuid=uuid, participant_id=participant.id)
+ except grpc.RpcError:
+ continue
+ raise NotFoundException(f'the algorithm uuid: {uuid} is not found')
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/fetcher_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/fetcher_test.py
new file mode 100644
index 000000000..873c033ea
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/fetcher_test.py
@@ -0,0 +1,101 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+from envs import Envs
+from google.protobuf.json_format import ParseDict
+from testing.common import NoWebServerTestCase
+from unittest.mock import patch
+from fedlearner_webconsole.algorithm.models import Algorithm, Source
+from fedlearner_webconsole.algorithm.transmit.sender import AlgorithmSender
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.algorithm.utils import algorithm_cache_path
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmVariable, AlgorithmParameter
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.proto import remove_secrets
+
+_TEST_ALGORITHM_PATH = os.path.join(Envs.BASE_DIR, 'testing/test_data/algorithm/e2e_test')
+
+
+class AlgorithmFetcherTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='part', domain_name='test')
+ project.participants = [participant]
+ algo1 = Algorithm(id=1, project_id=1, name='test-algo-1', uuid='algo-1', path=_TEST_ALGORITHM_PATH)
+ parameter1 = ParseDict({'variables': [{'name': 'BATCH_SIZE', 'value': '128'}]}, AlgorithmParameter())
+ algo1.set_parameter(parameter1)
+ algo2 = Algorithm(id=2, project_id=1, name='test-algo-2', uuid='algo-2')
+ parameter2 = ParseDict({'variables': [{'name': 'MAX_DEPTH', 'value': '5'}]}, AlgorithmParameter())
+ algo2.set_parameter(parameter2)
+ session.add_all([participant, project, algo1, algo2])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm')
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm_files')
+ def test_get_algorithm_from_participant(self, mock_get_algorithm_files, mock_get_algorithm):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(1)
+ mock_get_algorithm.return_value = remove_secrets(algo.to_proto())
+ mock_get_algorithm_files.return_value = AlgorithmSender().make_algorithm_iterator(algo.path)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ algorithm_uuid = 'uuid'
+ algorithm = AlgorithmFetcher(project_id=1).get_algorithm_from_participant(algorithm_uuid=algorithm_uuid,
+ participant_id=1)
+ algo_cache_path = algorithm_cache_path(Envs.STORAGE_ROOT, algorithm_uuid)
+ self.assertTrue(os.path.exists(algo_cache_path))
+ self.assertEqual(algo_cache_path, algorithm.path)
+ self.assertEqual(algorithm.source, Source.PARTICIPANT.name)
+ self.assertEqual(algorithm.id, 0)
+ self.assertEqual(algorithm.algorithm_project_id, 0)
+ self.assertEqual(algorithm.parameter,
+ AlgorithmParameter(variables=[AlgorithmVariable(name='BATCH_SIZE', value='128')]))
+ self.assertEqual(sorted(os.listdir(algo_cache_path)), ['follower', 'leader'])
+ with open(os.path.join(algo_cache_path, 'leader', 'main.py'), encoding='utf-8') as f:
+ self.assertEqual(f.read(), 'import tensorflow\n')
+ with open(os.path.join(algo_cache_path, 'follower', 'main.py'), encoding='utf-8') as f:
+ self.assertEqual(f.read(), '')
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm')
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm_files')
+ def test_get_algorithm(self, mock_get_algorithm_files, mock_get_algorithm):
+ with db.session_scope() as session:
+ algo1 = session.query(Algorithm).get(1)
+ algo2 = session.query(Algorithm).get(2)
+ mock_get_algorithm.return_value = remove_secrets(algo2.to_proto())
+ mock_get_algorithm_files.return_value = AlgorithmSender().make_algorithm_iterator(algo1.path)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ fetcher = AlgorithmFetcher(project_id=1)
+ algorithm1 = fetcher.get_algorithm('algo-1')
+ self.assertEqual(algorithm1.path, algo1.path)
+ self.assertEqual(algorithm1.parameter,
+ AlgorithmParameter(variables=[AlgorithmVariable(name='BATCH_SIZE', value='128')]))
+ algorithm2 = fetcher.get_algorithm('algo-3')
+ self.assertEqual(algorithm2.path, algorithm_cache_path(Envs.STORAGE_ROOT, 'algo-3'))
+ self.assertEqual(algorithm2.parameter,
+ AlgorithmParameter(variables=[AlgorithmVariable(name='MAX_DEPTH', value='5')]))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/models.py b/web_console_v2/api/fedlearner_webconsole/algorithm/models.py
new file mode 100644
index 000000000..a764e8f47
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/models.py
@@ -0,0 +1,354 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import enum
+from typing import Optional
+from sqlalchemy.sql.schema import Index, UniqueConstraint
+from google.protobuf import text_format
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmParameter, PendingAlgorithmPb, AlgorithmPb, \
+ AlgorithmProjectPb
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+from fedlearner_webconsole.utils.base_model.softdelete_model import SoftDeleteModel
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus, ReviewTicketModel
+
+
+def normalize_path(path: str) -> str:
+ if path.startswith('hdfs://'):
+ return path
+ if path.startswith('file://'):
+ _, pure_path = path.split('://')
+ return f'file://{os.path.normpath(pure_path)}'
+ return os.path.normpath(path)
+
+
+class AlgorithmType(enum.Enum):
+ UNSPECIFIED = 0
+ NN_LOCAL = 1
+ NN_HORIZONTAL = 2
+ NN_VERTICAL = 3
+ TREE_VERTICAL = 4
+ TRUSTED_COMPUTING = 5
+
+
+class Source(enum.Enum):
+ UNSPECIFIED = 0
+ PRESET = 1
+ USER = 2
+ THIRD_PARTY = 3 # deprecated
+ PARTICIPANT = 4 # algorithm from participant
+
+
+class ReleaseStatus(enum.Enum):
+ UNPUBLISHED = 0 # deprecated
+ PUBLISHED = 1 # deprecated
+ UNRELEASED = 'UNRELEASED'
+ RELEASED = 'RELEASED'
+
+
+class PublishStatus(enum.Enum):
+ UNPUBLISHED = 'UNPUBLISHED'
+ PUBLISHED = 'PUBLISHED'
+
+
+class AlgorithmStatus(enum.Enum):
+ UNPUBLISHED = 'UNPUBLISHED'
+ PENDING_APPROVAL = 'PENDING_APPROVAL'
+ APPROVED = 'APPROVED'
+ DECLINED = 'DECLINED'
+ PUBLISHED = 'PUBLISHED'
+
+
+# TODO(hangweiqiang): read https://docs.sqlalchemy.org/en/14/orm/inheritance.html and try refactor
+class AlgorithmProject(db.Model, SoftDeleteModel):
+ __tablename__ = 'algorithm_projects_v2'
+ __table_args__ = (UniqueConstraint('name', 'source', 'project_id', name='uniq_name_source_project_id'),
+ UniqueConstraint('uuid', name='uniq_uuid'), default_table_args('algorithm_projects'))
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ uuid = db.Column(db.String(64), comment='uuid')
+ name = db.Column(db.String(255), comment='name')
+ project_id = db.Column(db.Integer, comment='project id')
+ latest_version = db.Column(db.Integer, default=0, comment='latest version')
+ type = db.Column('algorithm_type',
+ db.Enum(AlgorithmType, native_enum=False, length=32, create_constraint=False),
+ default=AlgorithmType.UNSPECIFIED,
+ key='type',
+ comment='algorithm type')
+ source = db.Column(db.Enum(Source, native_enum=False, length=32, create_constraint=False),
+ default=Source.UNSPECIFIED,
+ comment='algorithm source')
+ # Algorithm project publish has been modified to release. Algorithm project is unreleased when file or
+ # parameter is edited. In order to ensure compatibility, it is still saved as publish_status in the database,
+ # and _release_status is added to the model layer to make a conversion when data is used.
+ _release_status = db.Column('publish_status',
+ db.Enum(ReleaseStatus, native_enum=False, length=32, create_constraint=False),
+ default=ReleaseStatus.UNRELEASED,
+ comment='release status')
+ publish_status = db.Column('publish_status_v2',
+ db.Enum(PublishStatus, native_enum=False, length=32, create_constraint=False),
+ server_default=PublishStatus.UNPUBLISHED.name,
+ comment='publish status')
+ username = db.Column(db.String(255), comment='creator name')
+ participant_id = db.Column(db.Integer, comment='participant id')
+ path = db.Column('fspath', db.String(512), key='path', comment='algorithm project path')
+ parameter = db.Column(db.Text(), comment='parameter')
+ comment = db.Column('cmt', db.String(255), key='comment', comment='comment')
+ created_at = db.Column(db.DateTime(timezone=True), default=now, comment='created time')
+ updated_at = db.Column(db.DateTime(timezone=True), default=now, onupdate=now, comment='updated time')
+ deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted time')
+ project = db.relationship(Project.__name__, primaryjoin='foreign(AlgorithmProject.project_id) == Project.id')
+ user = db.relationship(User.__name__, primaryjoin='foreign(AlgorithmProject.username) == User.username')
+ participant = db.relationship(Participant.__name__,
+ primaryjoin='foreign(AlgorithmProject.participant_id) == Participant.id')
+ algorithms = db.relationship(
+ 'Algorithm',
+ order_by='desc(Algorithm.version)',
+ primaryjoin='foreign(Algorithm.algorithm_project_id) == AlgorithmProject.id',
+ # To disable the warning of back_populates
+ overlaps='algorithm_project')
+
+ @property
+ def release_status(self) -> ReleaseStatus:
+ if self._release_status == ReleaseStatus.UNPUBLISHED:
+ return ReleaseStatus.UNRELEASED
+ if self._release_status == ReleaseStatus.PUBLISHED:
+ return ReleaseStatus.RELEASED
+ return self._release_status
+
+ @release_status.setter
+ def release_status(self, release_status: ReleaseStatus):
+ self._release_status = release_status
+
+ def set_parameter(self, parameter: Optional[AlgorithmParameter] = None):
+ if parameter is None:
+ parameter = AlgorithmParameter()
+ self.parameter = text_format.MessageToString(parameter)
+
+ def get_parameter(self) -> Optional[AlgorithmParameter]:
+ if self.parameter is not None:
+ return text_format.Parse(self.parameter, AlgorithmParameter())
+ return None
+
+ def is_path_accessible(self, path: str):
+ if self.path is None:
+ return False
+ return normalize_path(path).startswith(self.path)
+
+ def get_participant_name(self):
+ if self.participant is not None:
+ return self.participant.name
+ return None
+
+ def to_proto(self) -> AlgorithmProjectPb:
+ return AlgorithmProjectPb(
+ id=self.id,
+ uuid=self.uuid,
+ name=self.name,
+ project_id=self.project_id,
+ latest_version=self.latest_version,
+ type=self.type.name,
+ source=self.source.name,
+ publish_status=self.publish_status.name,
+ release_status=self.release_status.name,
+ username=self.username,
+ participant_id=self.participant_id,
+ participant_name=self.get_participant_name(),
+ path=self.path,
+ parameter=self.get_parameter(),
+ comment=self.comment,
+ created_at=to_timestamp(self.created_at) if self.created_at else None,
+ updated_at=to_timestamp(self.updated_at) if self.updated_at else None,
+ deleted_at=to_timestamp(self.deleted_at) if self.deleted_at else None,
+ algorithms=[algo.to_proto() for algo in self.algorithms],
+ )
+
+
+class Algorithm(db.Model, SoftDeleteModel, ReviewTicketModel):
+ __tablename__ = 'algorithms_v2'
+ __table_args__ = (Index('idx_name',
+ 'name'), UniqueConstraint('source', 'name', 'version', name='uniq_source_name_version'),
+ UniqueConstraint('uuid', name='uniq_uuid'), default_table_args('algorithms'))
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ uuid = db.Column(db.String(64), comment='uuid')
+ name = db.Column(db.String(255), comment='name')
+ project_id = db.Column(db.Integer, comment='project id')
+ version = db.Column(db.Integer, comment='version')
+ type = db.Column('algorithm_type',
+ db.Enum(AlgorithmType, native_enum=False, length=32, create_constraint=False),
+ default=AlgorithmType.UNSPECIFIED,
+ key='type',
+ comment='algorithm type')
+ source = db.Column(db.Enum(Source, native_enum=False, length=32, create_constraint=False),
+ default=Source.UNSPECIFIED,
+ comment='source')
+ publish_status = db.Column(db.Enum(PublishStatus, native_enum=False, length=32, create_constraint=False),
+ default=PublishStatus.UNPUBLISHED,
+ comment='publish status')
+ algorithm_project_id = db.Column(db.Integer, comment='algorithm project id')
+ username = db.Column(db.String(255), comment='creator name')
+ participant_id = db.Column(db.Integer, comment='participant id')
+ path = db.Column('fspath', db.String(512), key='path', comment='algorithm path')
+ parameter = db.Column(db.Text(), comment='parameter')
+ favorite = db.Column(db.Boolean, default=False, comment='favorite')
+ comment = db.Column('cmt', db.String(255), key='comment', comment='comment')
+ created_at = db.Column(db.DateTime(timezone=True), default=now, comment='created time')
+ updated_at = db.Column(db.DateTime(timezone=True), default=now, onupdate=now, comment='updated time')
+ deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted time')
+ project = db.relationship(Project.__name__, primaryjoin='foreign(Algorithm.project_id) == Project.id')
+ user = db.relationship(User.__name__, primaryjoin='foreign(Algorithm.username) == User.username')
+ participant = db.relationship(Participant.__name__,
+ primaryjoin='foreign(Algorithm.participant_id) == Participant.id')
+ algorithm_project = db.relationship(AlgorithmProject.__name__,
+ primaryjoin='foreign(Algorithm.algorithm_project_id) == AlgorithmProject.id')
+
+ def set_parameter(self, parameter: Optional[AlgorithmParameter] = None):
+ if parameter is None:
+ parameter = AlgorithmParameter()
+ self.parameter = text_format.MessageToString(parameter)
+
+ def get_parameter(self) -> Optional[AlgorithmParameter]:
+ if self.parameter is not None:
+ return text_format.Parse(self.parameter, AlgorithmParameter())
+ return None
+
+ def is_path_accessible(self, path: str):
+ if self.path is None:
+ return False
+ return normalize_path(path).startswith(self.path)
+
+ def get_participant_name(self):
+ if self.participant is not None:
+ return self.participant.name
+ return None
+
+ def get_status(self) -> AlgorithmStatus:
+ if self.publish_status == PublishStatus.PUBLISHED:
+ return AlgorithmStatus.PUBLISHED
+ if self.ticket_uuid is not None:
+ if self.ticket_status == TicketStatus.PENDING:
+ return AlgorithmStatus.PENDING_APPROVAL
+ if self.ticket_status == TicketStatus.APPROVED:
+ return AlgorithmStatus.APPROVED
+ if self.ticket_status == TicketStatus.DECLINED:
+ return AlgorithmStatus.DECLINED
+ return AlgorithmStatus.UNPUBLISHED
+
+ def get_algorithm_project_uuid(self) -> Optional[str]:
+ if self.algorithm_project:
+ return self.algorithm_project.uuid
+ return None
+
+ def to_proto(self) -> AlgorithmPb:
+ return AlgorithmPb(
+ id=self.id,
+ uuid=self.uuid,
+ name=self.name,
+ project_id=self.project_id,
+ version=self.version,
+ type=self.type.name,
+ source=self.source.name,
+ status=self.get_status().name,
+ algorithm_project_id=self.algorithm_project_id,
+ algorithm_project_uuid=self.get_algorithm_project_uuid(),
+ username=self.username,
+ participant_id=self.participant_id,
+ participant_name=self.get_participant_name(),
+ # TODO(gezhengqiang): delete participant name
+ path=self.path,
+ parameter=self.get_parameter(),
+ favorite=self.favorite,
+ comment=self.comment,
+ created_at=to_timestamp(self.created_at) if self.created_at else None,
+ updated_at=to_timestamp(self.updated_at) if self.updated_at else None,
+ deleted_at=to_timestamp(self.deleted_at) if self.deleted_at else None,
+ )
+
+
+class PendingAlgorithm(db.Model, SoftDeleteModel):
+ __tablename__ = 'pending_algorithms_v2'
+ __table_args__ = (default_table_args('pending_algorithms'))
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ algorithm_uuid = db.Column(db.String(64), comment='algorithm uuid')
+ algorithm_project_uuid = db.Column(db.String(64), comment='algorithm project uuid')
+ name = db.Column(db.String(255), comment='name')
+ project_id = db.Column(db.Integer, comment='project id')
+ version = db.Column(db.Integer, comment='version')
+ type = db.Column('algorithm_type',
+ db.Enum(AlgorithmType, native_enum=False, length=32, create_constraint=False),
+ default=AlgorithmType.UNSPECIFIED,
+ key='type',
+ comment='algorithm type')
+ participant_id = db.Column(db.Integer, comment='participant id')
+ path = db.Column('fspath', db.String(512), key='path', comment='algorithm path')
+ parameter = db.Column(db.Text(), comment='parameter')
+ comment = db.Column('cmt', db.String(255), key='comment', comment='comment')
+ created_at = db.Column(db.DateTime(timezone=True), default=now, comment='created time')
+ updated_at = db.Column(db.DateTime(timezone=True), default=now, onupdate=now, comment='updated time')
+ deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted time')
+ project = db.relationship(Project.__name__, primaryjoin='foreign(PendingAlgorithm.project_id) == Project.id')
+ participant = db.relationship(Participant.__name__,
+ primaryjoin='foreign(PendingAlgorithm.participant_id) == Participant.id')
+ algorithm_project = db.relationship(
+ AlgorithmProject.__name__,
+ primaryjoin='foreign(PendingAlgorithm.algorithm_project_uuid) == AlgorithmProject.uuid')
+
+ def set_parameter(self, parameter: Optional[AlgorithmParameter] = None):
+ if parameter is None:
+ parameter = AlgorithmParameter()
+ self.parameter = text_format.MessageToString(parameter)
+
+ def get_parameter(self) -> Optional[AlgorithmParameter]:
+ if self.parameter is not None:
+ return text_format.Parse(self.parameter, AlgorithmParameter())
+ return None
+
+ def is_path_accessible(self, path: str):
+ if self.path is None:
+ return False
+ return normalize_path(path).startswith(self.path)
+
+ def get_participant_name(self):
+ if self.participant:
+ return self.participant.name
+ return None
+
+ def get_algorithm_project_id(self) -> Optional[int]:
+ if self.algorithm_project:
+ return self.algorithm_project.id
+ return None
+
+ def to_proto(self) -> PendingAlgorithmPb:
+ return PendingAlgorithmPb(
+ id=self.id,
+ algorithm_uuid=self.algorithm_uuid,
+ algorithm_project_uuid=self.algorithm_project_uuid,
+ name=self.name,
+ project_id=self.project_id,
+ version=self.version,
+ type=self.type.name,
+ participant_id=self.participant_id,
+ participant_name=self.get_participant_name(),
+ path=self.path,
+ parameter=self.get_parameter(),
+ comment=self.comment,
+ created_at=to_timestamp(self.created_at) if self.created_at else None,
+ updated_at=to_timestamp(self.updated_at) if self.updated_at else None,
+ deleted_at=to_timestamp(self.deleted_at) if self.deleted_at else None,
+ algorithm_project_id=self.get_algorithm_project_id(),
+ )
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/models_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/models_test.py
new file mode 100644
index 000000000..19dc63f80
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/models_test.py
@@ -0,0 +1,169 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, AlgorithmType, Source,\
+ PendingAlgorithm, PublishStatus, AlgorithmStatus, normalize_path
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmParameter, AlgorithmVariable, AlgorithmPb,\
+ AlgorithmProjectPb
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+
+
+class AlgorithmTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(id=1, name='test-algo-project', uuid='test-algo-project-uuid')
+ algo = Algorithm(id=1,
+ algorithm_project_id=1,
+ name='test-algo',
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ path='/data',
+ created_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ updated_at=datetime(2022, 2, 22, tzinfo=timezone.utc))
+ algo.set_parameter(AlgorithmParameter(variables=[AlgorithmVariable(name='MAX_ITERS', value='5')]))
+ session.add_all([algo_project, algo])
+ session.commit()
+
+ def test_parameter(self):
+ algo = Algorithm(name='test-algo')
+ parameters = AlgorithmParameter(variables=[AlgorithmVariable(name='MAX_ITERS', value='5')])
+ algo.set_parameter(parameters)
+ self.assertEqual(algo.get_parameter(), parameters)
+
+ def test_to_proto(self):
+ parameters = AlgorithmParameter(variables=[
+ AlgorithmVariable(
+ name='MAX_ITERS', value='5', required=False, display_name='', comment='', value_type='STRING')
+ ])
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).filter_by(name='test-algo').first()
+ self.assertEqual(
+ algo.to_proto(),
+ AlgorithmPb(id=1,
+ name='test-algo',
+ type='NN_VERTICAL',
+ source='USER',
+ algorithm_project_id=1,
+ path='/data',
+ parameter=parameters,
+ status='UNPUBLISHED',
+ algorithm_project_uuid='test-algo-project-uuid',
+ updated_at=to_timestamp(datetime(2022, 2, 22, tzinfo=timezone.utc)),
+ created_at=to_timestamp(datetime(2022, 2, 22, tzinfo=timezone.utc))))
+
+ def test_normalize_path(self):
+ path1 = 'hdfs:///user/./local'
+ self.assertEqual(normalize_path(path1), 'hdfs:///user/./local')
+ path2 = 'file:///app/./local/../tools'
+ self.assertEqual(normalize_path(path2), 'file:///app/tools')
+ path3 = '/app/./local/../tools'
+ self.assertEqual(normalize_path(path3), '/app/tools')
+
+ def test_get_status(self):
+ algo1 = Algorithm(publish_status=PublishStatus.PUBLISHED, ticket_uuid=1)
+ self.assertEqual(algo1.get_status(), AlgorithmStatus.PUBLISHED)
+ algo2 = Algorithm(publish_status=PublishStatus.PUBLISHED, ticket_uuid=None)
+ self.assertEqual(algo2.get_status(), AlgorithmStatus.PUBLISHED)
+ algo3 = Algorithm(publish_status=PublishStatus.UNPUBLISHED, ticket_uuid=None)
+ self.assertEqual(algo3.get_status(), AlgorithmStatus.UNPUBLISHED)
+ algo4 = Algorithm(publish_status=PublishStatus.UNPUBLISHED,
+ ticket_uuid=None,
+ ticket_status=TicketStatus.PENDING)
+ self.assertEqual(algo4.get_status(), AlgorithmStatus.UNPUBLISHED)
+ algo5 = Algorithm(publish_status=PublishStatus.UNPUBLISHED, ticket_uuid=1, ticket_status=TicketStatus.PENDING)
+ self.assertEqual(algo5.get_status(), AlgorithmStatus.PENDING_APPROVAL)
+ algo6 = Algorithm(publish_status=PublishStatus.UNPUBLISHED, ticket_uuid=1, ticket_status=TicketStatus.DECLINED)
+ self.assertEqual(algo6.get_status(), AlgorithmStatus.DECLINED)
+ algo7 = Algorithm(publish_status=PublishStatus.UNPUBLISHED, ticket_uuid=1, ticket_status=TicketStatus.APPROVED)
+ self.assertEqual(algo7.get_status(), AlgorithmStatus.APPROVED)
+
+
+class AlgorithmProjectTest(NoWebServerTestCase):
+
+ def test_algorithms_reference(self):
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(name='test-algo')
+ session.add(algo_project)
+ session.flush()
+ algo1 = Algorithm(name='test-algo', version=1, algorithm_project_id=algo_project.id)
+ algo2 = Algorithm(name='test-algo', version=2, algorithm_project_id=algo_project.id)
+ algo3 = Algorithm(name='test-algo')
+ session.add_all([algo1, algo2, algo3])
+ session.commit()
+ with db.session_scope() as session:
+ algo_project: AlgorithmProject = session.query(AlgorithmProject).get(algo_project.id)
+ algorithms = algo_project.algorithms
+ self.assertEqual(len(algorithms), 2)
+ self.assertEqual(algorithms[0].name, 'test-algo')
+ self.assertEqual(algorithms[0].version, 2)
+ self.assertEqual(algorithms[1].name, 'test-algo')
+ self.assertEqual(algorithms[1].version, 1)
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(id=1,
+ name='test-algo-project',
+ type=AlgorithmType.TREE_VERTICAL,
+ path='/data',
+ created_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ updated_at=datetime(2022, 2, 22, tzinfo=timezone.utc))
+ algo_project.set_parameter(AlgorithmParameter(variables=[AlgorithmVariable(name='MAX_DEPTH', value='5')]))
+ session.add(algo_project)
+ session.commit()
+ result = algo_project.to_proto()
+ parameters = AlgorithmParameter(variables=[
+ AlgorithmVariable(
+ name='MAX_DEPTH', value='5', required=False, display_name='', comment='', value_type='STRING')
+ ])
+ self.assertEqual(
+ result,
+ AlgorithmProjectPb(id=1,
+ name='test-algo-project',
+ type='TREE_VERTICAL',
+ source='UNSPECIFIED',
+ publish_status='UNPUBLISHED',
+ path='/data',
+ parameter=parameters,
+ updated_at=to_timestamp(datetime(2022, 2, 22, tzinfo=timezone.utc)),
+ created_at=to_timestamp(datetime(2022, 2, 22, tzinfo=timezone.utc)),
+ release_status='UNRELEASED'))
+
+
+class PendingAlgorithmTest(NoWebServerTestCase):
+
+ def test_to_dict(self):
+ pending_algo = PendingAlgorithm(name='test-algo', type=AlgorithmType.TREE_VERTICAL, path='/data')
+ pending_algo.set_parameter(AlgorithmParameter(variables=[AlgorithmVariable(name='MAX_DEPTH', value='5')]))
+ with db.session_scope() as session:
+ session.add(pending_algo)
+ session.commit()
+ result = pending_algo.to_proto()
+ self.assertEqual(result.type, 'TREE_VERTICAL')
+ parameters = AlgorithmParameter(variables=[
+ AlgorithmVariable(
+ name='MAX_DEPTH', value='5', required=False, display_name='', comment='', value_type='STRING')
+ ])
+ self.assertEqual(result.parameter, parameters)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/BUILD.bazel
new file mode 100644
index 000000000..6380eb244
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/BUILD.bazel
@@ -0,0 +1,38 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+filegroup(
+ name = "preset_algorithms",
+ srcs = glob(["**/*"]),
+)
+
+py_library(
+ name = "preset_algorithm_service_lib",
+ srcs = ["preset_algorithm_service.py"],
+ data = [":preset_algorithms"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ ],
+)
+
+py_test(
+ name = "preset_algorithm_service_lib_test",
+ srcs = [
+ "preset_algorithm_service_test.py",
+ ],
+ imports = ["../../.."],
+ main = "preset_algorithm_service_test.py",
+ deps = [
+ ":preset_algorithm_service_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/follower/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/follower/config.py
new file mode 100644
index 000000000..b18f227ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/follower/config.py
@@ -0,0 +1,18 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/follower/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/follower/main.py
new file mode 100644
index 000000000..50ac5f5a9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/follower/main.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import numpy as np
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'follower'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=100, help='Training batch size.')
+args = parser.parse_args()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ return features, dict(y=tf.constant(0))
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ "example_id": tf.FixedLenFeature([], tf.string),
+ "raw_id": tf.FixedLenFeature([], tf.string),
+ }
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in follower_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1f = tf.get_variable('w1l',
+ shape=[len(follower_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1f = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)
+
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=act1_f)
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ gact1_f = model.send('act1_f', act1_f, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ model.send('act1_f', act1_f, require_grad=False)
+
+ #acc = model.recv('acc', tf.float32, require_grad=False)
+ auc = model.recv('auc', tf.float32, require_grad=False)
+ loss = model.recv('loss', tf.float32, require_grad=False)
+ logging_hook = tf.train.LoggingTensorHook({
+ 'auc': auc,
+ 'loss': loss,
+ }, every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ train_op = model.minimize(optimizer, act1_f, grad_loss=gact1_f, global_step=global_step)
+ return model.make_spec(mode,
+ loss=tf.math.reduce_mean(act1_f),
+ train_op=train_op,
+ training_hooks=[logging_hook, step_metric_hook])
+ if mode == tf.estimator.ModeKeys.EVAL:
+ fake_loss = tf.reduce_mean(act1_f)
+ return model.make_spec(mode=mode, loss=fake_loss, evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/leader/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/leader/config.py
new file mode 100644
index 000000000..b18f227ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/leader/config.py
@@ -0,0 +1,18 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/leader/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/leader/main.py
new file mode 100644
index 000000000..7bddf34b3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v1/leader/main.py
@@ -0,0 +1,123 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'leader'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=100, help='Training batch size.')
+args = parser.parse_args()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ label_map = {}
+ for name in leader_label_name:
+ label_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ labels = tf.parse_example(example, features=label_map)
+ return features, labels
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ "example_id": tf.FixedLenFeature([], tf.string),
+ "raw_id": tf.FixedLenFeature([], tf.string),
+ }
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in leader_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1l = tf.get_variable('w1l',
+ shape=[len(leader_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1l = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=False)
+ else:
+ act1_f = features['act1_f']
+ logits = act1_l + act1_f
+ pred = tf.math.sigmoid(logits)
+
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=pred)
+
+ y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]
+ y = tf.concat(y, axis=-1)
+
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)
+ _, auc = tf.metrics.auc(labels=y, predictions=pred)
+ #correct = tf.nn.in_top_k(predictions=logits, targets=y, k=1)
+ #acc = tf.reduce_mean(input_tensor=tf.cast(correct, tf.float32))
+ logging_hook = tf.train.LoggingTensorHook(
+ {
+ # 'acc': acc,
+ 'auc': auc,
+ 'loss': loss,
+ },
+ every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)
+ #model.send('acc', acc, require_grad=False)
+ model.send('auc', auc, require_grad=False)
+ model.send('loss', loss, require_grad=False)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(1e-4)
+ train_op = model.minimize(optimizer, loss, global_step=global_step)
+ return model.make_spec(mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook, step_metric_hook])
+
+ if mode == tf.estimator.ModeKeys.EVAL:
+ loss_pair = tf.metrics.mean(loss)
+ return model.make_spec(mode=mode,
+ loss=loss,
+ eval_metric_ops={'loss': loss_pair},
+ evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/follower/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/follower/config.py
new file mode 100644
index 000000000..b18f227ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/follower/config.py
@@ -0,0 +1,18 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/follower/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/follower/main.py
new file mode 100644
index 000000000..2f3d8773f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/follower/main.py
@@ -0,0 +1,136 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import numpy as np
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'follower'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=10, help='Training batch size.')
+args = parser.parse_args()
+
+
+class ResultWriter:
+
+ def __init__(self):
+ self.result = {}
+
+ def update_result(self, raw_id, pred):
+ raw_id = raw_id.numpy()
+ pred = pred.numpy()
+ for i in range(len(raw_id)):
+ self.result[raw_id[i]] = pred[i]
+
+ def write_result(self, filename):
+ raw_id = np.array(list(self.result.keys())).reshape(-1)
+ pred = np.array([self.result[i] for i in raw_id]).reshape(-1)
+ with tf.gfile.Open(filename, 'w') as f:
+ np.savetxt(f, np.dstack((raw_id, pred))[0], '%s,%f', header='raw_id,pred')
+ logging.info(f'[write_result]output result to {filename}')
+
+
+result_writer = ResultWriter()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.int64)
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ return features, dict(y=tf.constant(0))
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ 'example_id': tf.FixedLenFeature([], tf.string),
+ 'raw_id': tf.FixedLenFeature([], tf.int64),
+ }
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in follower_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1f = tf.get_variable('w1l',
+ shape=[len(follower_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1f = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)
+
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=act1_f)
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ gact1_f = model.send('act1_f', act1_f, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ model.send('act1_f', act1_f, require_grad=False)
+
+ pred = model.recv('pred', tf.float32, require_grad=False)
+ raw_id = features['raw_id']
+ update_result_op = tf.py_function(result_writer.update_result, [raw_id, pred], [], 'update')
+ with tf.control_dependencies([update_result_op]):
+ auc = model.recv('auc', tf.float32, require_grad=False)
+ loss = model.recv('loss', tf.float32, require_grad=False)
+
+ logging_hook = tf.train.LoggingTensorHook({
+ 'auc': auc,
+ 'loss': loss,
+ }, every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc, every_n_iter=1)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ train_op = model.minimize(optimizer, act1_f, grad_loss=gact1_f, global_step=global_step)
+ return model.make_spec(mode,
+ loss=tf.math.reduce_mean(act1_f),
+ train_op=train_op,
+ training_hooks=[logging_hook, step_metric_hook])
+ if mode == tf.estimator.ModeKeys.EVAL:
+ fake_loss = tf.reduce_mean(act1_f)
+ return model.make_spec(mode=mode, loss=fake_loss, evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ outputs_path = os.path.join(os.environ['OUTPUT_BASE_DIR'], 'outputs')
+ tf.gfile.MakeDirs(outputs_path)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
+ if args.worker:
+ result_writer.write_result(os.path.join(outputs_path, f'worker-{str(args.worker_rank)}.csv'))
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/leader/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/leader/config.py
new file mode 100644
index 000000000..b73e677f3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/leader/config.py
@@ -0,0 +1,3 @@
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/leader/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/leader/main.py
new file mode 100644
index 000000000..3c3f0b62c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v2/leader/main.py
@@ -0,0 +1,132 @@
+import os
+import logging
+import numpy as np
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'leader'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=10, help='Training batch size.')
+args = parser.parse_args()
+
+
+class ResultWriter:
+
+ def __init__(self):
+ self.result = {}
+
+ def update_result(self, raw_id, pred):
+ raw_id = raw_id.numpy()
+ pred = pred.numpy()
+ for i in range(len(raw_id)):
+ self.result[raw_id[i]] = pred[i]
+
+ def write_result(self, filename):
+ raw_id = np.array(list(self.result.keys())).reshape(-1)
+ pred = np.array([self.result[i] for i in raw_id]).reshape(-1)
+ with tf.gfile.Open(filename, 'w') as f:
+ np.savetxt(f, np.dstack((raw_id, pred))[0], '%s,%f', header='raw_id,pred')
+ logging.info(f'[write_result]output result to {filename}')
+
+
+result_writer = ResultWriter()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.int64)
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ label_map = {}
+ for name in leader_label_name:
+ label_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ labels = tf.parse_example(example, features=label_map)
+ return features, labels
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ 'example_id': tf.FixedLenFeature([], tf.string),
+ 'raw_id': tf.FixedLenFeature([], tf.int64),
+ }
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in leader_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1l = tf.get_variable('w1l',
+ shape=[len(leader_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1l = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=False)
+ else:
+ act1_f = features['act1_f']
+ logits = act1_l + act1_f
+ pred = tf.math.sigmoid(logits)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=pred)
+
+ raw_id = features['raw_id']
+ update_result_op = tf.py_function(result_writer.update_result, [raw_id, pred], [], 'update')
+ model.send('pred', pred, require_grad=False)
+ y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]
+ y = tf.concat(y, axis=-1)
+
+ with tf.control_dependencies([update_result_op]):
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)
+ _, auc = tf.metrics.auc(labels=y, predictions=pred)
+ logging_hook = tf.train.LoggingTensorHook({
+ 'auc': auc,
+ 'loss': loss,
+ }, every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc, every_n_iter=1)
+ model.send('auc', auc, require_grad=False)
+ model.send('loss', loss, require_grad=False)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(1e-4)
+ train_op = model.minimize(optimizer, loss, global_step=global_step)
+ return model.make_spec(mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook, step_metric_hook])
+
+ if mode == tf.estimator.ModeKeys.EVAL:
+ loss_pair = tf.metrics.mean(loss)
+ return model.make_spec(mode=mode,
+ loss=loss,
+ eval_metric_ops={'loss': loss_pair},
+ evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ outputs_path = os.path.join(os.environ['OUTPUT_BASE_DIR'], 'outputs')
+ tf.gfile.MakeDirs(outputs_path)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
+ if args.worker:
+ result_writer.write_result(os.path.join(outputs_path, f'worker-{str(args.worker_rank)}.csv'))
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/follower/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/follower/config.py
new file mode 100644
index 000000000..b18f227ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/follower/config.py
@@ -0,0 +1,18 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/follower/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/follower/main.py
new file mode 100644
index 000000000..dce750326
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/follower/main.py
@@ -0,0 +1,136 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import numpy as np
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'follower'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=10, help='Training batch size.')
+args = parser.parse_args()
+
+
+class ResultWriter:
+
+ def __init__(self):
+ self.result = {}
+
+ def update_result(self, raw_id, pred):
+ raw_id = raw_id.numpy()
+ pred = pred.numpy()
+ for i in range(len(raw_id)):
+ self.result[raw_id[i]] = pred[i]
+
+ def write_result(self, filename):
+ raw_id = np.array(list(self.result.keys())).reshape(-1)
+ pred = np.array([self.result[i] for i in raw_id]).reshape(-1)
+ with tf.gfile.Open(filename, 'w') as f:
+ np.savetxt(f, np.dstack((raw_id, pred))[0], '%s,%f', header='raw_id,pred')
+ logging.info(f'[write_result]output result to {filename}')
+
+
+result_writer = ResultWriter()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ return features, dict(y=tf.constant(0))
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ 'example_id': tf.FixedLenFeature([], tf.string),
+ 'raw_id': tf.FixedLenFeature([], tf.string),
+ }
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in follower_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1f = tf.get_variable('w1l',
+ shape=[len(follower_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1f = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)
+
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=act1_f)
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ gact1_f = model.send('act1_f', act1_f, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ model.send('act1_f', act1_f, require_grad=False)
+
+ pred = model.recv('pred', tf.float32, require_grad=False)
+ raw_id = features['raw_id']
+ update_result_op = tf.py_function(result_writer.update_result, [raw_id, pred], [], 'update')
+ with tf.control_dependencies([update_result_op]):
+ auc = model.recv('auc', tf.float32, require_grad=False)
+ loss = model.recv('loss', tf.float32, require_grad=False)
+
+ logging_hook = tf.train.LoggingTensorHook({
+ 'auc': auc,
+ 'loss': loss,
+ }, every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc, every_n_iter=1)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ train_op = model.minimize(optimizer, act1_f, grad_loss=gact1_f, global_step=global_step)
+ return model.make_spec(mode,
+ loss=tf.math.reduce_mean(act1_f),
+ train_op=train_op,
+ training_hooks=[logging_hook, step_metric_hook])
+ if mode == tf.estimator.ModeKeys.EVAL:
+ fake_loss = tf.reduce_mean(act1_f)
+ return model.make_spec(mode=mode, loss=fake_loss, evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ outputs_path = os.path.join(os.environ['OUTPUT_BASE_DIR'], 'outputs')
+ tf.gfile.MakeDirs(outputs_path)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
+ if args.worker:
+ result_writer.write_result(os.path.join(outputs_path, f'worker-{str(args.worker_rank)}.csv'))
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/leader/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/leader/config.py
new file mode 100644
index 000000000..b73e677f3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/leader/config.py
@@ -0,0 +1,3 @@
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/leader/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/leader/main.py
new file mode 100644
index 000000000..22e4e65ba
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v3/leader/main.py
@@ -0,0 +1,132 @@
+import os
+import logging
+import numpy as np
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'leader'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=10, help='Training batch size.')
+args = parser.parse_args()
+
+
+class ResultWriter:
+
+ def __init__(self):
+ self.result = {}
+
+ def update_result(self, raw_id, pred):
+ raw_id = raw_id.numpy()
+ pred = pred.numpy()
+ for i in range(len(raw_id)):
+ self.result[raw_id[i]] = pred[i]
+
+ def write_result(self, filename):
+ raw_id = np.array(list(self.result.keys())).reshape(-1)
+ pred = np.array([self.result[i] for i in raw_id]).reshape(-1)
+ with tf.gfile.Open(filename, 'w') as f:
+ np.savetxt(f, np.dstack((raw_id, pred))[0], '%s,%f', header='raw_id,pred')
+ logging.info(f'[write_result]output result to {filename}')
+
+
+result_writer = ResultWriter()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ label_map = {}
+ for name in leader_label_name:
+ label_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ labels = tf.parse_example(example, features=label_map)
+ return features, labels
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ 'example_id': tf.FixedLenFeature([], tf.string),
+ 'raw_id': tf.FixedLenFeature([], tf.string),
+ }
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in leader_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1l = tf.get_variable('w1l',
+ shape=[len(leader_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1l = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=False)
+ else:
+ act1_f = features['act1_f']
+ logits = act1_l + act1_f
+ pred = tf.math.sigmoid(logits)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=pred)
+
+ raw_id = features['raw_id']
+ update_result_op = tf.py_function(result_writer.update_result, [raw_id, pred], [], 'update')
+ model.send('pred', pred, require_grad=False)
+ y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]
+ y = tf.concat(y, axis=-1)
+
+ with tf.control_dependencies([update_result_op]):
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)
+ _, auc = tf.metrics.auc(labels=y, predictions=pred)
+ logging_hook = tf.train.LoggingTensorHook({
+ 'auc': auc,
+ 'loss': loss,
+ }, every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc, every_n_iter=1)
+ model.send('auc', auc, require_grad=False)
+ model.send('loss', loss, require_grad=False)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(1e-4)
+ train_op = model.minimize(optimizer, loss, global_step=global_step)
+ return model.make_spec(mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook, step_metric_hook])
+
+ if mode == tf.estimator.ModeKeys.EVAL:
+ loss_pair = tf.metrics.mean(loss)
+ return model.make_spec(mode=mode,
+ loss=loss,
+ eval_metric_ops={'loss': loss_pair},
+ evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ outputs_path = os.path.join(os.environ['OUTPUT_BASE_DIR'], 'outputs')
+ tf.gfile.MakeDirs(outputs_path)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
+ if args.worker:
+ result_writer.write_result(os.path.join(outputs_path, f'worker-{str(args.worker_rank)}.csv'))
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/follower/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/follower/config.py
new file mode 100644
index 000000000..b18f227ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/follower/config.py
@@ -0,0 +1,18 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/follower/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/follower/main.py
new file mode 100644
index 000000000..0affd60c1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/follower/main.py
@@ -0,0 +1,136 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import numpy as np
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'follower'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=10, help='Training batch size.')
+args = parser.parse_args()
+
+
+class ResultWriter:
+
+ def __init__(self):
+ self.result = {}
+
+ def update_result(self, raw_id, pred):
+ raw_id = raw_id.numpy()
+ pred = pred.numpy()
+ for i in range(len(raw_id)):
+ self.result[raw_id[i].decode('utf-8')] = pred[i]
+
+ def write_result(self, filename):
+ raw_id = np.array(list(self.result.keys())).reshape(-1)
+ pred = np.array([self.result[i] for i in raw_id]).reshape(-1)
+ with tf.gfile.Open(filename, 'w') as f:
+ np.savetxt(f, np.dstack((raw_id, pred))[0], '%s,%s', header='raw_id,pred')
+ logging.info(f'[write_result]output result to {filename}')
+
+
+result_writer = ResultWriter()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ return features, dict(y=tf.constant(0))
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ 'example_id': tf.FixedLenFeature([], tf.string),
+ 'raw_id': tf.FixedLenFeature([], tf.string),
+ }
+ for name in follower_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in follower_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1f = tf.get_variable('w1l',
+ shape=[len(follower_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1f = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)
+
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=act1_f)
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ gact1_f = model.send('act1_f', act1_f, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ model.send('act1_f', act1_f, require_grad=False)
+
+ pred = model.recv('pred', tf.float32, require_grad=False)
+ raw_id = features['raw_id']
+ update_result_op = tf.py_function(result_writer.update_result, [raw_id, pred], [], 'update')
+ with tf.control_dependencies([update_result_op]):
+ auc = model.recv('auc', tf.float32, require_grad=False)
+ loss = model.recv('loss', tf.float32, require_grad=False)
+
+ logging_hook = tf.train.LoggingTensorHook({
+ 'auc': auc,
+ 'loss': loss,
+ }, every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc, every_n_iter=1)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ train_op = model.minimize(optimizer, act1_f, grad_loss=gact1_f, global_step=global_step)
+ return model.make_spec(mode,
+ loss=tf.math.reduce_mean(act1_f),
+ train_op=train_op,
+ training_hooks=[logging_hook, step_metric_hook])
+ if mode == tf.estimator.ModeKeys.EVAL:
+ fake_loss = tf.reduce_mean(act1_f)
+ return model.make_spec(mode=mode, loss=fake_loss, evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ outputs_path = os.path.join(os.environ['OUTPUT_BASE_DIR'], 'outputs')
+ tf.gfile.MakeDirs(outputs_path)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
+ if args.worker:
+ result_writer.write_result(os.path.join(outputs_path, f'worker-{str(args.worker_rank)}.csv'))
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/leader/config.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/leader/config.py
new file mode 100644
index 000000000..b73e677f3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/leader/config.py
@@ -0,0 +1,3 @@
+leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']
+leader_label_name = ['label']
+follower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/leader/main.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/leader/main.py
new file mode 100644
index 000000000..832f27b65
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/e2e_test_v4/leader/main.py
@@ -0,0 +1,132 @@
+import os
+import logging
+import numpy as np
+import tensorflow.compat.v1 as tf
+import fedlearner.trainer as flt
+from config import *
+from fedlearner.trainer.trainer_worker import StepLossAucMetricsHook
+
+ROLE = 'leader'
+
+parser = flt.trainer_worker.create_argument_parser()
+parser.add_argument('--batch-size', type=int, default=10, help='Training batch size.')
+args = parser.parse_args()
+
+
+class ResultWriter:
+
+ def __init__(self):
+ self.result = {}
+
+ def update_result(self, raw_id, pred):
+ raw_id = raw_id.numpy()
+ pred = pred.numpy()
+ for i in range(len(raw_id)):
+ self.result[raw_id[i].decode('utf-8')] = pred[i]
+
+ def write_result(self, filename):
+ raw_id = np.array(list(self.result.keys())).reshape(-1)
+ pred = np.array([self.result[i] for i in raw_id]).reshape(-1)
+ with tf.gfile.Open(filename, 'w') as f:
+ np.savetxt(f, np.dstack((raw_id, pred))[0], '%s,%s', header='raw_id,pred')
+ logging.info(f'[write_result]output result to {filename}')
+
+
+result_writer = ResultWriter()
+
+
+def input_fn(bridge, trainer_master):
+ dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge, trainer_master).make_dataset()
+
+ def parse_fn(example):
+ feature_map = dict()
+ feature_map['example_id'] = tf.FixedLenFeature([], tf.string)
+ feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ label_map = {}
+ for name in leader_label_name:
+ label_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ features = tf.parse_example(example, features=feature_map)
+ labels = tf.parse_example(example, features=label_map)
+ return features, labels
+
+ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ 'example_id': tf.FixedLenFeature([], tf.string),
+ 'raw_id': tf.FixedLenFeature([], tf.string),
+ }
+ for name in leader_feature_names:
+ feature_map[name] = tf.FixedLenFeature([], tf.float32, default_value=0.0)
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.parse_example(record_batch, features=feature_map)
+ features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')
+ receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}
+ return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
+
+
+def model_fn(model, features, labels, mode):
+ logging.info('model_fn: mode %s', mode)
+ x = [tf.expand_dims(features[name], axis=-1) for name in leader_feature_names]
+ x = tf.concat(x, axis=-1)
+
+ w1l = tf.get_variable('w1l',
+ shape=[len(leader_feature_names), len(leader_label_name)],
+ dtype=tf.float32,
+ initializer=tf.random_uniform_initializer(-0.01, 0.01))
+ b1l = tf.get_variable('b1l', shape=[len(leader_label_name)], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+ act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=True)
+ elif mode == tf.estimator.ModeKeys.EVAL:
+ act1_f = model.recv('act1_f', tf.float32, require_grad=False)
+ else:
+ act1_f = features['act1_f']
+ logits = act1_l + act1_f
+ pred = tf.math.sigmoid(logits)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return model.make_spec(mode=mode, predictions=pred)
+
+ raw_id = features['raw_id']
+ update_result_op = tf.py_function(result_writer.update_result, [raw_id, pred], [], 'update')
+ model.send('pred', pred, require_grad=False)
+ y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]
+ y = tf.concat(y, axis=-1)
+
+ with tf.control_dependencies([update_result_op]):
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)
+ _, auc = tf.metrics.auc(labels=y, predictions=pred)
+ logging_hook = tf.train.LoggingTensorHook({
+ 'auc': auc,
+ 'loss': loss,
+ }, every_n_iter=10)
+ step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc, every_n_iter=1)
+ model.send('auc', auc, require_grad=False)
+ model.send('loss', loss, require_grad=False)
+
+ global_step = tf.train.get_or_create_global_step()
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(1e-4)
+ train_op = model.minimize(optimizer, loss, global_step=global_step)
+ return model.make_spec(mode=mode, loss=loss, train_op=train_op, training_hooks=[logging_hook, step_metric_hook])
+
+ if mode == tf.estimator.ModeKeys.EVAL:
+ loss_pair = tf.metrics.mean(loss)
+ return model.make_spec(mode=mode,
+ loss=loss,
+ eval_metric_ops={'loss': loss_pair},
+ evaluation_hooks=[logging_hook, step_metric_hook])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.INFO)
+ outputs_path = os.path.join(os.environ['OUTPUT_BASE_DIR'], 'outputs')
+ tf.gfile.MakeDirs(outputs_path)
+ flt.trainer_worker.train(ROLE, args, input_fn, model_fn, serving_input_receiver_fn)
+ if args.worker:
+ result_writer.write_result(os.path.join(outputs_path, f'worker-{str(args.worker_rank)}.csv'))
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/follower.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/follower.py
new file mode 100644
index 000000000..6497fb759
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/follower.py
@@ -0,0 +1,41 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import tensorflow as tf
+from model import run, get_dataset
+from fedlearner.fedavg import train_from_keras_model
+
+fl_name = 'follower'
+
+mode = os.getenv('MODE', 'train')
+epoch_num = int(os.getenv('EPOCH_NUM', 1))
+data_path = os.getenv('DATA_PATH')
+output_base_dir = os.getenv('OUTPUT_BASE_DIR')
+steps_per_sync = int(os.getenv('FL_STEPS_PER_SYNC', 10))
+LOAD_MODEL_FROM = os.getenv('LOAD_MODEL_FROM')
+
+if __name__ == '__main__':
+ print('-------------------------------')
+ print('mode : ', mode)
+ print('data_path : ', data_path)
+ print('output_base_dir : ', output_base_dir)
+ print('load model from : ', LOAD_MODEL_FROM)
+ print('-------------------------------')
+ logging.basicConfig(level=logging.INFO)
+ logging.info('mode: %s', mode)
+ ds = get_dataset(data_path)
+ run(fl_name, mode, ds, epoch_num, steps_per_sync)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/leader.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/leader.py
new file mode 100644
index 000000000..650a6bf23
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/leader.py
@@ -0,0 +1,41 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import tensorflow as tf
+from model import run, get_dataset
+from fedlearner.fedavg import train_from_keras_model
+
+fl_name = 'leader'
+
+mode = os.getenv('MODE', 'train')
+epoch_num = int(os.getenv('EPOCH_NUM', 1))
+data_path = os.getenv('DATA_PATH')
+output_base_dir = os.getenv('OUTPUT_BASE_DIR')
+steps_per_sync = int(os.getenv('FL_STEPS_PER_SYNC', 10))
+LOAD_MODEL_FROM = os.getenv('LOAD_MODEL_FROM')
+
+if __name__ == '__main__':
+ print('-------------------------------')
+ print('mode : ', mode)
+ print('data_path : ', data_path)
+ print('output_base_dir : ', output_base_dir)
+ print('load model from : ', LOAD_MODEL_FROM)
+ print('-------------------------------')
+ logging.basicConfig(level=logging.INFO)
+ logging.info('mode: %s', mode)
+ ds = get_dataset(data_path)
+ run(fl_name, mode, ds, epoch_num, steps_per_sync)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/metrics.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/metrics.py
new file mode 100644
index 000000000..e7ca4441b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/metrics.py
@@ -0,0 +1,62 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tensorflow as tf
+from fedlearner.common import metrics
+from fedlearner.fedavg.master import LeaderMaster, FollowerMaster
+from fedlearner.fedavg.cluster.cluster_spec import FLClusterSpec
+from fedlearner.fedavg._global_context import global_context as _gtx
+
+
+class MetricsKerasCallback(tf.keras.callbacks.Callback):
+
+ def __init__(self):
+ super().__init__()
+ self._global_step = None
+ self._metrics = {}
+
+ def on_train_end(self, logs=None):
+ self.emit_metrics()
+
+ def on_train_batch_end(self, batch, logs=None):
+ self.update_metrics(logs)
+
+ def on_test_end(self, logs=None):
+ self.emit_metrics()
+
+ def on_test_batch_end(self, batch, logs=None):
+ self.update_metrics(logs)
+
+ def update_metrics(self, logs: dict):
+ if 'batch' not in logs:
+ return
+
+ self._global_step = logs['batch']
+ self._metrics = logs
+ if self._global_step % 10 == 0:
+ self.emit_metrics()
+
+ def emit_metrics(self):
+ if self._global_step is None:
+ return
+ stats_pipe = _gtx.stats_client.pipeline()
+ stats_pipe.gauge('trainer.metric_global_step', self._global_step)
+ for key, value in self._metrics.items():
+ if key in ('size', 'batch'):
+ continue
+ stats_pipe.gauge('trainer.metric_value', value, tags={'metric': key})
+ metrics.emit_store(name=key, value=value)
+ stats_pipe.send()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/model.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/model.py
new file mode 100644
index 000000000..e7f8e28f4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/horizontal_e2e_test_v1/model.py
@@ -0,0 +1,150 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import time
+import tensorflow as tf
+import numpy as np
+import logging
+from datetime import datetime
+from fedlearner.fedavg import train_from_keras_model
+from metrics import MetricsKerasCallback
+
+LOAD_MODEL_FROM = os.getenv('LOAD_MODEL_FROM') # load from {STORAGE_ROOT}/job_outputs/{job.name}/checkpoints
+OUTPUT_BASE_DIR = os.getenv('OUTPUT_BASE_DIR') # save output under {OUTPUT_BASE_DIR}/output
+EXPORT_PATH = os.path.join(OUTPUT_BASE_DIR, 'exported_models') # save estimator to {EXPORT_PATH}
+CHECKPOINT_PATH = os.path.join(OUTPUT_BASE_DIR, 'checkpoints') # save keras model to {CHECKPOINT_PATH}
+
+fl_leader_address = os.getenv('FL_LEADER_ADDRESS', '0.0.0.0:6870')
+FL_CLUSTER = {'leader': {'name': 'leader', 'address': fl_leader_address}, 'followers': [{'name': 'follower'}]}
+
+
+def _label_to_int(label: str):
+ pred_fn_pairs = [(tf.equal(label, 'deer'), lambda: 0), (tf.equal(label, 'frog'), lambda: 1),
+ (tf.equal(label, 'horse'), lambda: 2), (tf.equal(label, 'dog'), lambda: 3),
+ (tf.equal(label, 'automobile'), lambda: 4), (tf.equal(label, 'airplane'), lambda: 5),
+ (tf.equal(label, 'ship'), lambda: 6), (tf.equal(label, 'cat'), lambda: 7),
+ (tf.equal(label, 'truck'), lambda: 8), (tf.equal(label, 'bird'), lambda: 9)]
+ return tf.case(pred_fn_pairs)
+
+
+def decode_and_resize(args):
+ x, h, w, c = args
+ x = tf.io.decode_raw(x, tf.uint8)
+ x = tf.reshape(x, [h, w, c])
+ x = tf.image.resize(x, (128, 128))
+ x = tf.cast(x, tf.float32)
+ x = tf.image.per_image_standardization(x)
+ x.set_shape([128, 128, 3])
+ return x
+
+
+def serving_input_receiver_fn():
+ feature_map = {
+ 'width': tf.io.FixedLenFeature([], tf.int64),
+ 'height': tf.io.FixedLenFeature([], tf.int64),
+ 'nChannels': tf.io.FixedLenFeature([], tf.int64),
+ 'data': tf.io.FixedLenFeature([], tf.string)
+ }
+ record_batch = tf.placeholder(dtype=tf.string, name='examples')
+ features = tf.io.parse_example(record_batch, features=feature_map)
+ features['data'] = tf.map_fn(decode_and_resize,
+ (features['data'], features['height'], features['width'], features['nChannels']),
+ dtype=tf.float32)
+ receiver_tensors = {'examples': record_batch}
+ return tf.estimator.export.ServingInputReceiver({'data': features['data']}, receiver_tensors)
+
+
+def parse_fn(record: bytes):
+ features = tf.io.parse_single_example(
+ record, {
+ 'width': tf.io.FixedLenFeature([], tf.int64),
+ 'height': tf.io.FixedLenFeature([], tf.int64),
+ 'nChannels': tf.io.FixedLenFeature([], tf.int64),
+ 'label': tf.io.FixedLenFeature([], tf.string),
+ 'data': tf.io.FixedLenFeature([], tf.string)
+ })
+ label = _label_to_int(features['label'])
+ img = tf.decode_raw(features['data'], out_type=tf.uint8)
+ img = tf.reshape(img, [features['height'], features['width'], features['nChannels']])
+ img = tf.image.resize(img, size=[128, 128])
+ img = tf.cast(img, tf.float32)
+ return img, label
+
+
+def create_model():
+ model = tf.keras.Sequential([
+ tf.keras.Input(shape=(128, 128, 3), name='data'),
+ tf.keras.layers.Conv2D(16, kernel_size=(3, 3), activation='relu'),
+ tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
+ tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
+ tf.keras.layers.GlobalMaxPool2D(),
+ tf.keras.layers.BatchNormalization(),
+ tf.keras.layers.Dense(64, activation='relu'),
+ tf.keras.layers.Dense(16, activation='relu'),
+ tf.keras.layers.Dense(10, activation='softmax', name='label'),
+ ])
+ model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(),
+ metrics=['acc'])
+ return model
+
+
+def get_dataset(data_path: str):
+ files = []
+ for dirname, subdirs, filenames in tf.io.gfile.walk(data_path):
+ for filename in filenames:
+ if filename.startswith('part'):
+ files.append(os.path.join(dirname, filename))
+ print('list filenames: ', files)
+ ds = tf.data.TFRecordDataset(files) \
+ .map(map_func=parse_fn) \
+ .shuffle(30000) \
+ .batch(30) \
+ .prefetch(10)
+ return ds
+
+
+def run(fl_name, mode, ds, epoch_num, steps_per_sync):
+ if mode == 'train':
+ model = create_model()
+ model.build([None, 128, 128, 3])
+ train_from_keras_model(model,
+ x=ds,
+ y=None,
+ epochs=epoch_num,
+ fl_name=fl_name,
+ fl_cluster=FL_CLUSTER,
+ steps_per_sync=steps_per_sync)
+ estimator = tf.keras.estimator.model_to_estimator(model)
+ # since fedlearner will save keras model, sleep for model importer to import the latest model
+ time.sleep(60)
+ export_path = estimator.export_saved_model(EXPORT_PATH, serving_input_receiver_fn=serving_input_receiver_fn)
+ logging.info(f'\nexport estimator to {export_path}\n')
+ checkpoint_path = os.path.join(CHECKPOINT_PATH, str(int(datetime.now().timestamp())))
+ model.save(checkpoint_path, save_format='tf')
+ logging.info(f'\nexport model to {CHECKPOINT_PATH}\n')
+ else:
+ latest_path = os.path.join(LOAD_MODEL_FROM, sorted(tf.io.gfile.listdir(LOAD_MODEL_FROM))[-1])
+ logging.info('load model from %s', latest_path)
+ model = tf.keras.models.load_model(latest_path)
+ if mode == 'eval':
+ model.evaluate(ds, callbacks=[MetricsKerasCallback()])
+ output = model.predict(ds)
+ output_path = os.path.join(OUTPUT_BASE_DIR, 'outputs')
+ tf.io.gfile.makedirs(output_path)
+ logging.info('write output to %s', output_path)
+ with tf.io.gfile.GFile(os.path.join(output_path, 'output.csv'), 'w') as fp:
+ np.savetxt(fp, output)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/preset_algorithm_service.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/preset_algorithm_service.py
new file mode 100644
index 000000000..b9bee62a5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/preset_algorithm_service.py
@@ -0,0 +1,132 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import copy
+from envs import Envs
+from pathlib import Path
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.algorithm.utils import algorithm_path, check_algorithm_file
+from fedlearner_webconsole.algorithm.models import Algorithm, Source, AlgorithmType, AlgorithmProject
+
+_ALGORITHMS_PATH = Path(__file__, '..').resolve()
+
+# When inserting a preset algorithm, you need to insert the algorithm project and the algorithm into
+# PRESET_ALGORITHM_PROJECT_LIST and PRESET_ALGORITHM_LIST respectively. The algorithm project and the
+# algorithm need to have the same name. If the algorithm project already exists, you need to update
+# the latest version.
+
+PRESET_ALGORITHM_PROJECT_LIST = [
+ AlgorithmProject(name='e2e_test',
+ type=AlgorithmType.NN_VERTICAL,
+ uuid='u1b9eea3753e24fd9b91',
+ source=Source.PRESET,
+ comment='algorithm for end to end test',
+ latest_version=4),
+ AlgorithmProject(name='horizontal_e2e_test',
+ type=AlgorithmType.NN_HORIZONTAL,
+ uuid='u76630127d63c4ddb871',
+ source=Source.PRESET,
+ comment='algorithm for end to end test',
+ latest_version=1),
+ AlgorithmProject(name='secure_boost',
+ type=AlgorithmType.TREE_VERTICAL,
+ uuid='u7607b76db2c843fb9cd',
+ source=Source.PRESET,
+ comment='algorithm for secure boost',
+ latest_version=1)
+]
+
+PRESET_ALGORITHM_LIST = [
+ Algorithm(name='e2e_test',
+ version=1,
+ uuid='u5c4f510aab2f4a288c8',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'e2e_test_v1'),
+ comment='algorithm for end to end test'),
+ Algorithm(name='e2e_test',
+ version=2,
+ uuid='uc74ce6731906480c804',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'e2e_test_v2'),
+ comment='algorithm for end to end test'),
+ Algorithm(name='e2e_test',
+ version=3,
+ uuid='u322cd66836f04a13b94',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'e2e_test_v3'),
+ comment='algorithm for end to end test'),
+ Algorithm(name='e2e_test',
+ version=4,
+ uuid='uff7a19e8a1834d5e991',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'e2e_test_v4'),
+ comment='support save result when predict'),
+ Algorithm(name='horizontal_e2e_test',
+ version=1,
+ uuid='ub7b45bf127fc4aebad4',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_HORIZONTAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'horizontal_e2e_test_v1'),
+ comment='algorithm for horizontal nn end to end test'),
+ Algorithm(name='secure_boost',
+ version=1,
+ uuid='u936cb7254e4444caaf9',
+ source=Source.PRESET,
+ type=AlgorithmType.TREE_VERTICAL,
+ comment='algorithm for secure boost')
+]
+
+
+def create_algorithm_if_not_exists():
+ file_operator = FileOperator()
+ file_manager = FileManager()
+
+ for algo_project in PRESET_ALGORITHM_PROJECT_LIST:
+ with db.session_scope() as session:
+ algorithm_project = session.query(Algorithm).filter_by(name=algo_project.name, source=Source.PRESET).first()
+ if algorithm_project is None:
+ session.add(algo_project)
+ session.commit()
+
+ for preset_algo in PRESET_ALGORITHM_LIST:
+ algo = copy.deepcopy(preset_algo)
+ dest_algo_path = None
+ if preset_algo.path:
+ dest_algo_path = algorithm_path(Envs.STORAGE_ROOT, algo.name, algo.version)
+ file_manager.mkdir(dest_algo_path)
+ with check_algorithm_file(dest_algo_path):
+ file_operator.copy_to(preset_algo.path, dest_algo_path)
+ algo.path = dest_algo_path
+ with db.session_scope() as session:
+ algorithm = session.query(Algorithm).filter_by(name=algo.name, version=algo.version,
+ source=Source.PRESET).first()
+ # Only need to update the path when the algo has been added to the database
+ if preset_algo.path and algorithm:
+ if algorithm.path and file_manager.exists(algorithm.path):
+ file_manager.remove(algorithm.path)
+ algorithm.path = dest_algo_path
+ if algorithm is None:
+ algo_project = session.query(AlgorithmProject).filter_by(name=algo.name, source=Source.PRESET).first()
+ assert algo_project is not None, 'preset algorithm project is not found'
+ algo.algorithm_project_id = algo_project.id
+ session.add(algo)
+ session.commit()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/preset_algorithm_service_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/preset_algorithm_service_test.py
new file mode 100644
index 000000000..e6b83befc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/preset_algorithms/preset_algorithm_service_test.py
@@ -0,0 +1,114 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+import tempfile
+from envs import Envs
+from pathlib import Path
+from unittest.mock import patch, MagicMock
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.algorithm.preset_algorithms.preset_algorithm_service import create_algorithm_if_not_exists
+from fedlearner_webconsole.algorithm.models import Algorithm, Source, AlgorithmType, AlgorithmProject
+from fedlearner_webconsole.utils.file_manager import FileManager
+
+_ALGORITHMS_PATH = Path(__file__, '..').resolve()
+
+
+class PresetTemplateServiceTest(NoWebServerTestCase):
+
+ def test_create_all(self):
+ Envs.STORAGE_ROOT = tempfile.gettempdir()
+ create_algorithm_if_not_exists()
+ with db.session_scope() as session:
+ algorithm = session.query(Algorithm).filter_by(name='e2e_test').first()
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'leader', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'follower', 'main.py')))
+ algorithm = session.query(Algorithm).filter_by(name='horizontal_e2e_test').first()
+ self.assertEqual(sorted(os.listdir(algorithm.path)), ['follower.py', 'leader.py', 'metrics.py', 'model.py'])
+ algorithm = session.query(Algorithm).filter_by(name='secure_boost').first()
+ self.assertIsNone(algorithm.path)
+ algo_ids = session.query(Algorithm.id).filter_by(source=Source.PRESET).all()
+ self.assertEqual(len(algo_ids), 6)
+
+ @patch('fedlearner_webconsole.algorithm.preset_algorithms.preset_algorithm_service.PRESET_ALGORITHM_PROJECT_LIST',
+ new_callable=list)
+ @patch('fedlearner_webconsole.algorithm.preset_algorithms.preset_algorithm_service.PRESET_ALGORITHM_LIST',
+ new_callable=list)
+ def test_update_preset_algorithm(self, mock_preset_algorithm_list: MagicMock,
+ mock_preset_algorithm_project_list: MagicMock):
+ mock_preset_algorithm_project_list.extend([
+ AlgorithmProject(name='e2e_test',
+ type=AlgorithmType.NN_VERTICAL,
+ uuid='u1b9eea3753e24fd9b91',
+ source=Source.PRESET,
+ comment='algorithm for end to end test',
+ latest_version=4)
+ ])
+ mock_preset_algorithm_list.extend([
+ Algorithm(name='e2e_test',
+ version=1,
+ uuid='u5c4f510aab2f4a288c8',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'e2e_test_v1'),
+ comment='algorithm for end to end test'),
+ Algorithm(name='e2e_test',
+ version=2,
+ uuid='uc74ce6731906480c804',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'e2e_test_v2'),
+ comment='algorithm for end to end test')
+ ])
+ file_manager = FileManager()
+ Envs.STORAGE_ROOT = tempfile.gettempdir()
+ create_algorithm_if_not_exists()
+ algo = Algorithm(name='e2e_test',
+ version=3,
+ uuid='e2e_test_version_3',
+ source=Source.PRESET,
+ type=AlgorithmType.NN_VERTICAL,
+ path=os.path.join(_ALGORITHMS_PATH, 'e2e_test_v3'),
+ comment='algorithm for end to end test')
+ mock_preset_algorithm_list.append(algo)
+ create_algorithm_if_not_exists()
+ with db.session_scope() as session:
+ algorithm = session.query(Algorithm).filter_by(name='e2e_test', source=Source.PRESET, version=3).first()
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'leader', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'follower', 'main.py')))
+ # when algorithm path does not exist
+ file_manager.remove(algorithm.path)
+ self.assertFalse(file_manager.exists(algorithm.path))
+ session.commit()
+ create_algorithm_if_not_exists()
+ with db.session_scope() as session:
+ algorithm = session.query(Algorithm).filter_by(name='e2e_test', source=Source.PRESET, version=3).first()
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'leader', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'follower', 'main.py')))
+ # when algorithm path is empty
+ file_manager.remove(algorithm.path)
+ file_manager.mkdir(algorithm.path)
+ self.assertEqual(len(file_manager.ls(algorithm.path)), 0)
+ create_algorithm_if_not_exists()
+ with db.session_scope() as session:
+ algorithm = session.query(Algorithm).filter_by(name='e2e_test', source=Source.PRESET, version=3).first()
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'leader', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(algorithm.path, 'follower', 'main.py')))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/service.py b/web_console_v2/api/fedlearner_webconsole/algorithm/service.py
new file mode 100644
index 000000000..c808a2a8c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/service.py
@@ -0,0 +1,233 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tarfile
+import tempfile
+from io import FileIO
+from datetime import datetime
+from typing import Optional, List
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.filtering_pb2 import FilterOp, FilterExpression
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, Source, ReleaseStatus, PublishStatus, \
+ PendingAlgorithm, AlgorithmType
+
+# TODO(wangzeju): use singleton of file_manager or file_operator
+file_manager = FileManager()
+file_operator = FileOperator()
+
+
+class AlgorithmProjectService:
+
+ FILTER_FIELDS = {
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'type': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: None}),
+ }
+
+ def __init__(self, session: Session):
+ self._session = session
+ self._filter_builder = FilterBuilder(model_class=AlgorithmProject, supported_fields=self.FILTER_FIELDS)
+
+ @staticmethod
+ def _extract_to(file, path):
+ with tempfile.TemporaryDirectory() as directory:
+ with tarfile.open(fileobj=file) as tar:
+ tar.extractall(directory)
+ for root, _, files in os.walk(directory):
+ for name in files:
+ # There will be error files starting with '._' when the file is uploaded from the MacOS system
+ if name.startswith('._') or name.endswith('.pyc'):
+ os.remove(os.path.join(root, name))
+ file_operator.copy_to(directory, path)
+
+ def create_algorithm_project(self,
+ name: str,
+ project_id: int,
+ algorithm_type: AlgorithmType,
+ username: str,
+ parameter,
+ path: str,
+ comment: Optional[str] = None,
+ file: Optional[FileIO] = None) -> AlgorithmProject:
+ if file is not None:
+ self._extract_to(file, path)
+ algo_project = AlgorithmProject(name=name,
+ uuid=resource_uuid(),
+ project_id=project_id,
+ type=algorithm_type,
+ source=Source.USER,
+ username=username,
+ path=path,
+ comment=comment)
+ algo_project.set_parameter(parameter)
+ self._session.add(algo_project)
+ self._session.flush()
+ return algo_project
+
+ def release_algorithm(self,
+ algorithm_project: AlgorithmProject,
+ username: str,
+ path: str,
+ participant_id: Optional[int] = None,
+ comment: Optional[str] = None):
+ # apply exclusive lock on algorithm project to avoid race condition on algorithm version
+ algo_project: AlgorithmProject = self._session.query(
+ AlgorithmProject).populate_existing().with_for_update().get(algorithm_project.id)
+ file_operator.copy_to(algorithm_project.path, path, create_dir=True)
+ algo = Algorithm(name=algorithm_project.name,
+ type=algorithm_project.type,
+ parameter=algorithm_project.parameter,
+ path=path,
+ source=Source.USER,
+ username=username,
+ participant_id=participant_id,
+ project_id=algorithm_project.project_id,
+ algorithm_project_id=algorithm_project.id,
+ comment=comment)
+ algo.uuid = resource_uuid()
+ algo.version = algo_project.latest_version + 1
+ algo_project.latest_version += 1
+ algo_project.release_status = ReleaseStatus.RELEASED
+ self._session.add(algo)
+ return algo
+
+ # TODO(linfan): implement delete file from file system
+ def delete(self, algorithm_project: AlgorithmProject):
+ algorithm_service = AlgorithmService(self._session)
+ for algo in algorithm_project.algorithms:
+ algorithm_service.delete(algo)
+ self._session.delete(algorithm_project)
+
+ def get_published_algorithm_projects(self, project_id: int,
+ filter_exp: Optional[FilterExpression]) -> List[AlgorithmProject]:
+ query = self._session.query(AlgorithmProject).filter_by(project_id=project_id,
+ publish_status=PublishStatus.PUBLISHED)
+ try:
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ return query.all()
+
+ def get_published_algorithms_latest_update_time(self, algorithm_project_id: int) -> datetime:
+ algo = self._session.query(Algorithm).filter_by(algorithm_project_id=algorithm_project_id,
+ publish_status=PublishStatus.PUBLISHED).order_by(
+ Algorithm.updated_at.desc()).limit(1).first()
+ return algo.updated_at
+
+
+class PendingAlgorithmService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def create_algorithm_project(self,
+ pending_algorithm: PendingAlgorithm,
+ username: str,
+ name: str,
+ comment: Optional[str] = None) -> AlgorithmProject:
+ algo_project = self._session.query(AlgorithmProject).filter(
+ AlgorithmProject.name == name, AlgorithmProject.source == Source.THIRD_PARTY).first()
+ if algo_project is not None:
+ raise ValueError(f'there already exists algorithm project with name {name} from third party')
+ algorithm_project = AlgorithmProject(name=name,
+ project_id=pending_algorithm.project_id,
+ latest_version=pending_algorithm.version,
+ type=pending_algorithm.type,
+ source=Source.THIRD_PARTY,
+ username=username,
+ participant_id=pending_algorithm.participant_id,
+ comment=comment,
+ uuid=pending_algorithm.algorithm_project_uuid,
+ release_status=ReleaseStatus.RELEASED)
+ algorithm_project.set_parameter(pending_algorithm.get_parameter())
+ self._session.add(algorithm_project)
+ return algorithm_project
+
+ def create_algorithm(self,
+ pending_algorithm: PendingAlgorithm,
+ algorithm_project_id: int,
+ username: str,
+ path: str,
+ comment: Optional[str] = None) -> Algorithm:
+ file_operator.copy_to(pending_algorithm.path, path, create_dir=True)
+ algo = Algorithm(name=pending_algorithm.name,
+ type=pending_algorithm.type,
+ parameter=pending_algorithm.parameter,
+ path=path,
+ source=Source.THIRD_PARTY,
+ username=username,
+ participant_id=pending_algorithm.participant_id,
+ project_id=pending_algorithm.project_id,
+ algorithm_project_id=algorithm_project_id,
+ uuid=pending_algorithm.algorithm_uuid,
+ version=pending_algorithm.version,
+ comment=comment)
+ self._session.add(algo)
+ return algo
+
+
+class AlgorithmService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def _update_algorithm_project_publish_status(self, algorithm_project_id: int):
+ algorithms = self._session.query(Algorithm).filter_by(algorithm_project_id=algorithm_project_id,
+ publish_status=PublishStatus.PUBLISHED).all()
+ # There may be a race condition here. Only one Algorithm under the AlgorithmProject is published.
+ # At this time, if an algorithm is published and an algorithm is unpublished or deleted at the same time,
+ # there may be a "published" Algorithm under the AlgorithmProject, but the AlgorithmProject is
+ # UNPUBLISHED. The user "Publish" or "Unpublish" the algorithm again, and it will be normal.
+ if len(algorithms) == 0:
+ algo_project = self._session.query(AlgorithmProject).get(algorithm_project_id)
+ algo_project.publish_status = PublishStatus.UNPUBLISHED
+
+ def _update_algorithm_project_release_status(self, algorithm_project_id: int):
+ algorithms = self._session.query(Algorithm).filter_by(algorithm_project_id=algorithm_project_id).all()
+ # There may be a race condition here too.
+ if len(algorithms) == 0:
+ algo_project = self._session.query(AlgorithmProject).get(algorithm_project_id)
+ algo_project.release_status = ReleaseStatus.UNRELEASED
+
+ def delete(self, algorithm: Algorithm):
+ self._session.delete(algorithm)
+ algo_project = self._session.query(AlgorithmProject).get(algorithm.algorithm_project_id)
+ if algo_project.latest_version == algorithm.version:
+ algo_project.release_status = ReleaseStatus.UNRELEASED
+ self._update_algorithm_project_publish_status(algorithm_project_id=algorithm.algorithm_project_id)
+ self._update_algorithm_project_release_status(algorithm_project_id=algorithm.algorithm_project_id)
+
+ def publish_algorithm(self, algorithm_id: int, project_id: int) -> Algorithm:
+ algorithm = self._session.query(Algorithm).filter_by(id=algorithm_id, project_id=project_id).first()
+ algorithm.publish_status = PublishStatus.PUBLISHED
+ algo_project = self._session.query(AlgorithmProject).get(algorithm.algorithm_project_id)
+ algo_project.publish_status = PublishStatus.PUBLISHED
+ return algorithm
+
+ def unpublish_algorithm(self, algorithm_id: int, project_id: int) -> Algorithm:
+ algorithm = self._session.query(Algorithm).filter_by(id=algorithm_id, project_id=project_id).first()
+ algorithm.publish_status = PublishStatus.UNPUBLISHED
+ self._update_algorithm_project_publish_status(algorithm_project_id=algorithm.algorithm_project_id)
+ return algorithm
+
+ def get_published_algorithms(self, project_id: int, algorithm_project_id: int) -> List[Algorithm]:
+ return self._session.query(Algorithm).filter_by(project_id=project_id,
+ algorithm_project_id=algorithm_project_id,
+ publish_status=PublishStatus.PUBLISHED).all()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/service_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/service_test.py
new file mode 100644
index 000000000..39dec7a58
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/service_test.py
@@ -0,0 +1,128 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tarfile
+import tempfile
+import unittest
+
+from io import BytesIO
+from pathlib import Path
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, PublishStatus, ReleaseStatus
+from fedlearner_webconsole.algorithm.service import AlgorithmService, AlgorithmProjectService
+from fedlearner_webconsole.db import db
+
+
+class AlgorithmServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ algo_project1 = AlgorithmProject(id=1,
+ name='test-algo-project-1',
+ publish_status=PublishStatus.PUBLISHED,
+ release_status=ReleaseStatus.RELEASED)
+ algo_project2 = AlgorithmProject(id=2,
+ name='test-algo-project-2',
+ latest_version=3,
+ publish_status=PublishStatus.PUBLISHED,
+ release_status=ReleaseStatus.RELEASED)
+ algo1 = Algorithm(id=1, algorithm_project_id=1, name='test-algo-1', publish_status=PublishStatus.PUBLISHED)
+ algo2 = Algorithm(id=2,
+ algorithm_project_id=2,
+ name='test-algo-2',
+ version=1,
+ publish_status=PublishStatus.PUBLISHED)
+ algo3 = Algorithm(id=3,
+ algorithm_project_id=2,
+ name='test-algo-3',
+ version=2,
+ publish_status=PublishStatus.PUBLISHED)
+ algo4 = Algorithm(id=4,
+ algorithm_project_id=2,
+ name='test-algo-4',
+ version=3,
+ publish_status=PublishStatus.PUBLISHED)
+ session.add_all([algo_project1, algo_project2, algo1, algo2, algo3, algo4])
+ session.commit()
+
+ def test_delete_algorithm(self):
+ with db.session_scope() as session:
+ algo1 = session.query(Algorithm).filter_by(name='test-algo-1').first()
+ AlgorithmService(session).delete(algo1)
+ algo1 = session.query(Algorithm).filter_by(name='test-algo-1').execution_options(
+ include_deleted=True).first()
+ self.assertIsNone(algo1)
+ algo_project1 = session.query(AlgorithmProject).get(1)
+ self.assertEqual(algo_project1.release_status, ReleaseStatus.UNRELEASED)
+
+ def test_algorithm_project_status_when_delete_algorithms(self):
+ with db.session_scope() as session:
+ algo2 = session.query(Algorithm).filter_by(name='test-algo-2').first()
+ algo3 = session.query(Algorithm).filter_by(name='test-algo-3').first()
+ algo4 = session.query(Algorithm).filter_by(name='test-algo-4').first()
+ AlgorithmService(session).delete(algo4)
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project-2').first()
+ self.assertEqual(algo_project.publish_status, PublishStatus.PUBLISHED)
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+ algo_project.release_status = ReleaseStatus.RELEASED
+ AlgorithmService(session).delete(algo2)
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project-2').first()
+ self.assertEqual(algo_project.publish_status, PublishStatus.PUBLISHED)
+ self.assertEqual(algo_project.release_status, ReleaseStatus.RELEASED)
+ AlgorithmService(session).delete(algo3)
+ algo_project = session.query(AlgorithmProject).filter_by(name='test-algo-project-2').first()
+ self.assertEqual(algo_project.publish_status, PublishStatus.UNPUBLISHED)
+ self.assertEqual(algo_project.release_status, ReleaseStatus.UNRELEASED)
+
+
+class AlgorithmProjectServiceTest(NoWebServerTestCase):
+
+ def test_extract_files(self):
+ path = tempfile.mkdtemp()
+ path = Path(path, 'e2e_test').resolve()
+ path.mkdir()
+ path.joinpath('follower').mkdir()
+ path.joinpath('follower').joinpath('main.py').touch()
+ path.joinpath('follower').joinpath('._main.py').touch()
+ path.joinpath('follower').joinpath('main.pyc').touch()
+ path.joinpath('leader').mkdir()
+ path.joinpath('leader').joinpath('___main.py').touch()
+ file_path = path.joinpath('leader').joinpath('main.py')
+ file_path.touch()
+ file_path.write_text('import tensorflow', encoding='utf-8')
+ tar_path = os.path.join(tempfile.mkdtemp(), 'test.tar.gz')
+ with tarfile.open(tar_path, 'w:gz') as tar:
+ tar.add(os.path.join(path, 'leader'), arcname='leader')
+ tar.add(os.path.join(path, 'follower'), arcname='follower')
+ tar = tarfile.open(tar_path, 'r') # pylint: disable=consider-using-with
+ with tempfile.TemporaryDirectory() as directory:
+ with db.session_scope() as session:
+ # pylint: disable=protected-access
+ AlgorithmProjectService(session)._extract_to(BytesIO(tar.fileobj.read()), directory)
+ self.assertTrue(os.path.exists(os.path.join(directory, 'leader', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(directory, 'follower', 'main.py')))
+ self.assertTrue(os.path.exists(os.path.join(directory, 'leader', '___main.py')))
+ self.assertFalse(os.path.exists(os.path.join(directory, 'follower', '._main.py')))
+ self.assertFalse(os.path.exists(os.path.join(directory, 'follower', 'main.pyc')))
+ with open(os.path.join(directory, 'leader', 'main.py'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), 'import tensorflow')
+ with open(os.path.join(directory, 'follower', 'main.py'), encoding='utf-8') as fin:
+ self.assertEqual(fin.read(), '')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/BUILD.bazel
new file mode 100644
index 000000000..f40922097
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/BUILD.bazel
@@ -0,0 +1,69 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "transmit",
+ srcs = [
+ "hash.py",
+ "receiver.py",
+ "sender.py",
+ ],
+ imports = ["../../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ ],
+)
+
+py_test(
+ name = "hash_test",
+ srcs = [
+ "hash_test.py",
+ ],
+ imports = ["../../.."],
+ deps = [
+ ":transmit",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ ],
+)
+
+py_test(
+ name = "sender_test",
+ srcs = [
+ "sender_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data/algorithm",
+ ],
+ imports = ["../../.."],
+ deps = [
+ ":transmit",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "receiver_test",
+ srcs = [
+ "receiver_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data/algorithm",
+ ],
+ imports = ["../../.."],
+ deps = [
+ ":transmit",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/hash.py b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/hash.py
new file mode 100644
index 000000000..e93828474
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/hash.py
@@ -0,0 +1,24 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import hashlib
+
+from fedlearner_webconsole.utils.file_manager import FileManager
+
+
+def get_file_md5(file_manager: FileManager, file_name: str) -> str:
+ # TODO(gezhengqiang): solve memory overflow problem
+ data = file_manager.read(file_name).encode()
+ return hashlib.md5(data).hexdigest()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/hash_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/hash_test.py
new file mode 100644
index 000000000..8b4ba6636
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/hash_test.py
@@ -0,0 +1,38 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+
+from pathlib import Path
+
+from fedlearner_webconsole.algorithm.transmit.hash import get_file_md5
+from fedlearner_webconsole.utils.file_manager import FileManager
+
+
+class HashTest(unittest.TestCase):
+
+ def test_get_file_md5(self):
+ with tempfile.NamedTemporaryFile() as f:
+ Path(f.name).write_text('hello world', encoding='utf-8')
+ self.assertEqual(get_file_md5(FileManager(), f.name), '5eb63bbbe01eeed093cb22bb8f5acdc3')
+
+ def test_get_file_md5_empty_file(self):
+ with tempfile.NamedTemporaryFile() as f:
+ self.assertEqual(get_file_md5(FileManager(), f.name), 'd41d8cd98f00b204e9800998ecf8427e')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/receiver.py b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/receiver.py
new file mode 100644
index 000000000..4b8e91c8c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/receiver.py
@@ -0,0 +1,49 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+from typing import Iterator, Optional
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.algorithm.transmit.hash import get_file_md5
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmData
+
+file_operator = FileOperator()
+
+
+class AlgorithmReceiver(object):
+
+ def __init__(self):
+ self._file_manager = FileManager()
+
+ def write_data_and_extract(self,
+ data_iterator: Iterator[AlgorithmData],
+ dest: str,
+ expected_file_hash: Optional[str] = None):
+ temp_dir = f'{dest}_temp'
+ self._file_manager.mkdir(temp_dir)
+ with tempfile.NamedTemporaryFile(suffix='.tar') as temp_file:
+ # TODO: limit the size of the received file
+ _written = False
+ for data in data_iterator:
+ self._file_manager.write(temp_file.name, data.chunk, mode='a')
+ _written = True
+ if _written:
+ if expected_file_hash is not None:
+ file_hash = get_file_md5(self._file_manager, temp_file.name)
+ if file_hash != expected_file_hash:
+ raise ValueError('The received file is not completed')
+ file_operator.extract_to(temp_file.name, temp_dir, create_dir=True)
+ self._file_manager.rename(temp_dir, dest)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/receiver_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/receiver_test.py
new file mode 100644
index 000000000..32d13465e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/receiver_test.py
@@ -0,0 +1,91 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+
+from google.protobuf.json_format import ParseDict
+from envs import Envs
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmType, Source
+from fedlearner_webconsole.algorithm.transmit.receiver import AlgorithmReceiver
+from fedlearner_webconsole.algorithm.transmit.sender import AlgorithmSender
+from fedlearner_webconsole.algorithm.utils import algorithm_cache_path
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmParameter
+
+_TEST_ALGORITHM_PATH = os.path.join(Envs.BASE_DIR, 'testing/test_data/algorithm/e2e_test')
+
+
+class AlgorithmReceiverTest(NoWebServerTestCase):
+
+ def test_recv_algorithm_files(self):
+ parameter = ParseDict({'variables': [{'name': 'BATCH_SIZE', 'value': '128'}]}, AlgorithmParameter())
+ with db.session_scope() as session:
+ algo1 = Algorithm(id=1,
+ name='algo-1',
+ uuid='algo-uuid-1',
+ path=_TEST_ALGORITHM_PATH,
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ comment='comment',
+ version=1)
+ algo1.set_parameter(parameter)
+ session.commit()
+ data_iterator = AlgorithmSender().make_algorithm_iterator(algo1.path)
+ receiver = AlgorithmReceiver()
+ with tempfile.TemporaryDirectory() as temp_dir:
+ resp = next(data_iterator)
+ algo_cache_path = algorithm_cache_path(temp_dir, 'algo-uuid-2')
+ receiver.write_data_and_extract(data_iterator, algo_cache_path, resp.hash)
+ self.assertTrue(os.path.exists(algo_cache_path))
+ self.assertEqual(sorted(os.listdir(algo_cache_path)), ['follower', 'leader'])
+ with open(os.path.join(algo_cache_path, 'leader', 'main.py'), encoding='utf-8') as f:
+ self.assertEqual(f.read(), 'import tensorflow\n')
+ with open(os.path.join(algo_cache_path, 'follower', 'main.py'), encoding='utf-8') as f:
+ self.assertEqual(f.read(), '')
+
+ def test_write_data_and_extra_when_no_files(self):
+ with db.session_scope() as session:
+ algo = Algorithm(id=1,
+ name='algo-1',
+ uuid='algo-uuid-1',\
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ comment='comment',
+ version=1)
+ session.commit()
+ data_iterator = AlgorithmSender().make_algorithm_iterator(algo.path)
+ next(data_iterator)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ path = os.path.join(temp_dir, 'test')
+ AlgorithmReceiver().write_data_and_extract(data_iterator, path)
+ self.assertTrue(os.path.exists(path))
+ self.assertEqual(os.listdir(path), [])
+
+ def test_write_data_iterator_with_wrong_hash(self):
+ with db.session_scope() as session:
+ sender = AlgorithmSender()
+ data_iterator = sender.make_algorithm_iterator(_TEST_ALGORITHM_PATH)
+ # Consumes hash code response first
+ next(data_iterator)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ with self.assertRaises(ValueError):
+ AlgorithmReceiver().write_data_and_extract(data_iterator, temp_dir, 'wrong_hash')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/sender.py b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/sender.py
new file mode 100644
index 000000000..3fe762fab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/sender.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+from io import BytesIO
+from typing import Generator, Union
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.algorithm.transmit.hash import get_file_md5
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import GetAlgorithmFilesResponse
+
+_DEFAULT_CHUNK_SIZE = 1024 * 1024
+_FILE_MANAGER = FileManager()
+_FILE_OPERATOR = FileOperator()
+
+
+class AlgorithmSender(object):
+
+ def __init__(self, chunk_size: int = _DEFAULT_CHUNK_SIZE):
+ self.chunk_size = chunk_size
+
+ def _file_content_generator(self, file: BytesIO) -> Generator[bytes, None, None]:
+ while True:
+ chunk = file.read(self.chunk_size)
+ if len(chunk) == 0:
+ return
+ yield chunk
+
+ def _archive_algorithm_files_into(self, algo_path: Union[str, None], dest_tar: str):
+ if algo_path is None:
+ return
+ sources = [file.path for file in _FILE_MANAGER.ls(algo_path, include_directory=True)]
+ if len(sources) > 0:
+ _FILE_OPERATOR.archive_to(sources, dest_tar)
+
+ def make_algorithm_iterator(self, algo_path: str) -> Generator[GetAlgorithmFilesResponse, None, None]:
+ with tempfile.NamedTemporaryFile(suffix='.tar') as temp_file:
+ self._archive_algorithm_files_into(algo_path, temp_file.name)
+ file_hash = get_file_md5(_FILE_MANAGER, temp_file.name)
+ chunk_generator = self._file_content_generator(temp_file.file)
+ yield GetAlgorithmFilesResponse(hash=file_hash)
+ for chunk in chunk_generator:
+ yield GetAlgorithmFilesResponse(chunk=chunk)
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/sender_test.py b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/sender_test.py
new file mode 100644
index 000000000..678be8385
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/transmit/sender_test.py
@@ -0,0 +1,44 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from envs import Envs
+from fedlearner_webconsole.algorithm.transmit.sender import AlgorithmSender
+from testing.common import NoWebServerTestCase
+
+_TEST_ALGORITHM_PATH = os.path.join(Envs.BASE_DIR, 'testing/test_data/algorithm/e2e_test')
+
+
+class AlgorithmSenderTest(NoWebServerTestCase):
+
+ def test_make_algorithm_iterator(self):
+ sender = AlgorithmSender(chunk_size=1024)
+ data_iterator = sender.make_algorithm_iterator(_TEST_ALGORITHM_PATH)
+
+ hash_resp = next(data_iterator)
+ # As tar's hash code is always changing
+ self.assertEqual(len(hash_resp.hash), 32)
+ # Tar archives have a minimum size of 10240 bytes by default
+ chunk_count = 0
+ for data in data_iterator:
+ chunk_count += 1
+ self.assertEqual(len(data.chunk), 1024)
+ self.assertEqual(chunk_count, 10)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/algorithm/utils.py b/web_console_v2/api/fedlearner_webconsole/algorithm/utils.py
new file mode 100644
index 000000000..19074628d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/algorithm/utils.py
@@ -0,0 +1,64 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from contextlib import contextmanager
+from uuid import uuid4
+from slugify import slugify
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.utils.file_manager import file_manager
+
+
+# TODO(hangweiqiang): move Envs.STORAGE_ROOT in function
+def algorithm_path(root_path: str, name: str, version: int) -> str:
+ suffix = now().strftime('%Y%m%d_%H%M%S')
+ return os.path.join(root_path, 'algorithms', f'{slugify(name)}-v{version}-{suffix}-{uuid4().hex[:5]}')
+
+
+def algorithm_cache_path(root_path: str, algorithm_uuid: str) -> str:
+ return os.path.join(root_path, 'algorithm_cache', algorithm_uuid)
+
+
+def algorithm_project_path(root_path: str, name: str) -> str:
+ suffix = now().strftime('%Y%m%d_%H%M%S')
+ return os.path.join(root_path, 'algorithm_projects', f'{slugify(name)}-{suffix}-{uuid4().hex[:5]}')
+
+
+def pending_algorithm_path(root_path: str, name: str, version: int) -> str:
+ suffix = now().strftime('%Y%m%d_%H%M%S')
+ return os.path.join(root_path, 'pending_algorithms', f'{slugify(name)}-v{version}-{suffix}-{uuid4().hex[:5]}')
+
+
+def deleted_name(name: str) -> str:
+ timestamp = now().strftime('%Y%m%d_%H%M%S')
+ return f'deleted_at_{timestamp}_{name}'
+
+
+@contextmanager
+def check_algorithm_file(path: str):
+ """clear the created algorithm files when exceptions
+
+ Example:
+ path = (the path of the algorithm files)
+ with _check_algorithm_file(path):
+ ...
+
+ """
+ try:
+ yield
+ except Exception as e:
+ if os.path.exists(path):
+ file_manager.remove(path)
+ raise e
diff --git a/web_console_v2/api/fedlearner_webconsole/app.py b/web_console_v2/api/fedlearner_webconsole/app.py
index 618b13b30..247227d6c 100644
--- a/web_console_v2/api/fedlearner_webconsole/app.py
+++ b/web_console_v2/api/fedlearner_webconsole/app.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,40 +16,66 @@
# pylint: disable=wrong-import-position, global-statement
import logging
import logging.config
-import os
-import traceback
from http import HTTPStatus
+from json import load
+from pathlib import Path
+
+from apispec.ext.marshmallow import MarshmallowPlugin
+from apispec_webframeworks.flask import FlaskPlugin
+from flasgger import APISpec, Swagger
from flask import Flask, jsonify
from flask_restful import Api
-from flask_jwt_extended import JWTManager
-from envs import Envs
-from fedlearner_webconsole.utils import metrics
-
-jwt = JWTManager()
+from marshmallow import ValidationError
+from sqlalchemy import inspect
+from sqlalchemy.orm import Session
+from webargs.flaskparser import parser
+from envs import Envs
+from fedlearner_webconsole.utils.hooks import pre_start_hook
+from fedlearner_webconsole.composer.apis import initialize_composer_apis
+from fedlearner_webconsole.cleanup.apis import initialize_cleanup_apis
+from fedlearner_webconsole.audit.apis import initialize_audit_apis
+from fedlearner_webconsole.auth.services import UserService
+from fedlearner_webconsole.e2e.apis import initialize_e2e_apis
+from fedlearner_webconsole.flag.apis import initialize_flags_apis
+from fedlearner_webconsole.iam.apis import initialize_iams_apis
+from fedlearner_webconsole.iam.client import create_iams_for_user
+from fedlearner_webconsole.middleware.middlewares import flask_middlewares
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.utils import metrics, const
from fedlearner_webconsole.auth.apis import initialize_auth_apis
from fedlearner_webconsole.project.apis import initialize_project_apis
-from fedlearner_webconsole.workflow_template.apis \
- import initialize_workflow_template_apis
+from fedlearner_webconsole.participant.apis import initialize_participant_apis
+from fedlearner_webconsole.utils.decorators.pp_flask import parser as custom_parser
+from fedlearner_webconsole.utils.swagger import normalize_schema
+from fedlearner_webconsole.workflow_template.apis import initialize_workflow_template_apis
from fedlearner_webconsole.workflow.apis import initialize_workflow_apis
from fedlearner_webconsole.dataset.apis import initialize_dataset_apis
from fedlearner_webconsole.job.apis import initialize_job_apis
from fedlearner_webconsole.setting.apis import initialize_setting_apis
-from fedlearner_webconsole.mmgr.apis import initialize_mmgr_apis
+from fedlearner_webconsole.mmgr.model_apis import initialize_mmgr_model_apis
+from fedlearner_webconsole.mmgr.model_job_apis import initialize_mmgr_model_job_apis
+from fedlearner_webconsole.mmgr.model_job_group_apis import initialize_mmgr_model_job_group_apis
+from fedlearner_webconsole.algorithm.apis import initialize_algorithm_apis
from fedlearner_webconsole.debug.apis import initialize_debug_apis
+from fedlearner_webconsole.serving.apis import initialize_serving_services_apis
from fedlearner_webconsole.sparkapp.apis import initialize_sparkapps_apis
-from fedlearner_webconsole.rpc.server import rpc_server
+from fedlearner_webconsole.file.apis import initialize_files_apis
+from fedlearner_webconsole.tee.apis import initialize_tee_apis
from fedlearner_webconsole.db import db
-from fedlearner_webconsole.exceptions import (make_response,
- WebConsoleApiException,
- InvalidArgumentException,
- NotFoundException)
+from fedlearner_webconsole.exceptions import make_response, WebConsoleApiException, InvalidArgumentException
from fedlearner_webconsole.scheduler.scheduler import scheduler
-from fedlearner_webconsole.utils.k8s_watcher import k8s_watcher
-from fedlearner_webconsole.auth.models import User, Session
-from fedlearner_webconsole.composer.composer import composer
-from logging_config import LOGGING_CONFIG
+from fedlearner_webconsole.k8s.k8s_watcher import k8s_watcher
+from logging_config import get_logging_config
+from werkzeug.exceptions import HTTPException
+
+
+@custom_parser.error_handler
+@parser.error_handler
+def handle_request_parsing_error(validation_error: ValidationError, *args, **kwargs):
+ raise InvalidArgumentException(details=validation_error.messages)
def _handle_bad_request(error):
@@ -63,17 +89,19 @@ def _handle_bad_request(error):
return error
-def _handle_not_found(error):
- """Handles the not found exception raised by framework"""
- if not isinstance(error, WebConsoleApiException):
- return make_response(NotFoundException())
- return error
+def _handle_wsgi_exception(error: HTTPException):
+ logging.exception('Wsgi exception: %s', str(error))
+ response = jsonify(
+ code=error.code,
+ msg=str(error),
+ )
+ response.status_code = error.code
+ return response
def _handle_uncaught_exception(error):
"""A fallback catcher for all exceptions."""
- logging.error('Uncaught exception %s, stack trace:\n %s', str(error),
- traceback.format_exc())
+ logging.exception('Uncaught exception %s', str(error))
response = jsonify(
code=500,
msg='Unknown error',
@@ -82,71 +110,82 @@ def _handle_uncaught_exception(error):
return response
-@jwt.unauthorized_loader
-def _handle_unauthorized_request(reason):
- response = jsonify(code=HTTPStatus.UNAUTHORIZED, msg=reason)
- return response, HTTPStatus.UNAUTHORIZED
-
-
-@jwt.invalid_token_loader
-def _handle_invalid_jwt_request(reason):
- response = jsonify(code=HTTPStatus.UNPROCESSABLE_ENTITY, msg=reason)
- return response, HTTPStatus.UNPROCESSABLE_ENTITY
-
-
-@jwt.expired_token_loader
-def _handle_token_expired_request(expired_token):
- response = jsonify(code=HTTPStatus.UNAUTHORIZED, msg='Token has expired')
- return response, HTTPStatus.UNAUTHORIZED
-
-
-@jwt.user_lookup_loader
-def user_lookup_callback(jwt_header, jwt_data):
- del jwt_header # Unused by user load.
-
- identity = jwt_data['sub']
- return User.query.filter_by(username=identity).one_or_none()
-
-
-@jwt.token_in_blocklist_loader
-def check_if_token_invalid(jwt_header, jwt_data):
- del jwt_header # unused by check_if_token_invalid
-
- jti = jwt_data['jti']
- session = Session.query.filter_by(jti=jti).first()
- return session is None
+def _initial_iams_for_users(session: Session):
+ inspector = inspect(db.engine)
+ if inspector.has_table('users_v2'):
+ try:
+ users = UserService(session).get_all_users()
+ for u in users:
+ create_iams_for_user(u)
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning('Initial iams failed, will be OK after db migration.')
+
+
+def _init_swagger(app: Flask):
+ openapi_version = '3.0.3'
+ spec = APISpec(title='FedLearner WebConsole API Documentation',
+ version=SettingService.get_application_version().version.version,
+ openapi_version=openapi_version,
+ plugins=[FlaskPlugin(), MarshmallowPlugin()])
+ schemas = schema_manager.get_schemas()
+ template = spec.to_flasgger(app, definitions=schemas, paths=[*app.view_functions.values()])
+ app.config['SWAGGER'] = {'title': 'FedLearner WebConsole API Documentation', 'uiversion': 3}
+ for path in (Path(__file__).parent / 'proto' / 'jsonschemas').glob('**/*.json'):
+ with open(path, mode='r', encoding='utf-8') as file:
+ definitions = load(file)['definitions']
+ definitions = normalize_schema(definitions, Path(path))
+ template['components']['schemas'] = {**template['components']['schemas'], **definitions}
+ template['definitions'] = template['components']['schemas']
+ Swagger(app,
+ template=template,
+ config={
+ 'url_prefix': Envs.SWAGGER_URL_PREFIX,
+ 'openapi': openapi_version
+ },
+ merge=True)
def create_app(config):
+ pre_start_hook()
# format logging
- logging.config.dictConfig(LOGGING_CONFIG)
+ logging.config.dictConfig(get_logging_config())
- app = Flask('fedlearner_webconsole')
+ app = Flask('fedlearner_webconsole', root_path=Envs.BASE_DIR)
app.config.from_object(config)
- jwt.init_app(app)
-
# Error handlers
app.register_error_handler(400, _handle_bad_request)
- app.register_error_handler(404, _handle_not_found)
app.register_error_handler(WebConsoleApiException, make_response)
+ app.register_error_handler(HTTPException, _handle_wsgi_exception)
app.register_error_handler(Exception, _handle_uncaught_exception)
-
- # TODO(wangsen.0914): This will be removed sooner!
- db.init_app(app)
-
- api = Api(prefix='/api/v2')
+ # TODO(xiangyuxuan.prs): Initial iams for all existed users, remove when not using memory-iams
+ with db.session_scope() as session:
+ _initial_iams_for_users(session)
+ api = Api(prefix=const.API_VERSION)
+ initialize_composer_apis(api)
+ initialize_cleanup_apis(api)
initialize_auth_apis(api)
initialize_project_apis(api)
+ initialize_participant_apis(api)
initialize_workflow_template_apis(api)
initialize_workflow_apis(api)
initialize_job_apis(api)
initialize_dataset_apis(api)
initialize_setting_apis(api)
- initialize_mmgr_apis(api)
+ initialize_mmgr_model_apis(api)
+ initialize_mmgr_model_job_apis(api)
+ initialize_mmgr_model_job_group_apis(api)
+ initialize_algorithm_apis(api)
initialize_sparkapps_apis(api)
- if os.environ.get('FLASK_ENV') != 'production' or Envs.DEBUG:
+ initialize_files_apis(api)
+ initialize_flags_apis(api)
+ initialize_serving_services_apis(api)
+ initialize_iams_apis(api)
+ initialize_e2e_apis(api)
+ initialize_tee_apis(api)
+ if Envs.FLASK_ENV != 'production' or Envs.DEBUG:
initialize_debug_apis(api)
+ initialize_audit_apis(api)
# A hack that use our customized error handlers
# Ref: https://github.com/flask-restful/flask-restful/issues/280
handle_exception = app.handle_exception
@@ -154,21 +193,16 @@ def create_app(config):
api.init_app(app)
app.handle_exception = handle_exception
app.handle_user_exception = handle_user_exception
-
+ if Envs.FLASK_ENV != 'production' or Envs.DEBUG:
+ _init_swagger(app)
# Inits k8s related stuff first since something in composer
# may depend on it
- if Envs.FLASK_ENV == 'production' or Envs.K8S_CONFIG_PATH is not None:
+ if app.config.get('START_K8S_WATCHER', True):
k8s_watcher.start()
-
- if app.config.get('START_GRPC_SERVER', True):
- rpc_server.stop()
- rpc_server.start(app)
if app.config.get('START_SCHEDULER', True):
scheduler.stop()
- scheduler.start(app)
- if app.config.get('START_COMPOSER', True):
- with app.app_context():
- composer.run(db_engine=db.get_engine())
+ scheduler.start()
- metrics.emit_counter('create_app', 1)
+ metrics.emit_store('create_app', 1)
+ app = flask_middlewares.init_app(app)
return app
diff --git a/web_console_v2/api/fedlearner_webconsole/app_test.py b/web_console_v2/api/fedlearner_webconsole/app_test.py
new file mode 100644
index 000000000..5643c6239
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/app_test.py
@@ -0,0 +1,56 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from http import HTTPStatus
+
+from marshmallow import fields
+from webargs.flaskparser import use_args
+
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
+from testing.common import BaseTestCase
+
+
+class ExceptionHandlersTest(BaseTestCase):
+
+ def test_404(self):
+ self.assert404(self.get_helper('/api/v2/not_found', use_auth=False))
+
+ def test_405(self):
+ self.assert405(self.post_helper('/api/v2/versions', use_auth=False))
+
+ def test_uncaught_exception(self):
+
+ @self.app.route('/test_uncaught')
+ def test_route():
+ raise RuntimeError('Uncaught')
+
+ response = self.get_helper('/test_uncaught', use_auth=False)
+ self.assert500(response)
+
+ def test_marshmallow_validation_error(self):
+
+ @self.app.route('/test_validation')
+ @use_args({'must': fields.Bool(required=True)})
+ def test_route(params):
+ return make_flask_response({'succeeded': params['must']})
+
+ resp = self.get_helper('/test_validation', use_auth=False)
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(resp.get_json()['details'], {'json': {'must': ['Missing data for required field.']}})
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/audit/BUILD.bazel
new file mode 100644
index 000000000..88bc2a1de
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/BUILD.bazel
@@ -0,0 +1,159 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:mixins_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "storage_lib",
+ srcs = ["storage.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = ["services.py"],
+ imports = ["../.."],
+ deps = [
+ ":storage_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "services_lib_test",
+ size = "small",
+ srcs = [
+ "services_test.py",
+ ],
+ imports = ["../.."],
+ main = "services_test.py",
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "decorators_lib",
+ srcs = ["decorators.py"],
+ imports = ["../.."],
+ deps = [
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_flask//:pkg",
+ ],
+)
+
+py_test(
+ name = "decorators_lib_test",
+ testonly = True,
+ srcs = [
+ "decorators_test.py",
+ ],
+ imports = ["../.."],
+ main = "decorators_test.py",
+ deps = [
+ ":decorators_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_base64_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_python_dateutil//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@common_python_dateutil//:pkg",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/__init__.py b/web_console_v2/api/fedlearner_webconsole/audit/__init__.py
new file mode 100644
index 000000000..fc6e7fa2c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/apis.py b/web_console_v2/api/fedlearner_webconsole/audit/apis.py
new file mode 100644
index 000000000..d7c3597d6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/apis.py
@@ -0,0 +1,113 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+from http import HTTPStatus
+from typing import Optional
+
+from flask_restful import Api, Resource
+from marshmallow import fields, validate
+from webargs.flaskparser import use_kwargs
+from dateutil.relativedelta import relativedelta
+from fedlearner_webconsole.audit.models import EventType
+from fedlearner_webconsole.audit.services import EventService
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp, now
+from fedlearner_webconsole.utils.filtering import parse_expression
+
+
+class EventsApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_kwargs(
+ {
+ 'filter_exp': fields.String(validate=validate.Length(min=1), data_key='filter', load_default=None),
+ 'page': fields.Integer(load_default=1),
+ 'page_size': fields.Integer(load_default=10)
+ },
+ location='query')
+ def get(self, filter_exp: Optional[str], page: int, page_size: int):
+ """Get audit events
+ ---
+ tags:
+ - audit
+ description: get audit events
+ parameters:
+ - name: filter
+ in: query
+ schema:
+ type: string
+ - name: page
+ in: query
+ schema:
+ type: integer
+ - name: page_size
+ in: query
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: Events are returned
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Event'
+ """
+ with db.session_scope() as session:
+ if filter_exp is not None:
+ filter_exp = parse_expression(filter_exp)
+ query = EventService(session).get_events(filter_exp)
+ pagination = paginate(query, page, page_size)
+ data = [model.to_proto() for model in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+ @credentials_required
+ @admin_required
+ @use_kwargs({'event_type': fields.String(required=True, validate=validate.OneOf([a.name for a in EventType]))},
+ location='query')
+ def delete(self, event_type: str):
+ """Delete audit events that are older than 6 months
+ ---
+ tags:
+ - audit
+ parameters:
+ - name: event_type
+ in: query
+ schema:
+ type: string
+ responses:
+ 204:
+ description: Events are deleted successfully
+ """
+ end_time = to_timestamp(now() - relativedelta(months=6))
+ if EventType[event_type] == EventType.RPC:
+ filter_exp = parse_expression(f'(and(start_time>0)(end_time<{end_time})(source:["RPC"]))')
+ elif EventType[event_type] == EventType.USER_ENDPOINT: # delete API/UI events
+ filter_exp = parse_expression(
+ f'(and(start_time>0)(end_time<{end_time})(source:["UNKNOWN_SOURCE","UI","API"]))')
+ with db.session_scope() as session:
+ EventService(session).delete_events(filter_exp)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+def initialize_audit_apis(api: Api):
+ api.add_resource(EventsApi, '/events')
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/apis_test.py b/web_console_v2/api/fedlearner_webconsole/audit/apis_test.py
new file mode 100644
index 000000000..7445ac6e5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/apis_test.py
@@ -0,0 +1,158 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+import unittest
+from http import HTTPStatus
+from typing import Tuple
+from datetime import timedelta, timezone
+
+from dateutil.relativedelta import relativedelta
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.const import API_VERSION
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp, now
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.audit.models import EventModel, to_model
+from testing.common import BaseTestCase
+
+PATH_PREFIX = f'{API_VERSION}/events'
+
+
+def get_times() -> Tuple[int, int]:
+ ts = to_timestamp(now())
+ return ts - 60 * 2, ts + 60 * 2
+
+
+def generate_event() -> EventModel:
+ return to_model(
+ Event(name='some_event',
+ user_id=1,
+ resource_type=Event.ResourceType.IAM,
+ resource_name='some_resource',
+ op_type=Event.OperationType.CREATE,
+ result=Event.Result.SUCCESS,
+ result_code='OK',
+ source=Event.Source.UI))
+
+
+def generate_rpc_event() -> EventModel:
+ return to_model(
+ Event(name='some_rpc_event',
+ user_id=1,
+ resource_type=Event.ResourceType.WORKFLOW,
+ resource_name='workflow_uuid',
+ op_type=Event.OperationType.CREATE,
+ result=Event.Result.SUCCESS,
+ result_code='OK',
+ source=Event.Source.RPC))
+
+
+class EventApisTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ events = [generate_event() for _ in range(5)]
+ with db.session_scope() as session:
+ session.bulk_save_objects(events)
+ session.commit()
+
+ def test_get_events(self):
+ start_time, end_time = get_times()
+ self.signin_as_admin()
+
+ response = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertStatus(response, HTTPStatus.OK)
+ self.assertEqual(5, len(self.get_response_data(response)))
+ self.assertEqual('CREATE', self.get_response_data(response)[0].get('op_type'))
+
+ response = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="admin")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertEqual(0, len(self.get_response_data(response)))
+
+ start_time = to_timestamp(now(timezone.utc) + timedelta(hours=2))
+ end_time = to_timestamp(now(timezone.utc) + timedelta(hours=3))
+
+ response = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertEqual(0, len(self.get_response_data(response)))
+
+ response = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{end_time})(end_time<{start_time}))')
+ self.assertEqual(0, len(self.get_response_data(response)))
+
+ def test_delete_events(self):
+ rpc_events = [generate_rpc_event() for _ in range(3)]
+ created_at = now(timezone.utc) - relativedelta(months=8)
+ with db.session_scope() as session:
+ session.bulk_save_objects(rpc_events)
+ session.query(EventModel).update({'created_at': created_at})
+ session.commit()
+ start_time, end_time = get_times()
+
+ self.signin_as_admin()
+ response = self.delete_helper(f'{PATH_PREFIX}?event_type=USER_ENDPOINT')
+ self.assertStatus(response, HTTPStatus.NO_CONTENT)
+
+ start_time = to_timestamp(now(timezone.utc) - relativedelta(months=9))
+ end_time = to_timestamp(now(timezone.utc) - relativedelta(months=7))
+
+ response = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertEqual(3, len(self.get_response_data(response)))
+
+ def test_delete_rpc_events(self):
+ rpc_events = [generate_rpc_event() for _ in range(3)]
+ created_at = now(timezone.utc) - relativedelta(months=8)
+ with db.session_scope() as session:
+ session.bulk_save_objects(rpc_events)
+ session.query(EventModel).update({'created_at': created_at})
+ session.commit()
+
+ start_time, end_time = get_times()
+
+ self.signin_as_admin()
+ response = self.delete_helper(f'{PATH_PREFIX}?event_type=RPC')
+ self.assertStatus(response, HTTPStatus.NO_CONTENT)
+
+ start_time = to_timestamp(now(timezone.utc) - relativedelta(months=9))
+ end_time = to_timestamp(now(timezone.utc) - relativedelta(months=7))
+
+ response = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertEqual(5, len(self.get_response_data(response)))
+
+ def test_get_with_op_type(self):
+ start_time, end_time = get_times()
+
+ self.signin_as_admin()
+ resp = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{start_time})(end_time<{end_time})(op_type="CREATE"))' # pylint: disable=line-too-long
+ )
+ self.assertEqual(5, len(self.get_response_data(resp)))
+
+ resp = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{end_time})(end_time<{start_time})(op_type="UPDATE"))' # pylint: disable=line-too-long
+ )
+ self.assertEqual(0, len(self.get_response_data(resp)))
+
+ resp = self.get_helper(
+ f'{PATH_PREFIX}?filter=(and(username="ada")(start_time>{end_time})(end_time<{start_time})(op_type="poop"))')
+ self.assert200(resp)
+ self.assertEqual(0, len(self.get_response_data(resp)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/decorators.py b/web_console_v2/api/fedlearner_webconsole/audit/decorators.py
new file mode 100644
index 000000000..19698d6c2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/decorators.py
@@ -0,0 +1,263 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+import grpc
+import json
+import logging
+from functools import wraps
+from typing import Optional, Dict, Tuple, Callable
+from envs import Envs
+from flask import request
+from google.protobuf.message import Message
+from google.protobuf.empty_pb2 import Empty
+from fedlearner_webconsole.audit.services import EventService
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.const import API_VERSION
+from fedlearner_webconsole.utils.flask_utils import get_current_user
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.common_pb2 import StatusCode
+from fedlearner_webconsole.proto.service_pb2 import TwoPcRequest
+from fedlearner_webconsole.exceptions import UnauthorizedException, InvalidArgumentException
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.rpc.auth import get_common_name, PROJECT_NAME_HEADER, SSL_CLIENT_SUBJECT_DN_HEADER
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType
+# TODO(wangsen.0914): IAM and SYSTEM WIP
+
+RESOURCE_TYPE_MAPPING = {
+ 'projects': Event.ResourceType.WORKSPACE,
+ 'workflow_templates': Event.ResourceType.TEMPLATE,
+ 'workflows': Event.ResourceType.WORKFLOW,
+ 'datasets': Event.ResourceType.DATASET,
+ 'models': Event.ResourceType.MODEL,
+ 'auth': Event.ResourceType.USER,
+ 'participants': Event.ResourceType.PARTICIPANT,
+ 'serving_services': Event.ResourceType.SERVING_SERVICE,
+ 'algorithm_projects': Event.ResourceType.ALGORITHM_PROJECT,
+ 'preset_algorithms': Event.ResourceType.PRESET_ALGORITHM
+}
+
+OP_TYPE_MAPPING = {
+ 'post': Event.OperationType.CREATE,
+ 'patch': Event.OperationType.UPDATE,
+ 'put': Event.OperationType.UPDATE,
+ 'delete': Event.OperationType.DELETE
+}
+
+STATUS_TYPE_MAPPING = {
+ StatusCode.STATUS_SUCCESS: grpc.StatusCode.OK.name,
+ StatusCode.STATUS_UNKNOWN_ERROR: grpc.StatusCode.UNKNOWN.name,
+ StatusCode.STATUS_UNAUTHORIZED: grpc.StatusCode.UNAUTHENTICATED.name,
+ StatusCode.STATUS_NOT_FOUND: grpc.StatusCode.NOT_FOUND.name,
+ StatusCode.STATUS_INVALID_ARGUMENT: grpc.StatusCode.INVALID_ARGUMENT.name
+}
+
+RESULT_TYPE_MAPPING = {
+ Event.Result.UNKNOWN_RESULT: grpc.StatusCode.UNKNOWN.name,
+ Event.Result.SUCCESS: grpc.StatusCode.OK.name,
+ Event.Result.FAILURE: grpc.StatusCode.ABORTED.name
+}
+
+
+def emits_event(resource_type: Event.ResourceType = Event.ResourceType.UNKNOWN_RESOURCE_TYPE,
+ op_type: Event.OperationType = Event.OperationType.UNKNOWN_OPERATION_TYPE,
+ audit_fields: Optional[list] = None):
+
+ def wrapper_func(func):
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ user = get_current_user()
+ if user is None:
+ return func(*args, **kwargs)
+ fields, result = _infer_event_fields(resource_type, op_type, audit_fields), Event.Result.SUCCESS
+ try:
+ data, *_ = func(*args, **kwargs)
+ if fields['op_type'] == Event.OperationType.CREATE:
+ fields['resource_name'] += f'/{data.get("data").get("id")}'
+ return (data, *_)
+ except Exception as e:
+ result = Event.Result.FAILURE
+ raise e
+ finally:
+ # TODO(yeqiuhan): deprecate result in two release cut
+ _emit_event(user_id=user.id, result=result, result_code=RESULT_TYPE_MAPPING[result], fields=fields)
+
+ return wrapper
+
+ return wrapper_func
+
+
+# TODO(yeqiuhan): Call local server for operation
+
+
+def emits_rpc_event(
+ resource_name_fn: Callable[[Message], str],
+ resource_type: Event.ResourceType = Event.ResourceType.UNKNOWN_RESOURCE_TYPE,
+ op_type: Event.OperationType = Event.OperationType.UNKNOWN_OPERATION_TYPE,
+):
+
+ def wrapper_func(func):
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ func_params = _get_func_params(func, *args, **kwargs)
+ fields = _infer_rpc_event_fields(func_params, resource_type, resource_name_fn, op_type)
+ try:
+ response = func(*args, **kwargs)
+ return response
+ except Exception as e:
+ raise e
+ finally:
+ # use public interface to get status code until upgrade of grpc from 1.32 to 1.38+
+ if func_params['context']._state.code is not None: # pylint: disable=protected-access
+ result_code = func_params['context']._state.code.name # pylint: disable=protected-access
+ else:
+ if isinstance(response, Empty) or response.DESCRIPTOR.fields_by_name.get('status') is None:
+ result_code = 'OK'
+ elif isinstance(response.status.code, int):
+ result_code = STATUS_TYPE_MAPPING[response.status.code]
+ else:
+ result_code = 'UNKNOWN'
+ _emit_event(user_id=None,
+ result=Event.Result.SUCCESS if result_code == 'OK' else Event.Result.FAILURE,
+ result_code=result_code,
+ fields=fields)
+
+ return wrapper
+
+ return wrapper_func
+
+
+def _infer_event_fields(resource_type: Event.ResourceType = Event.ResourceType.UNKNOWN_RESOURCE_TYPE,
+ op_type: Event.OperationType = Event.OperationType.UNKNOWN_OPERATION_TYPE,
+ audit_fields: Optional[list] = None) -> Dict[str, any]:
+ # path: API_PATH_PREFIX/resource_type/...
+ if resource_type == Event.ResourceType.UNKNOWN_RESOURCE_TYPE:
+ resource_type = RESOURCE_TYPE_MAPPING.get(request.path.partition(API_VERSION)[-1].split('/')[1].lower())
+ if op_type == Event.OperationType.UNKNOWN_OPERATION_TYPE:
+ op_type = OP_TYPE_MAPPING.get(request.method.lower())
+ body = request.get_json(force=True, silent=True)
+ resource_name = request.path.rpartition(API_VERSION)[-1]
+ extra = {k: body.get(k) for k in audit_fields} if audit_fields else {}
+ coordinator_pure_domain_name = SettingService.get_system_info().pure_domain_name
+ return {
+ 'name': Event.OperationType.Name(op_type).lower() + Event.ResourceType.Name(resource_type).capitalize(),
+ 'resource_type': resource_type,
+ 'resource_name': resource_name,
+ 'op_type': op_type,
+ # TODO(wangsen.0914): source depends on credentials
+ 'source': Event.Source.UI,
+ 'extra': json.dumps(extra),
+ 'coordinator_pure_domain_name': coordinator_pure_domain_name
+ }
+
+
+def _get_func_params(func, *args, **kwargs):
+ dict_param = {}
+ for arg in list(kwargs.values()) + list(args):
+ if isinstance(arg, grpc.ServicerContext):
+ dict_param['context'] = arg
+ if isinstance(arg, Message):
+ dict_param['request'] = arg
+ return dict_param
+
+
+def _infer_auth_info(rpc_request, context) -> Tuple[Optional[str], Optional[int]]:
+ if Envs.FLASK_ENV == 'production':
+ metadata = dict(context.invocation_metadata())
+ if not metadata:
+ raise UnauthorizedException('No client subject dn found')
+ cn = get_common_name(metadata.get(SSL_CLIENT_SUBJECT_DN_HEADER))
+ if not cn:
+ raise UnauthorizedException('Failed to get domain name from certs')
+ pure_domain_name = get_pure_domain_name(cn)
+ with db.session_scope() as session:
+ if 'auth_info' in rpc_request.keys(): # v1
+ project_name = rpc_request['auth_info']['project_name']
+ else: # v2
+ project_name = metadata.get(PROJECT_NAME_HEADER)
+ project = session.query(Project).filter_by(name=project_name).first()
+ project_id = project.id if project is not None else None
+ return pure_domain_name, project_id
+ return (None, None)
+
+
+def _infer_rpc_event_fields(func_params: Dict[str, any], resource_type: Event.ResourceType,
+ resource_name_fn: Callable[[Message], str], op_type: Event.OperationType) -> Dict[str, any]:
+ request_type = type(func_params['request'])
+ if resource_name_fn is None:
+ raise InvalidArgumentException('Callable resource_name_fn required')
+ resource_uuid = resource_name_fn(func_params['request'])
+ rpc_request = to_dict(func_params['request'])
+ context = func_params['context']
+ if request_type is TwoPcRequest:
+ type_list = rpc_request['type'].split('_')
+ if type_list[-1] == 'STATE':
+ op_type = Event.OperationType.Value(type_list[0] + '_' + type_list[-1])
+ resource_type = Event.ResourceType.Value('_'.join(type_list[1:-1]))
+ else:
+ op_type = Event.OperationType.Value(type_list[0])
+ resource_type = Event.ResourceType.Value('_'.join(type_list[1:]))
+ # get domain_name and project_name
+ pure_domain_name, project_id = _infer_auth_info(rpc_request, context)
+ return {
+ 'name': str(request_type)[str(request_type).rfind('.') + 1:str(request_type).rfind('Request')],
+ 'op_type': op_type,
+ 'resource_type': resource_type,
+ 'resource_name': resource_uuid,
+ 'coordinator_pure_domain_name': pure_domain_name,
+ 'project_id': project_id,
+ 'source': Event.Source.RPC
+ }
+
+
+def _emit_event(user_id: Optional[int], result: Event.Result, result_code: str, fields: dict) -> None:
+ event = Event(user_id=user_id, result=result, result_code=result_code, **fields)
+ try:
+ with db.session_scope() as session:
+ EventService(session).emit_event(event)
+ session.commit()
+ except ValueError as e:
+ logging.error(f'[audit.decorator] invalid argument passed: {e}')
+ emit_store('audit_invalid_arguments', 1)
+
+
+def get_two_pc_request_uuid(rpc_request: TwoPcRequest) -> Optional[str]:
+ if rpc_request.type == TwoPcType.CREATE_MODEL_JOB:
+ return rpc_request.data.create_model_job_data.model_job_uuid
+ if rpc_request.type == TwoPcType.CONTROL_WORKFLOW_STATE:
+ return rpc_request.data.transit_workflow_state_data.workflow_uuid
+ if rpc_request.type == TwoPcType.CREATE_MODEL_JOB_GROUP:
+ return rpc_request.data.create_model_job_group_data.model_job_group_uuid
+ if rpc_request.type == TwoPcType.LAUNCH_DATASET_JOB:
+ return rpc_request.data.launch_dataset_job_data.dataset_job_uuid
+ if rpc_request.type == TwoPcType.STOP_DATASET_JOB:
+ return rpc_request.data.stop_dataset_job_data.dataset_job_uuid
+ if rpc_request.type == TwoPcType.CREATE_TRUSTED_JOB_GROUP:
+ return rpc_request.data.create_trusted_job_group_data.algorithm_uuid
+ if rpc_request.type == TwoPcType.LAUNCH_TRUSTED_JOB:
+ return rpc_request.data.launch_trusted_job_data.uuid
+ if rpc_request.type == TwoPcType.STOP_TRUSTED_JOB:
+ return rpc_request.data.stop_trusted_job_data.uuid
+ if rpc_request.type == TwoPcType.LAUNCH_DATASET_JOB_STAGE:
+ return rpc_request.data.launch_dataset_job_stage_data.dataset_job_stage_uuid
+ if rpc_request.type == TwoPcType.STOP_DATASET_JOB_STAGE:
+ return rpc_request.data.stop_dataset_job_stage_data.dataset_job_stage_uuid
+ logging.warning('[TwoPc] Unsupported TwoPcType!')
+ return None
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/decorators_test.py b/web_console_v2/api/fedlearner_webconsole/audit/decorators_test.py
new file mode 100644
index 000000000..d7cc8d5a3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/decorators_test.py
@@ -0,0 +1,238 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+from http import HTTPStatus
+import unittest
+from unittest.mock import MagicMock, patch
+from typing import Tuple
+from grpc import ServicerContext, StatusCode
+from grpc._server import _Context, _RPCState
+from google.protobuf import empty_pb2
+from google.protobuf.message import Message
+from testing.common import BaseTestCase, NoWebServerTestCase
+
+from fedlearner_webconsole.utils.pp_base64 import base64encode
+from fedlearner_webconsole.utils.const import API_VERSION
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.audit.decorators import _get_func_params, _infer_rpc_event_fields,\
+ emits_rpc_event, get_two_pc_request_uuid
+from fedlearner_webconsole.proto.service_pb2 import TwoPcRequest, UpdateWorkflowResponse
+from fedlearner_webconsole.proto import common_pb2, service_pb2
+from fedlearner_webconsole.audit.models import EventModel
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.two_pc_pb2 import CreateModelJobData, TransactionData, TwoPcAction, TwoPcType
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition, WorkflowDefinition
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+
+
+def get_uuid(proto_message: Message):
+ return 'test_uuid'
+
+
+def get_times() -> Tuple[int, int]:
+ ts = to_timestamp(now())
+ return ts - 60 * 2, ts + 60 * 2
+
+
+@emits_rpc_event(resource_type=Event.ResourceType.WORKFLOW,
+ op_type=Event.OperationType.UPDATE,
+ resource_name_fn=get_uuid)
+def fake_rpc_method(request, context=None):
+ return UpdateWorkflowResponse(status=common_pb2.Status(code=common_pb2.STATUS_UNAUTHORIZED, msg='done'))
+
+
+@emits_rpc_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP,
+ op_type=Event.OperationType.UPDATE,
+ resource_name_fn=get_uuid)
+def fake_rpc_method_without_status_code(request, context=None):
+ return service_pb2.UpdateModelJobGroupResponse(uuid='test',
+ config=WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[Variable(name='mode', value='train')])
+ ]))
+
+
+@emits_rpc_event(resource_type=Event.ResourceType.TRUSTED_JOB_GROUP,
+ op_type=Event.OperationType.INFORM,
+ resource_name_fn=get_uuid)
+def fake_rpc_method_with_status_code_in_context(request, context: ServicerContext = None):
+ context.abort(StatusCode.INVALID_ARGUMENT, 'just test')
+ return empty_pb2.Empty()
+
+
+class DecoratorsTest(BaseTestCase):
+
+ def test_emits_event(self):
+ self.signin_as_admin()
+
+ start_time, end_time = get_times()
+ response = self.post_helper(
+ f'{API_VERSION}/auth/users', {
+ 'username': 'test123',
+ 'password': base64encode('123456.@abc'),
+ 'role': 'USER',
+ 'name': 'test123',
+ 'email': 'test@byd.org'
+ })
+ user_id = self.get_response_data(response).get('id')
+
+ response = self.get_helper(
+ f'{API_VERSION}/events?filter=(and(username="admin")(start_time>{start_time})(end_time<{end_time}))')
+ data = self.get_response_data(response)[0]
+ self.assertEqual(Event.OperationType.CREATE, Event.OperationType.Value(data.get('op_type')))
+ self.assertEqual(Event.ResourceType.USER, Event.ResourceType.Value(data.get('resource_type')))
+ self.assertEqual('4', data.get('resource_name').split('/')[-1])
+
+ # send a wrong request and see if the event is logged correctly
+ response = self.patch_helper(f'{API_VERSION}/auth/users/999', {})
+ self.assertStatus(response, HTTPStatus.NOT_FOUND)
+ response = self.get_helper(
+ f'{API_VERSION}/events?filter=(and(username="admin")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertEqual(Event.OperationType.UPDATE,
+ Event.OperationType.Value(self.get_response_data(response)[0].get('op_type')))
+ self.assertEqual(Event.Result.FAILURE, Event.Result.Value(self.get_response_data(response)[0].get('result')))
+
+ self.patch_helper(f'{API_VERSION}/auth/users/{user_id}', {})
+ response = self.get_helper(
+ f'{API_VERSION}/events?filter=(and(username="admin")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertEqual(Event.Result.SUCCESS, Event.Result.Value(self.get_response_data(response)[0].get('result')))
+
+ self.delete_helper(f'{API_VERSION}/auth/users/{user_id}')
+ response = self.get_helper(
+ f'{API_VERSION}/events?filter=(and(username="admin")(start_time>{start_time})(end_time<{end_time}))')
+ self.assertEqual(Event.OperationType.DELETE,
+ Event.OperationType.Value(self.get_response_data(response)[0].get('op_type')))
+
+
+class RpcDecoratorsTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.default_auth_info = service_pb2.ProjAuthInfo(project_name='test', target_domain='test_domain')
+ self.request = service_pb2.UpdateWorkflowRequest(auth_info=self.default_auth_info)
+ self.context = _Context('11', _RPCState(), '22')
+
+ @patch('fedlearner_webconsole.audit.decorators._infer_auth_info')
+ def test_decorator(self, mock_infer_auth_info: MagicMock):
+ mock_infer_auth_info.return_value = 'bytedance', 1
+ fake_rpc_method(self.request, context=self.context)
+ with db.session_scope() as session:
+ workflow_event = session.query(EventModel).first()
+ self.assertEqual(workflow_event.op_type, Event.OperationType.Name(Event.OperationType.UPDATE))
+ self.assertEqual(workflow_event.resource_type, Event.ResourceType.Name(Event.ResourceType.WORKFLOW))
+ self.assertEqual(workflow_event.resource_name, 'test_uuid')
+ self.assertEqual(workflow_event.coordinator_pure_domain_name, 'bytedance')
+ self.assertEqual(workflow_event.project_id, 1)
+ self.assertEqual(workflow_event.result_code, 'UNAUTHENTICATED')
+ self.assertEqual(workflow_event.result, 'FAILURE')
+ self.assertEqual(workflow_event.name, 'UpdateWorkflow')
+
+ def test_get_func_params(self):
+ func_params = _get_func_params(fake_rpc_method, request=self.request, context=self.context)
+ self.assertEqual(len(func_params), 2)
+ self.assertEqual(func_params['request'], self.request)
+ self.assertEqual(func_params['context'], self.context)
+
+ def test_infer_rpc_event_fields_with_two_pc(self):
+ transaction_data = TransactionData(
+ create_model_job_data=CreateModelJobData(model_job_name='test model name', model_job_uuid='test uuid'))
+ request = TwoPcRequest(auth_info=self.default_auth_info,
+ transaction_uuid='test-id',
+ type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data)
+ func_params = _get_func_params(fake_rpc_method, request=request, context=self.context)
+ fields = _infer_rpc_event_fields(func_params, Event.ResourceType.UNKNOWN_RESOURCE_TYPE, get_uuid,
+ Event.OperationType.UNKNOWN_OPERATION_TYPE)
+ uuid = get_two_pc_request_uuid(request)
+ self.assertEqual(uuid, request.data.create_model_job_data.model_job_uuid)
+ self.assertEqual(fields['op_type'], Event.OperationType.CREATE)
+ self.assertEqual(fields['resource_type'], Event.ResourceType.MODEL_JOB)
+ self.assertEqual(fields['name'], 'TwoPc')
+
+ request = TwoPcRequest(auth_info=self.default_auth_info,
+ transaction_uuid='test-id',
+ type=TwoPcType.CONTROL_WORKFLOW_STATE,
+ action=TwoPcAction.PREPARE)
+ func_params = _get_func_params(fake_rpc_method, request=request, context=self.context)
+ fields = _infer_rpc_event_fields(func_params, Event.ResourceType.UNKNOWN_RESOURCE_TYPE, get_uuid,
+ Event.OperationType.UNKNOWN_OPERATION_TYPE)
+ self.assertEqual(fields['op_type'], Event.OperationType.CONTROL_STATE)
+ self.assertEqual(fields['resource_type'], Event.ResourceType.WORKFLOW)
+
+ request = TwoPcRequest(auth_info=self.default_auth_info,
+ transaction_uuid='test-id',
+ type=TwoPcType.CREATE_MODEL_JOB_GROUP,
+ action=TwoPcAction.PREPARE)
+ func_params = _get_func_params(fake_rpc_method, request=request, context=self.context)
+ fields = _infer_rpc_event_fields(func_params, Event.ResourceType.UNKNOWN_RESOURCE_TYPE, get_uuid,
+ Event.OperationType.UNKNOWN_OPERATION_TYPE)
+ self.assertEqual(fields['op_type'], Event.OperationType.CREATE)
+ self.assertEqual(fields['resource_type'], Event.ResourceType.MODEL_JOB_GROUP)
+
+ request = TwoPcRequest(auth_info=self.default_auth_info,
+ transaction_uuid='test-id',
+ type=TwoPcType.LAUNCH_DATASET_JOB,
+ action=TwoPcAction.PREPARE)
+ func_params = _get_func_params(fake_rpc_method, request=request, context=self.context)
+ fields = _infer_rpc_event_fields(func_params, Event.ResourceType.UNKNOWN_RESOURCE_TYPE, get_uuid,
+ Event.OperationType.UNKNOWN_OPERATION_TYPE)
+ self.assertEqual(fields['op_type'], Event.OperationType.LAUNCH)
+ self.assertEqual(fields['resource_type'], Event.ResourceType.DATASET_JOB)
+
+ request = TwoPcRequest(auth_info=self.default_auth_info,
+ transaction_uuid='test-id',
+ type=TwoPcType.LAUNCH_MODEL_JOB,
+ action=TwoPcAction.PREPARE)
+ func_params = _get_func_params(fake_rpc_method, request=request, context=self.context)
+ fields = _infer_rpc_event_fields(func_params, Event.ResourceType.UNKNOWN_RESOURCE_TYPE, get_uuid,
+ Event.OperationType.UNKNOWN_OPERATION_TYPE)
+ self.assertEqual(fields['op_type'], Event.OperationType.LAUNCH)
+ self.assertEqual(fields['resource_type'], Event.ResourceType.MODEL_JOB)
+
+ request = TwoPcRequest(auth_info=self.default_auth_info,
+ transaction_uuid='test-id',
+ type=TwoPcType.STOP_DATASET_JOB,
+ action=TwoPcAction.PREPARE)
+ func_params = _get_func_params(fake_rpc_method, request=request, context=self.context)
+ fields = _infer_rpc_event_fields(func_params, Event.ResourceType.UNKNOWN_RESOURCE_TYPE, get_uuid,
+ Event.OperationType.UNKNOWN_OPERATION_TYPE)
+ self.assertEqual(fields['op_type'], Event.OperationType.STOP)
+ self.assertEqual(fields['resource_type'], Event.ResourceType.DATASET_JOB)
+
+ @patch('fedlearner_webconsole.audit.decorators._infer_auth_info')
+ def test_response_with_no_status(self, mock_infer_auth_info: MagicMock):
+ mock_infer_auth_info.return_value = 'bytedance', 1
+ fake_rpc_method_without_status_code(self.request, context=self.context)
+ with db.session_scope() as session:
+ event = session.query(EventModel).first()
+ self.assertEqual(event.result_code, 'OK')
+ self.assertEqual(event.result, 'SUCCESS')
+
+ @patch('fedlearner_webconsole.audit.decorators._infer_auth_info')
+ def test_response_with_status_code_in_context(self, mock_infer_auth_info: MagicMock):
+ mock_infer_auth_info.return_value = 'bytedance', 1
+ with self.assertRaises(Exception):
+ fake_rpc_method_with_status_code_in_context(self.request, context=self.context)
+ with db.session_scope() as session:
+ event = session.query(EventModel).first()
+ self.assertEqual(event.result_code, 'INVALID_ARGUMENT')
+ self.assertEqual(event.result, 'FAILURE')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/models.py b/web_console_v2/api/fedlearner_webconsole/audit/models.py
new file mode 100644
index 000000000..8a73578b3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/models.py
@@ -0,0 +1,122 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+from uuid import uuid4
+import enum
+
+from sqlalchemy import UniqueConstraint, func
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.utils.mixins import to_dict_mixin
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.auth_pb2 import User
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+
+
+class EventType(enum.Enum):
+ # USER_ENDPOINT maps to events with source of 'API/UI'
+ USER_ENDPOINT = 0
+ # RPC maps to events with source of 'RPC'
+ RPC = 1
+
+
+@to_dict_mixin(ignores=['updated_at', 'deleted_at', 'user_id'],
+ extras={
+ 'user': lambda e: {
+ 'id': e.user.id,
+ 'username': e.user.username,
+ 'role': e.user.role.name,
+ },
+ })
+class EventModel(db.Model):
+ __tablename__ = 'events_v2'
+ __table_args__ = (UniqueConstraint('uuid', name='uniq_uuid'), default_table_args('webconsole audit events'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='auto-incremented id')
+ uuid = db.Column(db.String(255), nullable=False, comment='UUID of the event', default=lambda _: str(uuid4()))
+ name = db.Column(db.String(255), nullable=False, comment='the name of the event')
+ user_id = db.Column(db.Integer, comment='the ID of the user who triggered the event')
+ resource_type = db.Column(db.Enum(*Event.ResourceType.keys(),
+ native_enum=False,
+ create_constraint=False,
+ length=32,
+ name='resource_type'),
+ nullable=False,
+ comment='the type of the resource')
+ resource_name = db.Column(db.String(512), nullable=False, comment='the name of the resource')
+ op_type = db.Column(db.Enum(*Event.OperationType.keys(),
+ native_enum=False,
+ create_constraint=False,
+ length=32,
+ name='op_type'),
+ nullable=False,
+ comment='the type of the operation of the event')
+ # Due to compatibility, audit API double writes result and result_code field
+ # TODO(yeqiuhan): remove result field
+ result = db.Column(db.Enum(*Event.Result.keys(),
+ native_enum=False,
+ create_constraint=False,
+ length=32,
+ name='result'),
+ nullable=False,
+ comment='the result of the operation')
+ result_code = db.Column(db.String(255), comment='the result code of the operation')
+ source = db.Column(db.Enum(*Event.Source.keys(),
+ native_enum=False,
+ create_constraint=False,
+ length=32,
+ name='source'),
+ nullable=False,
+ comment='the source that triggered the event')
+ coordinator_pure_domain_name = db.Column(db.String(255), comment='name of the coordinator')
+ project_id = db.Column(db.Integer, comment='project_id corresponds to participants name')
+ extra = db.Column(db.Text, comment='extra info in JSON')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created at')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ onupdate=func.now(),
+ server_default=func.now(),
+ comment='updated at')
+ deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')
+ user = db.relationship('User', primaryjoin='foreign(EventModel.user_id) == User.id')
+
+ def to_proto(self) -> Event:
+ return Event(user_id=self.user_id,
+ resource_type=Event.ResourceType.Value(self.resource_type),
+ resource_name=self.resource_name,
+ op_type=Event.OperationType.Value(self.op_type),
+ result=Event.Result.Value(self.result),
+ result_code=self.result_code,
+ source=Event.Source.Value(self.source),
+ name=self.name,
+ coordinator_pure_domain_name=self.coordinator_pure_domain_name,
+ project_id=self.project_id,
+ extra=self.extra,
+ user=User(id=self.user.id, username=self.user.username, role=self.user.role.value)
+ if self.user is not None else None,
+ event_id=self.id,
+ uuid=self.uuid,
+ created_at=to_timestamp(self.created_at))
+
+
+def to_model(proto: Event) -> EventModel:
+ return EventModel(name=proto.name,
+ user_id=proto.user_id,
+ resource_type=Event.ResourceType.Name(proto.resource_type),
+ resource_name=proto.resource_name,
+ op_type=Event.OperationType.Name(proto.op_type),
+ coordinator_pure_domain_name=proto.coordinator_pure_domain_name,
+ result=Event.Result.Name(proto.result),
+ result_code=proto.result_code,
+ source=Event.Source.Name(proto.source),
+ extra=proto.extra,
+ project_id=proto.project_id)
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/models_test.py b/web_console_v2/api/fedlearner_webconsole/audit/models_test.py
new file mode 100644
index 000000000..304adff35
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/models_test.py
@@ -0,0 +1,146 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+
+import unittest
+from sqlalchemy.exc import IntegrityError
+from datetime import datetime, timezone
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.auth.services import UserService
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.audit.models import EventModel, to_model
+from fedlearner_webconsole.proto import audit_pb2
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.proto.auth_pb2 import User
+
+
+def generate_event() -> EventModel:
+ return EventModel(name='some_event',
+ user_id=1,
+ resource_type='IAM',
+ resource_name='some_resource',
+ op_type='CREATE',
+ result='SUCCESS',
+ result_code='OK',
+ coordinator_pure_domain_name='bytedance',
+ project_id=1,
+ source='RPC')
+
+
+class EventModelsTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ events = [generate_event() for _ in range(3)]
+ events[1].user_id = None
+ events[2].user_id = 0
+ with db.session_scope() as session:
+ UserService(session).create_user_if_not_exists(username='ada', email='ada@ada.com', password='ada')
+ session.add_all(events)
+ session.commit()
+ self.default_event = audit_pb2.Event(event_id=1,
+ name='some_event',
+ user_id=1,
+ resource_type='IAM',
+ resource_name='some_resource',
+ op_type='CREATE',
+ result='SUCCESS',
+ result_code='OK',
+ coordinator_pure_domain_name='bytedance',
+ project_id=1,
+ user=User(id=1, username='ada', role='USER'),
+ created_at=to_timestamp(created_at),
+ source='RPC')
+ self.default_event_2 = audit_pb2.Event(event_id=2,
+ name='some_event',
+ user_id=None,
+ resource_type='IAM',
+ resource_name='some_resource',
+ op_type='CREATE',
+ result='SUCCESS',
+ result_code='OK',
+ coordinator_pure_domain_name='bytedance',
+ project_id=1,
+ user=None,
+ created_at=to_timestamp(created_at),
+ source='RPC')
+ self.default_event_3 = audit_pb2.Event(event_id=3,
+ name='some_event',
+ user_id=0,
+ resource_type='IAM',
+ resource_name='some_resource',
+ op_type='CREATE',
+ result='SUCCESS',
+ result_code='OK',
+ coordinator_pure_domain_name='bytedance',
+ project_id=1,
+ user=None,
+ created_at=to_timestamp(created_at),
+ source='RPC')
+
+ def test_uuids(self):
+ with db.session_scope() as session:
+ events = session.query(EventModel).all()
+ self.assertNotEqual(events[0].uuid, events[1].uuid)
+
+ def test_invalid_instances(self):
+ event = EventModel()
+ with db.session_scope() as session:
+ session.add(event)
+ self.assertRaises(IntegrityError, session.commit)
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ result = session.query(EventModel).first()
+ self.assertDictPartiallyEqual(to_dict(result.to_proto()), to_dict(self.default_event),
+ ['created_at', 'uuid'])
+
+ def test_save_proto(self):
+ with db.session_scope() as session:
+ session.add(to_model(self.default_event))
+ session.commit()
+
+ with db.session_scope() as session:
+ result = session.query(EventModel).get(1)
+ self.assertDictPartiallyEqual(to_dict(result.to_proto()), to_dict(self.default_event),
+ ['created_at', 'uuid'])
+
+ def test_without_user_id_to_proto(self):
+ with db.session_scope() as session:
+ result = session.query(EventModel).filter_by(user_id=None).first()
+ self.assertDictPartiallyEqual(to_dict(result.to_proto()), to_dict(self.default_event_2),
+ ['created_at', 'uuid'])
+
+ def test_without_user_id_save_proto(self):
+ with db.session_scope() as session:
+ session.add(to_model(self.default_event_2))
+ session.commit()
+
+ with db.session_scope() as session:
+ result = session.query(EventModel).get(2)
+ self.assertDictPartiallyEqual(to_dict(result.to_proto()), to_dict(self.default_event_2),
+ ['created_at', 'uuid'])
+
+ def test_user_id_zero_to_proto(self):
+ with db.session_scope() as session:
+ result = session.query(EventModel).filter_by(user_id=0).first()
+ self.assertDictPartiallyEqual(to_dict(result.to_proto()), to_dict(self.default_event_3),
+ ['created_at', 'uuid'])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/services.py b/web_console_v2/api/fedlearner_webconsole/audit/services.py
new file mode 100644
index 000000000..113304d32
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/services.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+from typing import Optional
+
+from sqlalchemy.orm import Session, Query
+
+from fedlearner_webconsole.audit.storage import get_storage
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression
+
+
+class EventService:
+
+ def __init__(self, session: Session):
+ """Construct a EventService.
+
+ Args:
+ session (Session): SQLAlchemy session.
+ """
+ self._session = session
+ self.storage = get_storage(self._session)
+
+ def emit_event(self, event: Event) -> None:
+ """Pass a Event instance to storage.
+
+ Args:
+ event (Event): Records to store.
+
+ Raises:
+ ValueError: Fields are invalid.
+ """
+ self.storage.save_event(event)
+
+ def get_events(self, filter_exp: Optional[FilterExpression] = None) -> Query:
+ """Get events by time and additional conditions in {event}.
+
+ Args:
+ filter_exp (FilterExpression): Filtering expression defined in utils/filtering.py
+ Returns:
+ A SQLAlchemy Query object contains selected records.
+ """
+ return self.storage.get_events(filter_exp)
+
+ def delete_events(self, filter_exp: FilterExpression):
+ """Delete events by time.
+
+ Args:
+ filter_exp (FilterExpression): Filtering expression defined in utils/filtering.py
+ """
+ self.storage.delete_events(filter_exp)
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/services_test.py b/web_console_v2/api/fedlearner_webconsole/audit/services_test.py
new file mode 100644
index 000000000..2cab952bd
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/services_test.py
@@ -0,0 +1,168 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+
+import unittest
+from typing import Tuple
+
+from fedlearner_webconsole.audit.services import EventService
+from fedlearner_webconsole.audit.models import EventModel
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.utils.filtering import parse_expression
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterExpressionKind, SimpleExpression, FilterOp
+from fedlearner_webconsole.auth.services import UserService
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+def get_times() -> Tuple[int, int]:
+ ts = to_timestamp(now())
+ return ts - 60 * 2, ts + 60 * 2
+
+
+def generate_event() -> EventModel:
+ return EventModel(name='some_event',
+ user_id=1,
+ resource_type='IAM',
+ resource_name='some_resource',
+ op_type='CREATE',
+ result='SUCCESS',
+ result_code='OK',
+ coordinator_pure_domain_name='bytedance',
+ project_id=1,
+ source='RPC')
+
+
+class EventServiceTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ with db.session_scope() as session:
+ UserService(session).create_user_if_not_exists(username='ada', email='ada@ada.com', password='ada')
+ event_1 = generate_event()
+ event_1.coordinator_pure_domain_name = 'mihoyo'
+ session.add(event_1)
+ event_2 = generate_event()
+ event_2.resource_type = 'WORKFLOW'
+ event_2.op_type = 'UPDATE'
+ session.add(event_2)
+ event_3 = generate_event()
+ event_3.resource_type = 'DATASET'
+ event_3.op_type = 'DELETE'
+ event_3.result_code = 'CANCELLED'
+ session.add(event_3)
+ session.commit()
+
+ def test_emit_event(self):
+ with db.session_scope() as session:
+ service = EventService(session)
+ event = generate_event()
+ event_param = Event(name=event.name,
+ user_id=event.user_id,
+ resource_type=event.resource_type,
+ resource_name=event.resource_name,
+ result_code=event.result_code,
+ coordinator_pure_domain_name=event.coordinator_pure_domain_name,
+ project_id=event.project_id,
+ op_type=event.op_type,
+ result=event.result,
+ source=event.source)
+ service.emit_event(event_param)
+ session.commit()
+ events = service.get_events()
+ self.assertEqual(4, len(events.all()))
+ self.assertEqual(now().hour, events.first().created_at.hour)
+
+ def test_get_events(self):
+ with db.session_scope() as session:
+ service = EventService(session)
+ events = service.get_events()
+
+ self.assertEqual(3, len(events.all()))
+
+ def test_get_rpc_events_with_filter(self):
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='coordinator_pure_domain_name',
+ op=FilterOp.CONTAIN,
+ string_value='mihoyo',
+ ))
+ with db.session_scope() as session:
+ service = EventService(session)
+ events = service.get_events(filter_exp)
+ self.assertEqual(1, len(events.all()))
+ self.assertEqual(events[0].coordinator_pure_domain_name, 'mihoyo')
+
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='op_type',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(string_list=['UPDATE', 'DELETE'])))
+ with db.session_scope() as session:
+ service = EventService(session)
+ events = service.get_events(filter_exp)
+ self.assertEqual(2, len(events.all()))
+ self.assertEqual(events[0].op_type, 'DELETE')
+ self.assertEqual(events[1].op_type, 'UPDATE')
+
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='op_type',
+ op=FilterOp.EQUAL,
+ string_value='UPDATE',
+ ))
+ with db.session_scope() as session:
+ service = EventService(session)
+ events = service.get_events(filter_exp)
+ self.assertEqual(1, len(events.all()))
+ self.assertEqual(events[0].op_type, 'UPDATE')
+
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='result_code',
+ op=FilterOp.EQUAL,
+ string_value='CANCELLED',
+ ))
+ with db.session_scope() as session:
+ service = EventService(session)
+ events = service.get_events(filter_exp)
+ self.assertEqual(1, len(events.all()))
+ self.assertEqual(events[0].result_code, 'CANCELLED')
+
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='resource_type',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(string_list=['WORKFLOW', 'DATASET'])))
+ with db.session_scope() as session:
+ service = EventService(session)
+ events = service.get_events(filter_exp)
+ self.assertEqual(2, len(events.all()))
+ self.assertEqual(events[0].resource_type, 'DATASET')
+ self.assertEqual(events[1].resource_type, 'WORKFLOW')
+
+ def test_delete_events(self):
+ with db.session_scope() as session:
+ service = EventService(session)
+ start_time, end_time = get_times()
+ filter_exp = parse_expression(f'(and(start_time>{start_time})(end_time<{end_time}))')
+ self.assertIsNone(service.delete_events(filter_exp))
+ self.assertEqual(0, service.get_events().count())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/audit/storage.py b/web_console_v2/api/fedlearner_webconsole/audit/storage.py
new file mode 100644
index 000000000..2c77c1148
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/audit/storage.py
@@ -0,0 +1,158 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# coding: utf-8
+from abc import ABCMeta, abstractmethod
+from datetime import datetime, timezone
+from typing import Optional
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session, Query
+from sqlalchemy.sql.schema import Column
+
+from envs import Envs
+from fedlearner_webconsole.audit.models import EventModel, to_model
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterOp, SimpleExpression
+from fedlearner_webconsole.utils.filtering import FieldType, FilterBuilder, SupportedField
+from fedlearner_webconsole.auth.services import UserService
+from fedlearner_webconsole.db import db
+
+
+def _contains_case_insensitive(exp: SimpleExpression):
+ c: Column = getattr(EventModel, exp.field)
+ return c.ilike(f'%{exp.string_value}%')
+
+
+def _equals_username(exp: SimpleExpression):
+ username = exp.string_value
+ c: Column = EventModel.user_id
+ with db.session_scope() as session:
+ user = UserService(session).get_user_by_username(username)
+ if user is None:
+ return False
+ return c == user.id
+
+
+def _later(exp: SimpleExpression):
+ c: Column = EventModel.created_at
+ dt = datetime.fromtimestamp(exp.number_value, tz=timezone.utc)
+ return c > dt
+
+
+def _earlier(exp: SimpleExpression):
+ c: Column = EventModel.created_at
+ dt = datetime.fromtimestamp(exp.number_value, tz=timezone.utc)
+ return c < dt
+
+
+class IStorage(metaclass=ABCMeta):
+
+ @abstractmethod
+ def save_event(self, event: Event) -> None:
+ """Save the event instance into corresponding storage.
+
+ Args:
+ event (Event): The event instance waited to be stored.
+ """
+
+ @abstractmethod
+ def get_events(self, filter_exp: Optional[FilterExpression] = None) -> Query:
+ """Get event records from corresponding storage.
+
+ Args:
+ filter_exp (FilterExpression): Filtering expression defined in utils/filtering.py
+ Returns:
+ A Query object contains selected events.
+ """
+
+ @abstractmethod
+ def delete_events(self, filter_exp: FilterExpression) -> None:
+ """Delete event records for a period of time.
+
+ Args:
+ filter_exp (FilterExpression): Filtering expression defined in utils/filtering.py
+ """
+
+
+class MySqlStorage(IStorage):
+
+ FILTER_FIELDS = {
+ 'name':
+ SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: _contains_case_insensitive}),
+ 'username':
+ SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: _equals_username}),
+ 'resource_type':
+ SupportedField(type=FieldType.STRING, ops={
+ FilterOp.CONTAIN: _contains_case_insensitive,
+ FilterOp.IN: None
+ }),
+ 'resource_name':
+ SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: _contains_case_insensitive}),
+ 'op_type':
+ SupportedField(type=FieldType.STRING, ops={
+ FilterOp.EQUAL: None,
+ FilterOp.IN: None
+ }),
+ 'coordinator_pure_domain_name':
+ SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: _contains_case_insensitive}),
+ 'project_id':
+ SupportedField(type=FieldType.NUMBER, ops={FilterOp.EQUAL: None}),
+ 'result':
+ SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'result_code':
+ SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'source':
+ SupportedField(type=FieldType.STRING, ops={
+ FilterOp.EQUAL: None,
+ FilterOp.IN: None
+ }),
+ 'start_time':
+ SupportedField(type=FieldType.NUMBER, ops={FilterOp.GREATER_THAN: _later}),
+ 'end_time':
+ SupportedField(type=FieldType.NUMBER, ops={FilterOp.LESS_THAN: _earlier}),
+ }
+
+ def __init__(self, session: Session):
+ self._session = session
+ self._filter_builder = FilterBuilder(model_class=EventModel, supported_fields=self.FILTER_FIELDS)
+
+ def save_event(self, event: Event) -> None:
+ self._session.add(to_model(event))
+
+ def get_events(self, filter_exp: Optional[FilterExpression] = None) -> Query:
+ events = self._session.query(EventModel).filter(EventModel.deleted_at.is_(None))
+ if filter_exp is not None:
+ events = self._filter_builder.build_query(events, filter_exp)
+ return events.order_by(EventModel.id.desc())
+
+ def delete_events(self, filter_exp: FilterExpression) -> None:
+ events = self._session.query(EventModel).filter(EventModel.deleted_at.is_(None))
+ events = self._filter_builder.build_query(events, filter_exp)
+ events.update({'deleted_at': func.now()}, synchronize_session='fetch')
+
+
+def get_storage(session: Session) -> Optional[IStorage]:
+ """Get a storage object accordingly.
+
+ Args:
+ session (Session): Session used to query records.
+
+ Returns:
+ A IStorage object that can save, get and delete events.
+ """
+ if Envs.AUDIT_STORAGE == 'db':
+ return MySqlStorage(session)
+ # TODO(wangsen.0914): add CloudStorage or other types of storages later
+ return None
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/auth/BUILD.bazel
new file mode 100644
index 000000000..4b88c98f9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/auth/BUILD.bazel
@@ -0,0 +1,135 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:mixins_lib",
+ "@common_passlib//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = [
+ "services.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "services_lib_test",
+ size = "small",
+ srcs = [
+ "services_test.py",
+ ],
+ imports = ["../.."],
+ main = "services_test.py",
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "third_party_sso_lib",
+ srcs = [
+ "third_party_sso.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api:config_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_base64_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:helpers_lib",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask//:pkg",
+ "@common_pyjwt//:pkg",
+ "@common_requests//:pkg",
+ "@common_xmltodict//:pkg",
+ ],
+)
+
+py_test(
+ name = "third_party_sso_lib_test",
+ size = "medium",
+ srcs = [
+ "third_party_sso_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "third_party_sso_test.py",
+ deps = [
+ ":third_party_sso_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_base64_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:helpers_lib",
+ "@common_flask//:pkg",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/apis.py b/web_console_v2/api/fedlearner_webconsole/auth/apis.py
index 7206e3497..d288b5fdf 100644
--- a/web_console_v2/api/fedlearner_webconsole/auth/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/auth/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,27 +13,32 @@
# limitations under the License.
# coding: utf-8
-# pylint: disable=cyclic-import
+import logging
import re
-import datetime
from http import HTTPStatus
-from flask import request
-from flask_restful import Resource, reqparse
-from flask_jwt_extended.utils import get_current_user
-from flask_jwt_extended import create_access_token, decode_token, get_jwt
-
-from fedlearner_webconsole.utils.base64 import base64decode
-from fedlearner_webconsole.utils.decorators import jwt_required
-from fedlearner_webconsole.utils.decorators import admin_required
+from flask import request
+from flask_restful import Resource
+from marshmallow import Schema, post_load, fields, validate, EXCLUDE
+from marshmallow.decorators import validates_schema
+from webargs.flaskparser import use_args
+
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.auth.services import UserService, SessionService
+from fedlearner_webconsole.iam.client import create_iams_for_user
+from fedlearner_webconsole.proto import auth_pb2
+from fedlearner_webconsole.swagger.models import schema_manager
+
+from fedlearner_webconsole.utils.pp_base64 import base64decode
+from fedlearner_webconsole.auth.third_party_sso import credentials_required, SsoHandlerFactory
+from fedlearner_webconsole.utils.flask_utils import get_current_user, make_flask_response
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required
from fedlearner_webconsole.db import db
-from fedlearner_webconsole.auth.models import (State, User, Role,
- MUTABLE_ATTRS_MAPPER, Session)
-from fedlearner_webconsole.exceptions import (NotFoundException,
- InvalidArgumentException,
- ResourceConflictException,
- UnauthorizedException,
- NoAccessException)
+from fedlearner_webconsole.auth.models import (Role, MUTABLE_ATTRS_MAPPER)
+from fedlearner_webconsole.exceptions import (NotFoundException, InvalidArgumentException, ResourceConflictException,
+ NoAccessException, UnauthorizedException)
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.auth.third_party_sso import sso_info_manager
# rule: password must have a letter, a num and a special character
PASSWORD_FORMAT_L = re.compile(r'.*[A-Za-z]')
@@ -41,10 +46,9 @@
PASSWORD_FORMAT_S = re.compile(r'.*[`!@#$%^&*()\-_=+|{}\[\];:\'\",<.>/?~]')
-def check_password_format(password: str):
+def _check_password_format(password: str):
if not 8 <= len(password) <= 20:
- raise InvalidArgumentException(
- 'Password is not legal: 8 <= length <= 20')
+ raise InvalidArgumentException('Password is not legal: 8 <= length <= 20')
required_chars = []
if PASSWORD_FORMAT_L.match(password) is None:
required_chars.append('a letter')
@@ -54,111 +58,159 @@ def check_password_format(password: str):
required_chars.append('a special character')
if required_chars:
tip = ', '.join(required_chars)
- raise InvalidArgumentException(
- f'Password is not legal: must have {tip}.')
+ raise InvalidArgumentException(f'Password is not legal: must have {tip}.')
-class SigninApi(Resource):
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('username',
- required=True,
- help='username is empty')
- parser.add_argument('password',
- required=True,
- help='password is empty')
- data = parser.parse_args()
- username = data['username']
- password = base64decode(data['password'])
- user = User.query.filter_by(username=username).filter_by(
- state=State.ACTIVE).first()
- if user is None:
- raise NotFoundException(f'Failed to find user: {username}')
- if not user.verify_password(password):
- raise UnauthorizedException('Invalid password')
- token = create_access_token(identity=username)
- decoded_token = decode_token(token)
-
- session = Session(jti=decoded_token.get('jti'),
- expired_at=datetime.datetime.fromtimestamp(
- decoded_token.get('exp')))
- db.session.add(session)
- db.session.commit()
-
- return {
- 'data': {
- 'user': user.to_dict(),
- 'access_token': token
- }
- }, HTTPStatus.OK
-
- @jwt_required()
- def delete(self):
- decoded_token = get_jwt()
+class UserParameter(Schema):
+ username = fields.Str(required=True)
+ # Base64 encoded password
+ password = fields.Str(required=True, validate=lambda x: _check_password_format(base64decode(x)))
+ role = fields.Str(required=True, validate=validate.OneOf([x.name for x in Role]))
+ name = fields.Str(required=True, validate=validate.Length(min=1))
+ email = fields.Str(required=True, validate=validate.Email())
+
+ @post_load
+ def make_user(self, data, **kwargs):
+ return auth_pb2.User(**data)
+
+
+class SigninParameter(Schema):
+ username = fields.String()
+ password = fields.String()
+ code = fields.String()
+ ticket = fields.String()
+
+ @validates_schema
+ def validate_schema(self, data, **kwargs):
+ del kwargs
+ if data.get('username') is None and data.get('code') is None and data.get('ticket') is None:
+ raise InvalidArgumentException('no credential detected')
+
+ @post_load
+ def make_proto(self, data, **kwargs):
+ del kwargs
+ return auth_pb2.SigninParameter(**data)
- jti = decoded_token.get('jti')
- Session.query.filter_by(jti=jti).delete()
- db.session.commit()
- return {}, HTTPStatus.OK
+class SigninApi(Resource):
+
+ @use_args(SigninParameter(unknown=EXCLUDE), location='json_or_form')
+ def post(self, signin_parameter: auth_pb2.SigninParameter):
+ """Sign in to the system
+ ---
+ tags:
+ - auth
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/SigninParameter'
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ access_token:
+ type: string
+ user:
+ type: object
+ properties:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.User'
+ """
+ sso_name = request.args.get('sso_name')
+ return make_flask_response(SsoHandlerFactory.get_handler(sso_name).signin(signin_parameter))
+
+ @credentials_required
+ def delete(self):
+ """Sign out from the system
+ ---
+ tags:
+ - auth
+ parameters:
+ - in: header
+ name: Authorization
+ schema:
+ type: string
+ description: token used for current session
+ responses:
+ 200:
+ description: Signed out successfully
+ """
+ user = get_current_user()
+ SsoHandlerFactory.get_handler(user.sso_name).signout()
+ return make_flask_response()
class UsersApi(Resource):
- @jwt_required()
+
+ @credentials_required
@admin_required
def get(self):
- return {
- 'data': [
- row.to_dict()
- for row in User.query.filter_by(state=State.ACTIVE).all()
- ]
- }
-
- @jwt_required()
+ """Get a list of all users
+ ---
+ tags:
+ - auth
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.User'
+ """
+ with db.session_scope() as session:
+ return make_flask_response(
+ [row.to_dict() for row in UserService(session).get_all_users(filter_deleted=True)])
+
+ @credentials_required
@admin_required
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('username',
- required=True,
- help='username is empty')
- parser.add_argument('password',
- required=True,
- help='password is empty')
- parser.add_argument('role', required=True, help='role is empty')
- parser.add_argument('name', required=True, help='name is empty')
- parser.add_argument('email', required=True, help='email is empty')
-
- data = parser.parse_args()
- username = data['username']
- password = base64decode(data['password'])
- role = data['role']
- name = data['name']
- email = data['email']
-
- check_password_format(password)
-
- if User.query.filter_by(username=username).first() is not None:
- raise ResourceConflictException(
- 'user {} already exists'.format(username))
- user = User(username=username,
- role=role,
- name=name,
- email=email,
- state=State.ACTIVE)
- user.set_password(password)
- db.session.add(user)
- db.session.commit()
-
- return {'data': user.to_dict()}, HTTPStatus.CREATED
+ # if use_kwargs is used with explicit parameters, one has to write YAML document!
+ # Param: https://swagger.io/docs/specification/2-0/describing-parameters/
+ # Body: https://swagger.io/docs/specification/2-0/describing-request-body/
+ # Resp: https://swagger.io/docs/specification/2-0/describing-responses/
+ @use_args(UserParameter(unknown=EXCLUDE))
+ @emits_event()
+ def post(self, params: auth_pb2.User):
+ """Create a user
+ ---
+ tags:
+ - auth
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/UserParameter'
+ responses:
+ 201:
+ description: The user is created
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/UserParameter'
+ 409:
+ description: A user with the same username exists
+ """
+ # Swagger will detect APIs automatically, but params/req body/resp have to be defined manually
+ with db.session_scope() as session:
+ user = UserService(session).get_user_by_username(params.username)
+ if user is not None:
+ raise ResourceConflictException(f'user {user.username} already exists')
+ user = UserService(session).create_user_if_not_exists(username=params.username,
+ role=Role(params.role),
+ name=params.name,
+ email=params.email,
+ password=base64decode(params.password))
+ session.commit()
+ return make_flask_response(user.to_dict(), status=HTTPStatus.CREATED)
class UserApi(Resource):
- def _find_user(self, user_id) -> User:
- user = User.query.filter_by(id=user_id).first()
- if user is None or user.state == State.DELETED:
- raise NotFoundException(
- f'Failed to find user_id: {user_id}')
- return user
def _check_current_user(self, user_id, msg):
current_user = get_current_user()
@@ -166,50 +218,188 @@ def _check_current_user(self, user_id, msg):
and not user_id == current_user.id:
raise NoAccessException(msg)
- @jwt_required()
+ @credentials_required
def get(self, user_id):
- self._check_current_user(user_id,
- 'user cannot get other user\'s information')
- user = self._find_user(user_id)
- return {'data': user.to_dict()}, HTTPStatus.OK
-
- @jwt_required()
+ """Get a user by id
+ ---
+ tags:
+ - auth
+ parameters:
+ - in: path
+ name: user_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: The user is returned
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.User'
+ 404:
+ description: The user with specified ID is not found
+ """
+ self._check_current_user(user_id, 'user cannot get other user\'s information')
+
+ with db.session_scope() as session:
+ user = UserService(session).get_user_by_id(user_id, filter_deleted=True)
+ if user is None:
+ raise NotFoundException(f'Failed to find user_id: {user_id}')
+ return make_flask_response(user.to_dict())
+
+ @credentials_required
+ @emits_event()
+ # Example of manually defining an API
def patch(self, user_id):
- self._check_current_user(user_id,
- 'user cannot modify other user\'s information')
- user = self._find_user(user_id)
-
- mutable_attrs = MUTABLE_ATTRS_MAPPER.get(get_current_user().role)
-
- data = request.get_json()
- for k, v in data.items():
- if k not in mutable_attrs:
- raise InvalidArgumentException(f'cannot edit {k} attribute!')
- if k == 'password':
- password = base64decode(v)
- check_password_format(password)
- user.set_password(password)
- else:
- setattr(user, k, v)
-
- db.session.commit()
- return {'data': user.to_dict()}, HTTPStatus.OK
-
- @jwt_required()
+ """Patch a user
+ ---
+ tags:
+ - auth
+ parameters:
+ - in: path
+ name: user_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the user
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.User'
+ responses:
+ 200:
+ description: The user is updated
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.User'
+ 404:
+ description: The user is not found
+ 400:
+ description: Attributes selected are uneditable
+ """
+ self._check_current_user(user_id, 'user cannot modify other user\'s information')
+
+ with db.session_scope() as session:
+ user = UserService(session).get_user_by_id(user_id, filter_deleted=True)
+ if user is None:
+ raise NotFoundException(f'Failed to find user_id: {user_id}')
+
+ mutable_attrs = MUTABLE_ATTRS_MAPPER.get(get_current_user().role)
+
+ data = request.get_json()
+ for k, v in data.items():
+ if k not in mutable_attrs:
+ raise InvalidArgumentException(f'cannot edit {k} attribute!')
+ if k == 'password':
+ password = base64decode(v)
+ _check_password_format(password)
+ user.set_password(password)
+ SessionService(session).delete_session_by_user_id(user_id)
+ elif k == 'role':
+ user.role = Role(v)
+ else:
+ setattr(user, k, v)
+ create_iams_for_user(user)
+ session.commit()
+ return make_flask_response(user.to_dict())
+
+ @credentials_required
@admin_required
+ @emits_event()
def delete(self, user_id):
- user = self._find_user(user_id)
+ """Delete the user with specified ID
+ ---
+ tags:
+ - auth
+ parameters:
+ - in: path
+ name: user_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: The user with specified ID is deleted
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.User'
+ 400:
+ description: Cannot delete the user logged in within current session
+ 404:
+ description: The user with specified ID is not found
+ """
+ with db.session_scope() as session:
+ user_service = UserService(session)
+ user = user_service.get_user_by_id(user_id, filter_deleted=True)
+
+ if user is None:
+ raise NotFoundException(f'Failed to find user_id: {user_id}')
+
+ current_user = get_current_user()
+ if current_user.id == user_id:
+ raise InvalidArgumentException('cannot delete yourself')
+
+ user = UserService(session).delete_user(user)
+ session.commit()
+ return make_flask_response(user.to_dict())
+
+
+class SsoInfosApi(Resource):
- current_user = get_current_user()
- if current_user.id == user_id:
- raise InvalidArgumentException('cannot delete yourself')
-
- user.state = State.DELETED
- db.session.commit()
- return {'data': user.to_dict()}, HTTPStatus.OK
+ def get(self):
+ """Get all available options of SSOs
+ ---
+ tags:
+ - auth
+ responses:
+ 200:
+ description: All options are returned
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Sso'
+ """
+ return make_flask_response([to_dict(sso, with_secret=False) for sso in sso_info_manager.sso_infos])
+
+
+class SelfUserApi(Resource):
+
+ @credentials_required
+ def get(self):
+ """Get current user
+ ---
+ tags:
+ - auth
+ responses:
+ 200:
+ description: User logged in within current session is returned
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.User'
+ 400:
+ description: No user is logged in within current session
+ """
+ user = get_current_user()
+ # Defensively program for unexpected exception
+ if user is None:
+ logging.error('No current user.')
+ raise UnauthorizedException('No current user.')
+ return make_flask_response(user.to_dict())
def initialize_auth_apis(api):
api.add_resource(SigninApi, '/auth/signin')
api.add_resource(UsersApi, '/auth/users')
api.add_resource(UserApi, '/auth/users/')
+ api.add_resource(SsoInfosApi, '/auth/sso_infos')
+ api.add_resource(SelfUserApi, '/auth/self')
+
+ # if a schema is used, one has to append it to schema_manager so Swagger knows there is a schema available
+ schema_manager.append(UserParameter)
+ schema_manager.append(SigninParameter)
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/apis_test.py b/web_console_v2/api/fedlearner_webconsole/auth/apis_test.py
new file mode 100644
index 000000000..9e963b99d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/auth/apis_test.py
@@ -0,0 +1,306 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import json
+import unittest
+from http import HTTPStatus
+from unittest.mock import patch
+from datetime import timedelta
+
+from testing.common import BaseTestCase
+from testing.helpers import FakeResponse
+from fedlearner_webconsole.auth.services import UserService
+from fedlearner_webconsole.utils.pp_base64 import base64encode
+from fedlearner_webconsole.utils.const import API_VERSION
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.auth.models import State, User
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.auth_pb2 import Sso, OAuthProtocol, CasProtocol
+from fedlearner_webconsole.auth.third_party_sso import get_user_info_with_cache, SsoInfos, OAuthHandler, CasHandler
+from fedlearner_webconsole.auth.models import Session as SessionTbl
+from envs import Envs
+
+
+class UsersApiTest(BaseTestCase):
+
+ def test_get_all_users(self):
+ deleted_user = User(username='deleted_one', email='who.knows@hhh.com', state=State.DELETED)
+ with db.session_scope() as session:
+ session.add(deleted_user)
+ session.commit()
+
+ resp = self.get_helper('/api/v2/auth/users')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+
+ self.signin_as_admin()
+
+ resp = self.get_helper('/api/v2/auth/users')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(len(self.get_response_data(resp)), 3)
+
+ def test_create_new_user(self):
+ new_user = {
+ 'username': 'fedlearner',
+ 'password': 'fedlearner',
+ 'email': 'hello@bytedance.com',
+ 'role': 'USER',
+ 'name': 'codemonkey',
+ }
+ resp = self.post_helper('/api/v2/auth/users', data=new_user)
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+
+ self.signin_as_admin()
+ illegal_cases = [
+ 'aaaaaaaa', '11111111', '!@#$%^[]', 'aaaA1111', 'AAAa!@#$', '1111!@#-', 'aa11!@', 'fl@123.',
+ 'fl@1234567890abcdefg.'
+ ]
+ legal_case = 'fl@1234.'
+
+ for case in illegal_cases:
+ new_user['password'] = base64encode(case)
+ resp = self.post_helper('/api/v2/auth/users', data=new_user)
+ print(self.get_response_data(resp))
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ new_user['password'] = base64encode(legal_case)
+ resp = self.post_helper('/api/v2/auth/users', data=new_user)
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ self.assertEqual(self.get_response_data(resp).get('username'), 'fedlearner')
+
+ # test_repeat_create
+ resp = self.post_helper('/api/v2/auth/users', data=new_user)
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+
+
+class AuthApiTest(BaseTestCase):
+
+ def test_partial_update_user_info(self):
+ self.signin_as_admin()
+ resp = self.get_helper('/api/v2/auth/users')
+ resp_data = self.get_response_data(resp)
+ user_id = resp_data[0]['id']
+ admin_id = resp_data[1]['id']
+
+ self.signin_helper()
+ resp = self.patch_helper('/api/v2/auth/users/10', data={})
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+
+ resp = self.patch_helper(f'/api/v2/auth/users/{user_id}', data={
+ 'email': 'a_new_email@bytedance.com',
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp).get('email'), 'a_new_email@bytedance.com')
+
+ resp = self.patch_helper(f'/api/v2/auth/users/{admin_id}', data={
+ 'name': 'cannot_modify',
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+
+ # now we are signing in as admin
+ self.signin_as_admin()
+ resp = self.patch_helper(f'/api/v2/auth/users/{user_id}', data={
+ 'role': 'ADMIN',
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp).get('role'), 'ADMIN')
+
+ resp = self.patch_helper(f'/api/v2/auth/users/{user_id}', data={
+ 'password': base64encode('fl@1234.'),
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+
+ def test_delete_user(self):
+ self.signin_as_admin()
+ resp = self.get_helper('/api/v2/auth/users')
+ resp_data = self.get_response_data(resp)
+ user_id = resp_data[0]['id']
+ admin_id = resp_data[1]['id']
+
+ self.signin_helper()
+ resp = self.delete_helper(url=f'/api/v2/auth/users/{user_id}')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+
+ self.signin_as_admin()
+
+ resp = self.delete_helper(url=f'/api/v2/auth/users/{admin_id}')
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ resp = self.delete_helper(url=f'/api/v2/auth/users/{user_id}')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp).get('username'), 'ada')
+
+ def test_get_specific_user(self):
+ resp = self.get_helper(url='/api/v2/auth/users/10086')
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+
+ resp = self.get_helper(url='/api/v2/auth/users/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp).get('username'), 'ada')
+
+ self.signin_as_admin()
+
+ resp = self.get_helper(url='/api/v2/auth/users/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp).get('username'), 'ada')
+
+ resp = self.get_helper(url='/api/v2/auth/users/10086')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_signout(self):
+ self.signin_helper()
+
+ resp = self.delete_helper(url='/api/v2/auth/signin')
+ self.assertEqual(resp.status_code, HTTPStatus.OK, resp.json)
+
+ resp = self.get_helper(url='/api/v2/auth/users/1')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+
+ @patch('fedlearner_webconsole.auth.apis.SsoHandlerFactory.get_handler')
+ @patch('fedlearner_webconsole.auth.third_party_sso.requests.post')
+ @patch('fedlearner_webconsole.auth.third_party_sso.requests.get')
+ def test_signin_oauth(self, mock_request_get, mock_request_post, mock_sso_handler):
+
+ mock_sso_handler.return_value = OAuthHandler(Sso(name='test', oauth=OAuthProtocol()))
+ resp = self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ mock_request_post.return_value = FakeResponse({}, HTTPStatus.OK)
+ resp = self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={'code': 'wrong_code'})
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+
+ mock_request_post.return_value = FakeResponse({'access_token': 'token'}, HTTPStatus.OK)
+ self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={'code': 'right_code'})
+ mock_request_get.assert_called_once()
+ get_user_info_with_cache.cache_clear()
+ mock_request_get.return_value = FakeResponse({'username': 'test', 'email': 'test'}, HTTPStatus.OK)
+ resp = self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={'code': 'right_code'})
+ data = self.get_response_data(resp)
+ self.assertEqual(data['user']['username'], 'test')
+ # test oauth sign in after deleted
+ with db.session_scope() as session:
+ user = UserService(session).get_user_by_username(data['user']['username'])
+ user.state = State.DELETED
+ session.commit()
+ resp = self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={'code': 'right_code'})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ @patch('fedlearner_webconsole.auth.apis.SsoHandlerFactory.get_handler')
+ @patch('fedlearner_webconsole.auth.third_party_sso.requests.get')
+ def test_signin_cas(self, mock_request_get, mock_sso_handler):
+ mock_sso_handler.return_value = CasHandler(Sso(name='test', cas=CasProtocol()))
+ resp = self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ resp = self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={'ticket': 'wrong_ticket'})
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ mock_request_get.assert_called_once()
+ fake_xml = """
+
+
+ test3
+
+
+
+ """
+ mock_request_get.return_value = FakeResponse(None, HTTPStatus.OK, fake_xml)
+ resp = self.post_helper(url='/api/v2/auth/signin?sso_name=test', data={'ticket': 'right_code'})
+ data = self.get_response_data(resp)
+ self.assertEqual(data['user']['username'], 'test3')
+
+
+class SsoInfosApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with open(f'{Envs.BASE_DIR}/testing/test_data/test_sso.json', encoding='utf-8') as f:
+ sso_infos_dict = json.load(f)
+ self.patch_ssoinfos = patch('fedlearner_webconsole.auth.third_party_sso.Envs.SSO_INFOS',
+ json.dumps(sso_infos_dict))
+ self.patch_ssoinfos.start()
+
+ def tearDown(self):
+ self.patch_ssoinfos.stop()
+
+ def test_get_sso_infos(self):
+ with patch('fedlearner_webconsole.auth.apis.sso_info_manager', SsoInfos()):
+ resp = self.get_helper(url='/api/v2/auth/sso_infos')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ self.assertTrue(data[0].get('oauth'))
+ self.assertEqual(data[0]['oauth'].get('secret'), '')
+
+
+class SelfUserApiTest(BaseTestCase):
+
+ def test_get_self_user(self):
+ resp = self.get_helper(url='/api/v2/auth/self')
+ self.assertEqual(self.get_response_data(resp)['name'], 'ada')
+ self.signout_helper()
+ resp = self.get_helper(url='/api/v2/auth/self')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ self.assertEqual('failed to find x-pc-auth or authorization within headers', resp.json.get('message'))
+
+
+class StrictSignInServiceTest(BaseTestCase):
+
+ def test_sign_in(self):
+ self.post_helper(f'{API_VERSION}/auth/signin', data={'username': 'ada', 'password': base64encode('fl@.')})
+ self.post_helper(f'{API_VERSION}/auth/signin', data={'username': 'ada', 'password': base64encode('fl@.')})
+ resp = self.post_helper(f'{API_VERSION}/auth/signin',
+ data={
+ 'username': 'ada',
+ 'password': base64encode('fl@.')
+ })
+ self.assertStatus(resp, HTTPStatus.BAD_REQUEST)
+ resp = self.post_helper(f'{API_VERSION}/auth/signin',
+ data={
+ 'username': 'ada',
+ 'password': base64encode('fl@12345.')
+ })
+ self.assertStatus(resp, HTTPStatus.FORBIDDEN)
+ self.assertEqual('Account is locked', resp.json['message'])
+
+ def test_banned_time(self):
+ self.post_helper(f'{API_VERSION}/auth/signin', data={'username': 'ada', 'password': base64encode('fl@.')})
+ self.post_helper(f'{API_VERSION}/auth/signin', data={'username': 'ada', 'password': base64encode('fl@.')})
+ self.post_helper(f'{API_VERSION}/auth/signin', data={'username': 'ada', 'password': base64encode('fl@.')})
+ resp = self.post_helper(f'{API_VERSION}/auth/signin',
+ data={
+ 'username': 'ada',
+ 'password': base64encode('fl@12345.')
+ })
+ self.assertStatus(resp, HTTPStatus.FORBIDDEN)
+ with db.session_scope() as session:
+ session.query(User).filter(User.username == 'ada').first().last_sign_in_at = now() - timedelta(minutes=31)
+ session.commit()
+ resp = self.post_helper(f'{API_VERSION}/auth/signin',
+ data={
+ 'username': 'ada',
+ 'password': base64encode('fl@12345.')
+ })
+ self.assertStatus(resp, HTTPStatus.OK)
+
+ def test_change_password(self):
+ with db.session_scope() as session:
+ user_id = UserService(session).get_user_by_username('ada').id
+ self.assertIsNotNone(session.query(SessionTbl).filter(SessionTbl.user_id == user_id).first())
+ self.signin_as_admin()
+ self.patch_helper(f'{API_VERSION}/auth/users/{user_id}', data={'password': base64encode('flfl123123.')})
+ with db.session_scope() as session:
+ self.assertIsNone(session.query(SessionTbl).filter(SessionTbl.user_id == user_id).first())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/models.py b/web_console_v2/api/fedlearner_webconsole/auth/models.py
index cad09bc5c..4778c2f0b 100644
--- a/web_console_v2/api/fedlearner_webconsole/auth/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/auth/models.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
# coding: utf-8
import enum
+
from passlib.apps import custom_app_context as pwd_context
from sqlalchemy.sql.schema import UniqueConstraint, Index
from sqlalchemy.sql import func
@@ -24,14 +25,16 @@
class Role(enum.Enum):
- USER = 'user'
- ADMIN = 'admin'
+ USER = 'USER'
+ ADMIN = 'ADMIN'
+# yapf: disable
MUTABLE_ATTRS_MAPPER = {
Role.USER: ('password', 'name', 'email'),
Role.ADMIN: ('password', 'role', 'name', 'email')
}
+# yapf: enable
class State(enum.Enum):
@@ -42,19 +45,27 @@ class State(enum.Enum):
@to_dict_mixin(ignores=['password', 'state'])
class User(db.Model):
__tablename__ = 'users_v2'
- __table_args__ = (UniqueConstraint('username', name='uniq_username'),
- default_table_args('This is webconsole user table'))
- id = db.Column(db.Integer, primary_key=True, comment='user id')
+ __table_args__ = (UniqueConstraint('username',
+ name='uniq_username'), default_table_args('This is webconsole user table'))
+ id = db.Column(db.Integer, primary_key=True, comment='user id', autoincrement=True)
username = db.Column(db.String(255), comment='unique name of user')
password = db.Column(db.String(255), comment='user password after encode')
- role = db.Column(db.Enum(Role, native_enum=False),
+ role = db.Column(db.Enum(Role, native_enum=False, create_constraint=False, length=21),
default=Role.USER,
comment='role of user')
name = db.Column(db.String(255), comment='name of user')
email = db.Column(db.String(255), comment='email of user')
- state = db.Column(db.Enum(State, native_enum=False),
+ state = db.Column(db.Enum(State, native_enum=False, create_constraint=False, length=21),
default=State.ACTIVE,
comment='state of user')
+ sso_name = db.Column(db.String(255), comment='sso_name')
+ last_sign_in_at = db.Column(db.DateTime(timezone=True),
+ nullable=True,
+ comment='the last time when user tries to sign in')
+ failed_sign_in_attempts = db.Column(db.Integer,
+ nullable=False,
+ default=0,
+ comment='failed sign in attempts since last successful sign in')
def set_password(self, password):
self.password = pwd_context.hash(password)
@@ -63,17 +74,12 @@ def verify_password(self, password):
return pwd_context.verify(password, self.password)
+@to_dict_mixin(ignores=['expired_at', 'created_at'])
class Session(db.Model):
__tablename__ = 'session_v2'
- __table_args__ = (Index('idx_jti', 'jti'),
- default_table_args('This is webconsole session table'))
- id = db.Column(db.Integer,
- primary_key=True,
- autoincrement=True,
- comment='session id')
+ __table_args__ = (Index('idx_jti', 'jti'), default_table_args('This is webconsole session table'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='session id')
jti = db.Column(db.String(64), comment='JWT jti')
- expired_at = db.Column(db.DateTime(timezone=True),
- comment='expired time, for db automatically clear')
- created_at = db.Column(db.DateTime(timezone=True),
- server_default=func.now(),
- comment='created at')
+ user_id = db.Column(db.Integer, nullable=False, comment='for whom the session is created')
+ expired_at = db.Column(db.DateTime(timezone=True), comment='expired time, for db automatically clear')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created at')
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/services.py b/web_console_v2/api/fedlearner_webconsole/auth/services.py
new file mode 100644
index 000000000..e0b41b226
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/auth/services.py
@@ -0,0 +1,110 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+from typing import List, Optional
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.auth.models import State, User, Session as SessionTbl, Role
+from fedlearner_webconsole.iam.client import create_iams_for_user
+from fedlearner_webconsole.utils.const import SIGN_IN_INTERVAL_SECONDS, MAX_SIGN_IN_ATTEMPTS
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+
+
+class UserService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def _filter_deleted_user(self, user: User) -> Optional[User]:
+ if not user or user.state == State.DELETED:
+ return None
+ return user
+
+ def get_user_by_id(self, user_id: int, filter_deleted=False) -> Optional[User]:
+ user = self._session.query(User).filter_by(id=user_id).first()
+ if filter_deleted:
+ return self._filter_deleted_user(user)
+ return user
+
+ def get_user_by_username(self, username: str, filter_deleted=False) -> Optional[User]:
+ user = self._session.query(User).filter_by(username=username).first()
+ if filter_deleted:
+ return self._filter_deleted_user(user)
+ return user
+
+ def get_all_users(self, filter_deleted=False) -> List[User]:
+ if filter_deleted:
+ return self._session.query(User).filter_by(state=State.ACTIVE).all()
+ return self._session.query(User).all()
+
+ def delete_user(self, user: User) -> User:
+ user.state = State.DELETED
+ return user
+
+ def create_user_if_not_exists(self,
+ username: str,
+ email: str,
+ name: Optional[str] = None,
+ role: Role = Role.USER,
+ sso_name: Optional[str] = None,
+ password: Optional[str] = None) -> User:
+ user = self.get_user_by_username(username)
+ if user is None:
+ user = User(username=username, name=name, email=email, state=State.ACTIVE, role=role, sso_name=sso_name)
+ if password is not None:
+ user.set_password(password)
+ self._session.add(user)
+ create_iams_for_user(user)
+ return user
+
+
+class SessionService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def get_session_by_jti(self, jti: str) -> Optional[SessionTbl]:
+ return self._session.query(SessionTbl).filter_by(jti=jti).first()
+
+ def delete_session(self, session_obj: SessionTbl) -> Optional[SessionTbl]:
+ if session_obj is None:
+ logging.warning('deleting a non-existence session...')
+ return None
+ self._session.delete(session_obj)
+ return session_obj
+
+ def delete_session_by_user_id(self, user_id: int) -> Optional[SessionTbl]:
+ session_obj = self._session.query(SessionTbl).filter(SessionTbl.user_id == user_id).first()
+ return self.delete_session(session_obj)
+
+
+class StrictSignInService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def can_sign_in(self, user: User):
+ if user.last_sign_in_at is None or \
+ to_timestamp(now()) - to_timestamp(user.last_sign_in_at) > SIGN_IN_INTERVAL_SECONDS:
+ return True
+ return not user.failed_sign_in_attempts >= MAX_SIGN_IN_ATTEMPTS
+
+ def update(self, user: User, is_signed_in: bool = True):
+ user.last_sign_in_at = now()
+ if is_signed_in:
+ user.failed_sign_in_attempts = 0
+ else:
+ user.failed_sign_in_attempts += 1
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/services_test.py b/web_console_v2/api/fedlearner_webconsole/auth/services_test.py
new file mode 100644
index 000000000..b4a810361
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/auth/services_test.py
@@ -0,0 +1,125 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from fedlearner_webconsole.auth.models import State, User, Session as SessionTbl
+from fedlearner_webconsole.auth.services import UserService, SessionService
+
+from fedlearner_webconsole.db import db
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class UserServiceTest(NoWebServerTestCase):
+
+ def test_get_user_by_id(self):
+ # case1: unexisted one
+ unexisted_uid = 9999
+ with db.session_scope() as session:
+ self.assertIsNone(UserService(session).get_user_by_id(unexisted_uid))
+
+ # case2: deleted one
+ with db.session_scope() as session:
+ deleted_user = User(username='deleted_one', email='who.knows@hhh.com', state=State.DELETED)
+ session.add(deleted_user)
+ session.commit()
+ with db.session_scope() as session:
+ self.assertIsNone(UserService(session).get_user_by_id(deleted_user.id, filter_deleted=True))
+
+ # case3: a real one
+ with db.session_scope() as session:
+ real_user = User(username='real_one', email='who.knows@hhh.com', state=State.ACTIVE)
+ session.add(real_user)
+ session.commit()
+ with db.session_scope() as session:
+ self.assertEqual(UserService(session).get_user_by_id(real_user.id).username, 'real_one')
+
+ def test_get_user_by_username(self):
+ # case1: unexisted one
+ unexisted_username = 'none_existed'
+ with db.session_scope() as session:
+ self.assertIsNone(UserService(session).get_user_by_username(unexisted_username))
+
+ # case2: deleted one
+ with db.session_scope() as session:
+ deleted_user = User(username='deleted_one', email='who.knows@hhh.com', state=State.DELETED)
+ session.add(deleted_user)
+ session.commit()
+ with db.session_scope() as session:
+ self.assertIsNone(UserService(session).get_user_by_username(deleted_user.username, filter_deleted=True))
+
+ # case3: a real one
+ with db.session_scope() as session:
+ real_user = User(username='real_one', email='who.knows@hhh.com', state=State.ACTIVE)
+ session.add(real_user)
+ session.commit()
+ with db.session_scope() as session:
+ self.assertEqual(UserService(session).get_user_by_username(real_user.username).id, 2)
+
+ def test_get_all_users(self):
+ with db.session_scope() as session:
+ session.add_all([
+ User(username='real_one', email='who.knows@hhh.com', state=State.ACTIVE),
+ User(username='deleted_one', email='who.knows@hhh.com', state=State.DELETED)
+ ])
+ session.commit()
+ with db.session_scope() as session:
+ self.assertEqual(len(UserService(session).get_all_users()), 2)
+ self.assertEqual(len(UserService(session).get_all_users(filter_deleted=True)), 1)
+
+ def test_delete_user(self):
+ with db.session_scope() as session:
+ user = User(username='real_one', email='who.knows@hhh.com', state=State.ACTIVE)
+ session.add(user)
+ session.commit()
+ with db.session_scope() as session:
+ deleted_user = UserService(session).delete_user(user)
+ session.commit()
+ self.assertEqual(deleted_user.state, State.DELETED)
+
+
+class SessionServiceTest(NoWebServerTestCase):
+
+ def test_get_session_by_jti(self):
+ jti = 'test'
+ with db.session_scope() as session:
+ session.add(SessionTbl(jti=jti, user_id=1))
+ session.commit()
+ with db.session_scope() as session:
+ session_obj = SessionService(session).get_session_by_jti(jti)
+ self.assertEqual(session_obj.jti, jti)
+ session_obj = SessionService(session).get_session_by_jti('fjeruif')
+ self.assertIsNone(session_obj)
+
+ def test_delete_session(self):
+ jti = 'test'
+ with db.session_scope() as session:
+ session.add(SessionTbl(jti=jti, user_id=1))
+ session.commit()
+ with db.session_scope() as session:
+ session_obj = SessionService(session).get_session_by_jti(jti)
+ session_obj = SessionService(session).delete_session(session_obj)
+ self.assertEqual(session_obj.jti, jti)
+ session.commit()
+ with db.session_scope() as session:
+ self.assertIsNone(SessionService(session).get_session_by_jti(jti))
+
+ with db.session_scope() as session:
+ session_obj = SessionService(session).get_session_by_jti('dfas')
+ session_obj = SessionService(session).delete_session(session_obj)
+ self.assertIsNone(session_obj)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/third_party_sso.py b/web_console_v2/api/fedlearner_webconsole/auth/third_party_sso.py
new file mode 100644
index 000000000..137ba84ec
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/auth/third_party_sso.py
@@ -0,0 +1,376 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import enum
+import json
+import logging
+import uuid
+from abc import ABCMeta, abstractmethod
+from datetime import timedelta, timezone
+from typing import Optional
+from collections import namedtuple
+from http import HTTPStatus
+from functools import wraps
+from urllib.parse import urlencode
+
+import requests
+import jwt
+from flask import request, g
+import xmltodict
+from google.protobuf.json_format import ParseDict
+from config import Config
+from envs import Envs
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.const import SSO_HEADER
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.utils.decorators.lru_cache import lru_cache
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+from fedlearner_webconsole.utils.flask_utils import set_current_user
+from fedlearner_webconsole.utils.pp_base64 import base64decode
+from fedlearner_webconsole.proto.auth_pb2 import SigninParameter, Sso
+from fedlearner_webconsole.exceptions import UnauthorizedException, InvalidArgumentException, NoAccessException
+from fedlearner_webconsole.auth.services import UserService, SessionService, StrictSignInService
+from fedlearner_webconsole.auth.models import Session, State, User
+
+UserInfo = namedtuple('UserInfo', ['username', 'email'])
+
+
+class SsoProtocol(enum.Enum):
+ OAUTH = 'oauth'
+ CAS = 'cas'
+
+
+def _generate_jwt_session(username: str, user_id: int, session: Session) -> str:
+ delta = timedelta(seconds=Config.JWT_ACCESS_TOKEN_EXPIRES)
+ expire_time = now(timezone.utc) + delta
+ jti = str(uuid.uuid4())
+ token = jwt.encode(
+ {
+ 'username': username,
+ 'exp': expire_time,
+ 'jti': jti
+ },
+ key=Config.JWT_SECRET_KEY,
+ )
+ session_obj = Session(jti=jti, user_id=user_id, expired_at=expire_time)
+ session.add(session_obj)
+ # PyJWT api has a breaking change for return types
+ if isinstance(token, bytes):
+ token = token.decode()
+ return token
+
+
+def _signout_jwt_session():
+ if hasattr(g, 'jti'):
+ jti = g.jti
+ else:
+ raise UnauthorizedException('Not sign in with jwt.')
+ with db.session_scope() as session:
+ session_obj = SessionService(session).get_session_by_jti(jti)
+ SessionService(session).delete_session(session_obj)
+ session.commit()
+
+
+def _validate_jwt_session(credentials: str) -> str:
+ time_now = to_timestamp(now(timezone.utc))
+ decoded_token = jwt.decode(credentials, Config.JWT_SECRET_KEY, algorithms='HS256')
+ expire_time = decoded_token.get('exp')
+ jti = decoded_token.get('jti')
+ username = decoded_token.get('username')
+ with db.session_scope() as session:
+ session = SessionService(session).get_session_by_jti(jti)
+ if session is None:
+ raise UnauthorizedException('No session.')
+ if expire_time < time_now:
+ raise UnauthorizedException('Token has expired.')
+ # Set jti to for signout to find the session to remove.
+ g.jti = jti
+ return username
+
+
+class SsoHandler(metaclass=ABCMeta):
+
+ def __init__(self, sso):
+ self.sso = sso
+
+ @abstractmethod
+ def signin(self, signin_parameter: SigninParameter) -> dict:
+ pass
+
+ @abstractmethod
+ def signout(self):
+ pass
+
+ @abstractmethod
+ def check_credentials(self, credentials) -> str:
+ """
+ Check credentials and return the username.
+ """
+
+ def check_credentials_and_set_current_user(self, credentials):
+ try:
+ username = self.check_credentials(credentials)
+ except Exception as err:
+ raise UnauthorizedException(str(err)) from err
+ with db.session_scope() as session:
+ user = UserService(session).get_user_by_username(username, filter_deleted=True)
+ if user is None:
+ raise UnauthorizedException(f'User {username} not found.')
+ set_current_user(user)
+
+ @classmethod
+ def check_user_validity(cls, user: User):
+ if user.state == State.DELETED:
+ error_msg = f'user: {user.username} has been deleted'
+ logging.error(error_msg)
+ raise InvalidArgumentException(error_msg)
+
+
+class OAuthHandler(SsoHandler):
+
+ def get_access_token(self, code: str) -> str:
+ try:
+ r = requests.post(self.sso.oauth.access_token_url,
+ data={
+ 'code': code,
+ 'client_id': self.sso.oauth.client_id,
+ 'client_secret': self.sso.oauth.secret,
+ 'redirect_uri': self.sso.oauth.redirect_uri,
+ 'grant_type': 'authorization_code'
+ })
+ except Exception as e:
+ error_msg = f'Get access_token failed from sso: {self.sso.name}: {str(e)}.'
+ logging.error(error_msg)
+ raise UnauthorizedException(error_msg) from e
+ if r.status_code != HTTPStatus.OK:
+ error_msg = f'Get access_token failed from sso: {self.sso.name}: {r.json()}.'
+ logging.error(error_msg)
+ raise UnauthorizedException(error_msg)
+ access_token = r.json().get('access_token')
+ if access_token is None:
+ error_msg = f'Get access_token failed from sso: ' \
+ f'{self.sso.name}: no access_token in response.'
+ logging.error(error_msg)
+ raise UnauthorizedException(error_msg)
+ return access_token
+
+ def get_user_info(self, access_token: str) -> UserInfo:
+ user_info = get_user_info_with_cache(self.sso.name, self.sso.oauth.user_info_url, access_token,
+ self.sso.oauth.username_key, self.sso.oauth.email_key)
+ return user_info
+
+ def signin(self, signin_parameter: SigninParameter) -> dict:
+ code = signin_parameter.code
+ if code == '':
+ raise InvalidArgumentException('OAuth code is not found')
+ access_token = self.get_access_token(code)
+
+ user_info = self.get_user_info(access_token)
+
+ with db.session_scope() as session:
+ user = UserService(session).create_user_if_not_exists(username=user_info.username,
+ email=user_info.email,
+ sso_name=self.sso.name,
+ name=user_info.username)
+ self.check_user_validity(user)
+ StrictSignInService(session).update(user, is_signed_in=True)
+ session.commit()
+ return {'user': user.to_dict(), 'access_token': access_token}
+
+ def signout(self):
+ get_user_info_with_cache.cache_clear()
+
+ def check_credentials(self, credentials):
+ user_info = get_user_info_with_cache(self.sso.name, self.sso.oauth.user_info_url, credentials,
+ self.sso.oauth.username_key, self.sso.oauth.email_key)
+ return user_info.username
+
+
+class JwtHandler(SsoHandler):
+
+ def __init__(self):
+ super().__init__(None)
+
+ def signin(self, signin_parameter: SigninParameter) -> dict:
+ username = signin_parameter.username
+ password = base64decode(signin_parameter.password)
+ if username == '' or password == '':
+ raise InvalidArgumentException('username or password is not found')
+ with db.session_scope() as session:
+ user = UserService(session).get_user_by_username(username, filter_deleted=True)
+ if user is None:
+ raise InvalidArgumentException(f'Failed to find user: {username}')
+ strict_service = StrictSignInService(session)
+ if not strict_service.can_sign_in(user):
+ raise NoAccessException('Account is locked')
+ if not user.verify_password(password):
+ logging.warning(f'user {user.username} login failed due to wrong password')
+ emit_store('user.wrong_password', 1)
+ strict_service.update(user, is_signed_in=False)
+ session.commit()
+ raise InvalidArgumentException('Invalid password')
+ token = _generate_jwt_session(username, user.id, session)
+ strict_service.update(user, is_signed_in=True)
+ session.commit()
+ return {'user': user.to_dict(), 'access_token': token}
+
+ def signout(self):
+ _signout_jwt_session()
+
+ def check_credentials(self, credentials: str) -> str:
+ return _validate_jwt_session(credentials)
+
+
+class CasHandler(SsoHandler):
+
+ def _service_validate(self, ticket: str) -> str:
+ params_dict = dict(
+ service=self.sso.cas.service_url,
+ ticket=ticket,
+ )
+ validate_url = f'{self.sso.cas.cas_server_url}' \
+ f'{self.sso.cas.validate_route}?{urlencode(params_dict)}'
+ r = requests.get(validate_url)
+ if r.status_code != HTTPStatus.OK:
+ logging.error(f'Cas sso {self.sso.name} receive Error code {r.status_code}')
+ raise UnauthorizedException('Sso server error.')
+ resp_dict = xmltodict.parse(r.content)
+ if 'cas:authenticationSuccess' in resp_dict['cas:serviceResponse']:
+ resp_data = resp_dict['cas:serviceResponse']['cas:authenticationSuccess']
+ return resp_data['cas:user']
+ logging.error(f'sso: {self.sso.name} CAS returned unexpected result')
+ raise UnauthorizedException('Wrong ticket.')
+
+ def signin(self, signin_parameter: SigninParameter) -> dict:
+ ticket = signin_parameter.ticket
+ if ticket == '':
+ raise InvalidArgumentException('CAS ticket is not found')
+ username = self._service_validate(ticket)
+
+ with db.session_scope() as session:
+ user = UserService(session).create_user_if_not_exists(username=username,
+ name=username,
+ email='',
+ sso_name=self.sso.name)
+ self.check_user_validity(user)
+ session.flush()
+ token = _generate_jwt_session(username, user.id, session)
+ StrictSignInService(session).update(user, is_signed_in=True)
+ session.commit()
+ return {'user': user.to_dict(), 'access_token': token}
+
+ def signout(self):
+ _signout_jwt_session()
+
+ def check_credentials(self, credentials: str) -> str:
+ return _validate_jwt_session(credentials)
+
+
+class SsoInfos:
+
+ def __init__(self):
+ try:
+ sso_infos_dict = json.loads(Envs.SSO_INFOS)
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'Failed parse SSO_INFOS: {str(e)}')
+ sso_infos_dict = []
+ self.sso_infos = []
+ # sso_infos without server info which should not be visible to the frontend
+ self.sso_handlers = {}
+ for sso in sso_infos_dict:
+ # check the format of sso_infos
+ sso_proto = ParseDict(sso, Sso(), ignore_unknown_fields=True)
+ if sso_proto.name == 'default':
+ logging.error('Sso name should not be \'default\'')
+ self.sso_infos.append(sso_proto)
+ if sso_proto.WhichOneof('protocol') == SsoProtocol.OAUTH.value:
+ self.sso_handlers[sso_proto.name] = OAuthHandler(sso_proto)
+ elif sso_proto.WhichOneof('protocol') == SsoProtocol.CAS.value:
+ self.sso_handlers[sso_proto.name] = CasHandler(sso_proto)
+ else:
+ logging.error(f'SSO {sso_proto.name} does not have supported protocol.')
+ self.sso_handlers['default'] = JwtHandler()
+
+ def get_sso_info(self, name: str) -> Optional[Sso]:
+ for sso in self.sso_infos:
+ if name == sso.name:
+ return sso
+ return None
+
+
+sso_info_manager = SsoInfos()
+
+
+class SsoHandlerFactory:
+
+ @staticmethod
+ def get_handler(sso_name) -> SsoHandler:
+ jwt_handler = sso_info_manager.sso_handlers['default']
+ return sso_info_manager.sso_handlers.get(sso_name, jwt_handler)
+
+
+# Separate the func from the class to avoid leaking memory.
+@lru_cache(timeout=600, maxsize=128)
+def get_user_info_with_cache(sso_name: str, user_info_url: str, access_token: str, username_key: str,
+ email_key: str) -> Optional[UserInfo]:
+ if not username_key:
+ username_key = 'username'
+ if not email_key:
+ email_key = 'email'
+ try:
+ r = requests.get(user_info_url, headers={'Authorization': f'Bearer {access_token}'})
+ except Exception as e: # pylint: disable=broad-except
+ error_msg = f'Get user_info failed from sso: {sso_name}: {str(e)}.'
+ logging.error(error_msg)
+ raise UnauthorizedException(error_msg) from e
+ if r.status_code != HTTPStatus.OK:
+ error_msg = f'Get user_info failed from sso: {sso_name}: {r.json()}.'
+ logging.error(error_msg)
+ raise UnauthorizedException(error_msg)
+ user_info_dict = r.json()
+ # This is to be compatible with some API response schema with data.
+ if 'data' in user_info_dict:
+ user_info_dict = user_info_dict.get('data')
+ if username_key not in user_info_dict:
+ error_msg = f'Get user_info failed from sso: ' \
+ f'{sso_name}: no {username_key} in response.'
+ logging.error(error_msg)
+ raise UnauthorizedException(error_msg)
+ user_info = UserInfo(username=user_info_dict.get(username_key), email=user_info_dict.get(email_key, ''))
+ return user_info
+
+
+def credentials_required(fn):
+
+ @wraps(fn)
+ def decorator(*args, **kwargs):
+
+ sso_headers = request.headers.get(SSO_HEADER, None)
+ jwt_headers = request.headers.get('Authorization', None)
+ sso_name = None
+
+ if sso_headers is None and jwt_headers is None and Envs.DEBUG:
+ return fn(*args, **kwargs)
+
+ if sso_headers:
+ sso_name, _, credentials = sso_headers.split()
+ elif jwt_headers:
+ _, credentials = jwt_headers.split()
+ else:
+ raise UnauthorizedException(f'failed to find {SSO_HEADER} or authorization within headers')
+ SsoHandlerFactory.get_handler(sso_name).check_credentials_and_set_current_user(credentials)
+ return fn(*args, **kwargs)
+
+ return decorator
diff --git a/web_console_v2/api/fedlearner_webconsole/auth/third_party_sso_test.py b/web_console_v2/api/fedlearner_webconsole/auth/third_party_sso_test.py
new file mode 100644
index 000000000..1207c26bf
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/auth/third_party_sso_test.py
@@ -0,0 +1,159 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from http import HTTPStatus
+from unittest.mock import patch
+
+import jwt
+from flask import g
+from datetime import timedelta
+from config import Config
+from envs import Envs
+from testing.common import BaseTestCase
+from testing.helpers import FakeResponse
+from fedlearner_webconsole.auth.services import SessionService
+from fedlearner_webconsole.auth.third_party_sso import credentials_required, \
+ get_user_info_with_cache, SsoInfos, JwtHandler
+from fedlearner_webconsole.exceptions import UnauthorizedException
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.pp_datetime import now
+
+
+@credentials_required
+def test_some_api():
+ return 1
+
+
+class OauthHandlerTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ session.add(User(username='test'))
+ session.commit()
+ with open(f'{Envs.BASE_DIR}/testing/test_data/test_sso.json', encoding='utf-8') as f:
+ self.patch_ssoinfos = patch('fedlearner_webconsole.auth.third_party_sso.Envs.SSO_INFOS', f.read())
+ self.patch_ssoinfos.start()
+ self.fake_sso_info_manager = SsoInfos()
+ self.patch_manager = patch('fedlearner_webconsole.auth.third_party_sso.sso_info_manager',
+ self.fake_sso_info_manager)
+ self.patch_manager.start()
+
+ def tearDown(self):
+ self.patch_manager.stop()
+ self.patch_ssoinfos.stop()
+ # clear cache to isolate the cache of each test case.
+ get_user_info_with_cache.cache_clear()
+ super().tearDown()
+
+ def test_get_sso_infos(self):
+ self.assertEqual(len(self.fake_sso_info_manager.sso_infos), 2)
+
+ def test_get_sso_info(self,):
+ self.assertEqual(self.fake_sso_info_manager.get_sso_info('test').name, 'test')
+ self.assertEqual(self.fake_sso_info_manager.get_sso_info('test').display_name, 'test')
+
+ @patch('fedlearner_webconsole.auth.third_party_sso.requests.get')
+ @patch('fedlearner_webconsole.auth.third_party_sso.request.headers.get')
+ def test_credentials_required(self, mock_headers, mock_request_get):
+ # test not supported sso
+ mock_headers.return_value = 'not_supported_sso oauth access_token'
+ self.assertRaises(UnauthorizedException, test_some_api)
+ mock_request_get.return_value = FakeResponse({'username': 'test', 'email': 'test'}, HTTPStatus.OK)
+ self.assertRaises(UnauthorizedException, test_some_api)
+
+ # test supported sso
+ mock_headers.return_value = 'test oauth access_token'
+ test_some_api()
+
+ @patch('fedlearner_webconsole.auth.third_party_sso.requests.get')
+ @patch('fedlearner_webconsole.auth.third_party_sso.request.headers.get')
+ def test_get_user_info_cache(self, mock_headers, mock_request_get):
+ mock_headers.return_value = 'test oauth access_token'
+ mock_request_get.return_value = FakeResponse({'username': 'test', 'email': 'test'}, HTTPStatus.OK)
+ test_some_api()
+ test_some_api()
+ mock_request_get.assert_called_once()
+ mock_headers.return_value = 'test oauth access_token1'
+ mock_request_get.return_value = FakeResponse({'data': {'username': 'test', 'email': 'test'}}, HTTPStatus.OK)
+ test_some_api()
+ test_some_api()
+ self.assertEqual(mock_request_get.call_count, 2)
+
+
+class JwtHandlerTest(BaseTestCase):
+
+ def test_check_credentials(self):
+ jwt_handler = JwtHandler()
+ self.assertEqual(jwt_handler.check_credentials(self._token), 'ada')
+ jti = jwt.decode(self._token, key=Config.JWT_SECRET_KEY, algorithms='HS256').get('jti')
+ with db.session_scope() as session:
+ session_obj = SessionService(session).get_session_by_jti(jti)
+ SessionService(session).delete_session(session_obj=session_obj)
+ session.commit()
+ self.assertRaises(UnauthorizedException, jwt_handler.check_credentials, self._token)
+ self.signin_as_admin()
+ with patch('fedlearner_webconsole.auth.third_party_sso.now') as fake_now:
+ fake_now.return_value = now() + timedelta(seconds=86405)
+ self.assertRaises(UnauthorizedException, jwt_handler.check_credentials, self._token)
+
+ def test_signout(self):
+ jwt_handler = JwtHandler()
+ jti = jwt.decode(self._token, key=Config.JWT_SECRET_KEY, algorithms='HS256').get('jti')
+ with db.session_scope() as session:
+ session_obj = SessionService(session).get_session_by_jti(jti)
+ self.assertIsNotNone(session_obj)
+ g.jti = jti
+ jwt_handler.signout()
+ with db.session_scope() as session:
+ session_obj = SessionService(session).get_session_by_jti(jti)
+ self.assertIsNone(session_obj)
+
+
+class CasHandlerTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ session.add(User(username='test'))
+ session.commit()
+ with open(f'{Envs.BASE_DIR}/testing/test_data/test_sso.json', encoding='utf-8') as f:
+ self.patch_ssoinfos = patch('fedlearner_webconsole.auth.third_party_sso.Envs.SSO_INFOS', f.read())
+ self.patch_ssoinfos.start()
+ self.fake_sso_info_manager = SsoInfos()
+ self.patch_manager = patch('fedlearner_webconsole.auth.third_party_sso.sso_info_manager',
+ self.fake_sso_info_manager)
+ self.patch_manager.start()
+
+ def tearDown(self):
+ self.patch_manager.stop()
+ self.patch_ssoinfos.stop()
+ super().tearDown()
+
+ @patch('fedlearner_webconsole.auth.third_party_sso.requests.get')
+ @patch('fedlearner_webconsole.auth.third_party_sso.request.headers.get')
+ def test_credentials_required_cas(self, mock_headers, mock_request_get):
+ mock_request_get.return_value = FakeResponse({'username': 'test', 'email': 'test'}, HTTPStatus.OK)
+ # test supported sso
+ mock_headers.return_value = f'test_cas cas {self._token}'
+ test_some_api()
+ mock_headers.return_value = f'test_cas cas {self._token}aa'
+ self.assertRaises(UnauthorizedException, test_some_api)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/cleanup/BUILD.bazel
new file mode 100644
index 000000000..835bb1f4b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/BUILD.bazel
@@ -0,0 +1,137 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_test",
+ size = "medium",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = ["services.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "services_test",
+ size = "medium",
+ srcs = [
+ "services_test.py",
+ ],
+ imports = ["../.."],
+ main = "services_test.py",
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "cleaner_cronjob_lib",
+ srcs = ["cleaner_cronjob.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "cleaner_cronjob_test",
+ size = "medium",
+ srcs = [
+ "cleaner_cronjob_test.py",
+ ],
+ imports = ["../.."],
+ main = "cleaner_cronjob_test.py",
+ deps = [
+ ":cleaner_cronjob_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "@common_flask_restful//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/__init__.py b/web_console_v2/api/fedlearner_webconsole/cleanup/__init__.py
similarity index 100%
rename from web_console_v2/api/fedlearner_webconsole/setting/__init__.py
rename to web_console_v2/api/fedlearner_webconsole/cleanup/__init__.py
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/apis.py b/web_console_v2/api/fedlearner_webconsole/cleanup/apis.py
new file mode 100644
index 000000000..0991ca90e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/apis.py
@@ -0,0 +1,150 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from flask_restful import Resource, Api
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required, use_args
+from fedlearner_webconsole.utils.flask_utils import make_flask_response, FilterExpField
+from fedlearner_webconsole.cleanup.services import CleanupService
+from marshmallow import Schema, fields
+
+
+class GetCleanupParams(Schema):
+ filter = FilterExpField(required=False, load_default=None)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
+
+
+class CleanupsApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_args(GetCleanupParams(), location='query')
+ def get(self, params: dict):
+ """Get a list of all cleanups
+ ---
+ tags:
+ - cleanup
+ description: get cleanups list
+ parameters:
+ - in: query
+ name: filter
+ schema:
+ type: string
+ required: false
+ - in: query
+ name: page
+ schema:
+ type: integer
+ required: false
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ required: false
+ responses:
+ 200:
+ description: Get cleanups list result
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.CleanupPb'
+ """
+ with db.session_scope() as session:
+ try:
+ pagination = CleanupService(session).get_cleanups(
+ filter_exp=params['filter'],
+ page=params['page'],
+ page_size=params['page_size'],
+ )
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ data = [t.to_proto() for t in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+
+class CleanupApi(Resource):
+
+ @credentials_required
+ @admin_required
+ def get(self, cleanup_id: int):
+ """Get a cleanup by id
+ ---
+ tags:
+ - cleanup
+ description: get details of cleanup
+ parameters:
+ - in: path
+ name: cleanup_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: Get details of cleanup
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.CleanupPb'
+ 404:
+ description: The cleanup with specified ID is not found
+ """
+ with db.session_scope() as session:
+ cleanup = CleanupService(session).get_cleanup(cleanup_id)
+ return make_flask_response(cleanup)
+
+
+class CleanupCancelApi(Resource):
+
+ @credentials_required
+ @admin_required
+ def post(self, cleanup_id: int):
+ """Get a cleanup by id
+ ---
+ tags:
+ - cleanup
+ description: change the state of cleanup
+ parameters:
+ - in: path
+ name: cleanup_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: The Cleanup's state has been updated
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.CleanupPb'
+ 400:
+ description: The param of state in the request body is invliad
+ 404:
+ description: The cleanup with specified ID is not found
+ """
+ with db.session_scope() as session:
+ cleanup = CleanupService(session).cancel_cleanup_by_id(cleanup_id)
+ session.commit()
+ return make_flask_response(cleanup)
+
+
+def initialize_cleanup_apis(api: Api):
+ api.add_resource(CleanupsApi, '/cleanups')
+ api.add_resource(CleanupApi, '/cleanups/')
+ api.add_resource(CleanupCancelApi, '/cleanups/:cancel')
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/apis_test.py b/web_console_v2/api/fedlearner_webconsole/cleanup/apis_test.py
new file mode 100644
index 000000000..2194ee6de
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/apis_test.py
@@ -0,0 +1,180 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import urllib.parse
+from http import HTTPStatus
+from datetime import datetime, timezone
+from testing.common import BaseTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.cleanup.models import ResourceType, CleanupState, Cleanup
+from fedlearner_webconsole.dataset.models import Dataset, DatasetKindV2, DatasetType
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupPayload, CleanupPb
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.proto import to_dict
+
+
+class CleanupsApiTest(BaseTestCase):
+ _TARGET_START_AT = datetime(2022, 2, 22, 10, 10, 12, tzinfo=timezone.utc)
+ _CREATED_AT = datetime(2022, 2, 22, 3, 3, 4, tzinfo=timezone.utc)
+ _CREATED_AT_2 = datetime(2022, 3, 22, 3, 3, 4, tzinfo=timezone.utc)
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ dataset1 = Dataset(id=1,
+ name='default_dataset',
+ dataset_type=DatasetType.PSI,
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ path='/data/default_dataset/')
+ self.default_paylaod = CleanupPayload(paths=['/Major333/test_path/a.csv'])
+ cleanup1 = Cleanup(id=1,
+ state=CleanupState.WAITING,
+ created_at=self._CREATED_AT,
+ updated_at=self._CREATED_AT,
+ target_start_at=self._TARGET_START_AT,
+ resource_id=dataset1.id,
+ resource_type=ResourceType(Dataset).name,
+ payload=self.default_paylaod)
+ cleanup2 = Cleanup(id=2,
+ state=CleanupState.CANCELED,
+ created_at=self._CREATED_AT,
+ updated_at=self._CREATED_AT,
+ target_start_at=self._TARGET_START_AT,
+ resource_id=dataset1.id,
+ resource_type=ResourceType(Dataset).name,
+ payload=self.default_paylaod)
+ with db.session_scope() as session:
+ session.add(dataset1)
+ session.add(cleanup1)
+ session.add(cleanup2)
+ session.commit()
+ self.signin_as_admin()
+
+ def test_get_without_filter_and_pagination(self):
+ response = self.get_helper('/api/v2/cleanups')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+
+ def test_get_with_pagination(self):
+ response = self.get_helper('/api/v2/cleanups?page=1&page_size=1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['id'], 1)
+
+ def test_get_with_invalid_filter(self):
+ response = self.get_helper('/api/v2/cleanups?filter=invalid')
+ self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
+
+ def test_get_with_filter(self):
+ filter_exp = urllib.parse.quote('(and(resource_type="DATASET")(state="WAITING"))')
+ response = self.get_helper(f'/api/v2/cleanups?filter={filter_exp}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['id'], 1)
+
+
+class CleanupApiTest(BaseTestCase):
+ _TARGET_START_AT = datetime(2022, 2, 22, 10, 10, 12, tzinfo=timezone.utc)
+ _CREATED_AT = datetime(2022, 2, 22, 3, 3, 4, tzinfo=timezone.utc)
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ dataset1 = Dataset(id=1,
+ name='default_dataset',
+ dataset_type=DatasetType.PSI,
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ path='/data/default_dataset/')
+ self.default_paylaod = CleanupPayload(paths=['/Major333/test_path/a.csv'])
+ cleanup1 = Cleanup(id=1,
+ state=CleanupState.WAITING,
+ created_at=self._CREATED_AT,
+ updated_at=self._CREATED_AT,
+ target_start_at=self._TARGET_START_AT,
+ resource_id=dataset1.id,
+ resource_type=ResourceType(Dataset).name,
+ payload=self.default_paylaod)
+ with db.session_scope() as session:
+ session.add(dataset1)
+ session.add(cleanup1)
+ session.commit()
+
+ def test_get(self):
+ expected_cleanup_proto = CleanupPb(id=1,
+ state='WAITING',
+ completed_at=None,
+ resource_id=1,
+ resource_type='DATASET',
+ payload=self.default_paylaod,
+ target_start_at=to_timestamp(self._TARGET_START_AT),
+ updated_at=to_timestamp(self._CREATED_AT),
+ created_at=to_timestamp(self._CREATED_AT))
+ self.signin_as_admin()
+ response = self.get_helper(f'/api/v2/cleanups/{expected_cleanup_proto.id}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ cleanup = self.get_response_data(response)
+ self.assertEqual(cleanup, to_dict(expected_cleanup_proto))
+
+
+class CleanupCancelApiTest(BaseTestCase):
+ _TARGET_START_AT = datetime(2022, 2, 22, 10, 10, 12, tzinfo=timezone.utc)
+ _CREATED_AT = datetime(2022, 2, 22, 3, 3, 4, tzinfo=timezone.utc)
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ dataset1 = Dataset(id=1,
+ name='default_dataset',
+ dataset_type=DatasetType.PSI,
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ path='/data/default_dataset/')
+ self.default_payload = CleanupPayload(paths=['/Major333/test_path/a.csv'])
+ cleanup1 = Cleanup(id=1,
+ state=CleanupState.WAITING,
+ created_at=self._CREATED_AT,
+ updated_at=self._CREATED_AT,
+ target_start_at=self._TARGET_START_AT,
+ resource_id=dataset1.id,
+ resource_type=ResourceType(Dataset).name,
+ payload=self.default_payload)
+ with db.session_scope() as session:
+ session.add(dataset1)
+ session.add(cleanup1)
+ session.commit()
+
+ def test_cleanup_waiting_cancel(self):
+ self.signin_as_admin()
+ response = self.post_helper('/api/v2/cleanups/1:cancel', {})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ response = self.get_helper('/api/v2/cleanups/1')
+ cancelled_cleanup = self.get_response_data(response)
+ self.assertEqual(cancelled_cleanup['state'], 'CANCELED')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/cleaner_cronjob.py b/web_console_v2/api/fedlearner_webconsole/cleanup/cleaner_cronjob.py
new file mode 100644
index 000000000..0ba84225e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/cleaner_cronjob.py
@@ -0,0 +1,119 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple, List
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, CleanupCronJobOutput
+from datetime import timezone, datetime
+from fedlearner_webconsole.cleanup.models import Cleanup, CleanupState
+
+
+class CleanupCronJob(IRunnerV2):
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._file_manager = FileManager()
+
+ def _get_current_utc_time(self):
+ return datetime.now(tz=timezone.utc)
+
+ def _execute_cleanup(self, cleanup: Cleanup):
+ for path in cleanup.payload.paths:
+ if self._file_manager.exists(path):
+ self._file_manager.remove(path)
+ cleanup.state = CleanupState.SUCCEEDED
+ cleanup.completed_at = self._get_current_utc_time()
+
+ def _get_waiting_ids(self) -> List[int]:
+ with db.session_scope() as session:
+ current_time = self._get_current_utc_time()
+ waiting_ids = session.query(Cleanup.id).filter(Cleanup.state == CleanupState.WAITING).filter(
+ Cleanup.target_start_at <= current_time).all()
+ logging.info(f'Has collected waiting cleanup ids:{waiting_ids}')
+ # unwrap query result
+ return [cleanup_id for cleanup_id, *_ in waiting_ids]
+
+ def _sweep_waiting_cleanups(self, waiting_list: List[int]):
+ logging.info(f'will sweep the cleanup ids:{waiting_list}')
+ for cleanup_id in waiting_list:
+ with db.session_scope() as session:
+ logging.info(f'will sweep the waiting cleanup:{cleanup_id}')
+ current_time = self._get_current_utc_time()
+ cleanup = session.query(Cleanup).populate_existing().with_for_update().get(cleanup_id)
+ try:
+ if cleanup and cleanup.state == CleanupState.WAITING and \
+ cleanup.target_start_at.replace(tzinfo=timezone.utc) <= current_time:
+ cleanup.state = CleanupState.RUNNING
+ # Release the lock
+ session.commit()
+ else:
+ logging.warning(f'In waiting cleanup list are being swept, \
+ the cleanup:{cleanup_id} has been changed/canceled. It has been skipped.')
+ # Release the lock
+ session.rollback()
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'The cleanup:{cleanup.id} has failed. error_msg is:{str(e)}')
+ cleanup.state = CleanupState.FAILED
+ session.commit()
+
+ def _get_running_ids(self) -> List[int]:
+ with db.session_scope() as session:
+ running_ids = session.query(Cleanup.id).filter(Cleanup.state == CleanupState.RUNNING).all()
+ logging.info(f'Has collected waiting cleanup ids:{running_ids}')
+ # unwrap query result
+ return [cleanup_id for cleanup_id, *_ in running_ids]
+
+ def _sweep_running_cleanups(self, running_list: List[int]) -> Tuple[List[int], List[int]]:
+ logging.info(f'will sweep the cleanup ids:{running_list}')
+ succeeded_cleanup_ids = []
+ failed_cleanup_ids = []
+ for cleanup_id in running_list:
+ with db.session_scope() as session:
+ logging.info(f'will sweep the running cleanup:{cleanup_id}')
+ cleanup = session.query(Cleanup).populate_existing().with_for_update().get(cleanup_id)
+ try:
+ if cleanup and cleanup.state == CleanupState.RUNNING:
+ self._execute_cleanup(cleanup)
+ # Release the lock
+ session.commit()
+ succeeded_cleanup_ids.append(cleanup.id)
+ else:
+ logging.warning(f'In running cleanup list are being swept, \
+ the cleanup:{cleanup_id} has been changed/canceled. It has been skipped.')
+ # Release the lock
+ session.rollback()
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'The cleanup:{cleanup.id} has failed. error_msg is:{str(e)}')
+ cleanup.state = CleanupState.FAILED
+ session.commit()
+ failed_cleanup_ids.append(cleanup.id)
+ return succeeded_cleanup_ids, failed_cleanup_ids
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ try:
+ waiting_ids = self._get_waiting_ids()
+ self._sweep_waiting_cleanups(waiting_ids)
+ running_ids = self._get_running_ids()
+ succeeded_cleanup_ids, failed_cleanup_ids = self._sweep_running_cleanups(running_ids)
+ return RunnerStatus.DONE, RunnerOutput(cleanup_cron_job_output=CleanupCronJobOutput(
+ succeeded_cleanup_ids=succeeded_cleanup_ids, failed_cleanup_ids=failed_cleanup_ids))
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'Cleanup Cronjob is failed. error_msg is:{str(e)}')
+ return RunnerStatus.FAILED, RunnerOutput(error_message=str(e))
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/cleaner_cronjob_test.py b/web_console_v2/api/fedlearner_webconsole/cleanup/cleaner_cronjob_test.py
new file mode 100644
index 000000000..d99f8c91b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/cleaner_cronjob_test.py
@@ -0,0 +1,93 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import timezone, datetime
+from unittest.mock import MagicMock, patch
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.fake_time_patcher import FakeTimePatcher
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput
+
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.cleanup.models import Cleanup, CleanupState, ResourceType
+from fedlearner_webconsole.cleanup.cleaner_cronjob import CleanupCronJob
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupPayload
+
+
+@patch('fedlearner_webconsole.utils.file_manager.FileManager.exists')
+@patch('fedlearner_webconsole.utils.file_manager.FileManager.remove')
+class CleanupCronJobTest(NoWebServerTestCase):
+ _CLEANUP_ID = 1
+
+ def setUp(self):
+ super().setUp()
+ self.time_patcher = FakeTimePatcher()
+ self.time_patcher.start(datetime(2012, 1, 14, 12, 0, 5, tzinfo=timezone.utc))
+ self.default_paylaod = CleanupPayload(paths=['/Major333/test_path/a.csv'])
+ with db.session_scope() as session:
+ self.default_cleanup = Cleanup(id=1,
+ state=CleanupState.WAITING,
+ target_start_at=datetime(1999, 3, 1, tzinfo=timezone.utc),
+ resource_id=1,
+ resource_type=ResourceType(Dataset).name,
+ payload=self.default_paylaod)
+ session.add(self.default_cleanup)
+ session.commit()
+
+ def tearDown(self):
+ self.time_patcher.stop()
+ super().tearDown()
+
+ def test_run_failed_alone(self, mock_remove: MagicMock, mock_exists: MagicMock):
+ # The file always exist
+ mock_exists.return_value = True
+ #Failed to delete
+ mock_remove.side_effect = RuntimeError('fake error')
+
+ runner = CleanupCronJob()
+ runner_input = RunnerInput()
+ runner_context = RunnerContext(index=0, input=runner_input)
+
+ status, output = runner.run(runner_context)
+ self.assertEqual(status, RunnerStatus.DONE)
+ expected_cleanup_status = CleanupState.FAILED
+ with db.session_scope() as session:
+ cleanup = session.query(Cleanup).get(1)
+ self.assertEqual(expected_cleanup_status, cleanup.state)
+
+ def test_run_success_alone(self, mock_remove: MagicMock, mock_exists: MagicMock):
+ # The file always exist
+ mock_exists.return_value = True
+ #Success to delete
+ mock_remove.reset_mock(side_effect=True)
+
+ runner = CleanupCronJob()
+ runner_input = RunnerInput()
+ runner_context = RunnerContext(index=0, input=runner_input)
+
+ status, output = runner.run(runner_context)
+ self.assertEqual(status, RunnerStatus.DONE)
+ expected_cleanup_status = CleanupState.SUCCEEDED
+ with db.session_scope() as session:
+ cleanup = session.query(Cleanup).get(1)
+ self.assertEqual(expected_cleanup_status, cleanup.state)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/models.py b/web_console_v2/api/fedlearner_webconsole/cleanup/models.py
new file mode 100644
index 000000000..a11d73dd5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/models.py
@@ -0,0 +1,88 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+
+from sqlalchemy.sql import func
+from google.protobuf import text_format
+
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupPayload, CleanupPb
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobStage
+from fedlearner_webconsole.mmgr.models import Model
+from fedlearner_webconsole.algorithm.models import Algorithm
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+
+
+# Centralized RegistrationIn
+class ResourceType(enum.Enum):
+ DATASET = Dataset
+ DATASET_JOB = DatasetJob
+ DATASET_JOB_STAGE = DatasetJobStage
+ MODEL = Model
+ ALGORITHM = Algorithm
+ NO_RESOURCE = None
+
+
+class CleanupState(enum.Enum):
+ WAITING = 'WAITING'
+ RUNNING = 'RUNNING'
+ SUCCEEDED = 'SUCCEEDED'
+ FAILED = 'FAILED'
+ CANCELED = 'CANCELED'
+
+
+class Cleanup(db.Model):
+ __tablename__ = 'cleanups_v2'
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ state = db.Column(db.Enum(CleanupState, native_enum=False, length=64, create_constraint=False),
+ default=CleanupState.WAITING,
+ comment='state')
+ target_start_at = db.Column(db.DateTime(timezone=True), comment='target_start_at')
+ completed_at = db.Column(db.DateTime(timezone=True), comment='completed_at')
+ resource_id = db.Column(db.Integer, comment='resource_id')
+ resource_type = db.Column(db.Enum(ResourceType, native_enum=False, length=64, create_constraint=False),
+ comment='resource_type')
+ _payload = db.Column(db.Text(), name='payload', comment='the underlying resources that need to be cleaned up')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created_at')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ server_default=func.now(),
+ server_onupdate=func.now(),
+ comment='updated_at')
+
+ @property
+ def payload(self) -> CleanupPayload:
+ if not self._payload:
+ return CleanupPayload()
+ return text_format.Parse(self._payload, CleanupPayload())
+
+ @payload.setter
+ def payload(self, payload: CleanupPayload):
+ self._payload = text_format.MessageToString(payload)
+
+ @property
+ def is_cancellable(self):
+ return self.state in [CleanupState.CANCELED, CleanupState.WAITING]
+
+ def to_proto(self) -> CleanupPb:
+ return CleanupPb(id=self.id,
+ state=self.state.name,
+ target_start_at=to_timestamp(self.target_start_at),
+ completed_at=to_timestamp(self.completed_at) if self.completed_at else None,
+ resource_id=self.resource_id,
+ resource_type=self.resource_type.name,
+ payload=self.payload,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at))
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/models_test.py b/web_console_v2/api/fedlearner_webconsole/cleanup/models_test.py
new file mode 100644
index 000000000..86049a4a3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/models_test.py
@@ -0,0 +1,95 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.cleanup.models import Cleanup, CleanupState, ResourceType
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupPayload, CleanupPb
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+
+
+class CleanupTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.default_paylaod = CleanupPayload(paths=['/Major333/test_path/a.csv'])
+ default_cleanup = Cleanup(id=1,
+ state=CleanupState.WAITING,
+ created_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ updated_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ target_start_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ resource_id=100,
+ resource_type=ResourceType.DATASET.name,
+ payload=self.default_paylaod)
+ cleanup_without_resource = Cleanup(id=2,
+ state=CleanupState.WAITING,
+ created_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ updated_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ target_start_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ resource_id=100,
+ resource_type=ResourceType.NO_RESOURCE.name,
+ payload=self.default_paylaod)
+ running_cleanup = Cleanup(id=3,
+ state=CleanupState.RUNNING,
+ created_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ updated_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ target_start_at=datetime(2022, 2, 22, tzinfo=timezone.utc),
+ resource_id=100,
+ resource_type=ResourceType.NO_RESOURCE.name,
+ payload=self.default_paylaod)
+ with db.session_scope() as session:
+ session.add(default_cleanup)
+ session.add(cleanup_without_resource)
+ session.add(running_cleanup)
+ session.commit()
+
+ def test_payload(self):
+ with db.session_scope() as session:
+ cleanup = session.query(Cleanup).get(1)
+ self.assertEqual(cleanup.payload, self.default_paylaod)
+ cleanup.payload = CleanupPayload(paths=['/Major333/test_path/b.csv'])
+ session.add(cleanup)
+ session.commit()
+ with db.session_scope() as session:
+ cleanup = session.query(Cleanup).get(1)
+ self.assertEqual(['/Major333/test_path/b.csv'], cleanup.payload.paths)
+
+ def test_cancellable(self):
+ with db.session_scope() as session:
+ default_cleanup = session.query(Cleanup).get(1)
+ cleanup_without_resource = session.query(Cleanup).get(2)
+ running_cleanup = session.query(Cleanup).get(3)
+ self.assertTrue(default_cleanup.is_cancellable)
+ self.assertTrue(cleanup_without_resource.is_cancellable)
+ self.assertFalse(running_cleanup.is_cancellable)
+
+ def test_to_proto(self):
+ expected_cleanup_proto = CleanupPb(id=1,
+ state='WAITING',
+ target_start_at=to_timestamp(datetime(2022, 2, 22, tzinfo=timezone.utc)),
+ resource_id=100,
+ resource_type='DATASET',
+ payload=self.default_paylaod,
+ updated_at=to_timestamp(datetime(2022, 2, 22, tzinfo=timezone.utc)),
+ created_at=to_timestamp(datetime(2022, 2, 22, tzinfo=timezone.utc)))
+ with db.session_scope() as session:
+ cleanup_proto = session.query(Cleanup).get(1).to_proto()
+ self.assertEqual(cleanup_proto, expected_cleanup_proto)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/services.py b/web_console_v2/api/fedlearner_webconsole/cleanup/services.py
new file mode 100644
index 000000000..57124ad6e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/services.py
@@ -0,0 +1,83 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Optional
+from datetime import datetime, timezone
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.exceptions import InvalidArgumentException, NotFoundException
+from fedlearner_webconsole.cleanup.models import Cleanup, CleanupState
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupParameter, CleanupPb
+from fedlearner_webconsole.proto.filtering_pb2 import FilterOp, FilterExpression
+from fedlearner_webconsole.utils.paginate import Pagination, paginate
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder
+
+
+class CleanupService():
+
+ FILTER_FIELDS = {
+ 'state': SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'resource_type': SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'resource_id': SupportedField(type=FieldType.NUMBER, ops={FilterOp.EQUAL: None}),
+ }
+
+ def __init__(self, session: Session):
+ self._session = session
+ self._filter_builder = FilterBuilder(model_class=Cleanup, supported_fields=self.FILTER_FIELDS)
+
+ def get_cleanups(self,
+ page: Optional[int] = None,
+ page_size: Optional[int] = None,
+ filter_exp: Optional[FilterExpression] = None) -> Pagination:
+ query = self._session.query(Cleanup)
+ if filter_exp:
+ query = self._filter_builder.build_query(query, filter_exp)
+ query = query.order_by(Cleanup.created_at.desc())
+ return paginate(query, page, page_size)
+
+ def get_cleanup(self, cleanup_id: int = 0) -> CleanupPb:
+ cleanup = self._session.query(Cleanup).get(cleanup_id)
+ if not cleanup:
+ raise NotFoundException(f'Failed to find cleanup: {cleanup_id}')
+ return cleanup.to_proto()
+
+ def create_cleanup(self, cleanup_parmeter: CleanupParameter) -> Cleanup:
+ cleanup = Cleanup(
+ state=CleanupState.WAITING,
+ resource_id=cleanup_parmeter.resource_id,
+ resource_type=cleanup_parmeter.resource_type,
+ target_start_at=datetime.fromtimestamp(cleanup_parmeter.target_start_at, tz=timezone.utc),
+ payload=cleanup_parmeter.payload,
+ )
+ self._session.add(cleanup)
+ return cleanup
+
+ def _cancel_cleanup(self, cleanup: Cleanup) -> CleanupPb:
+ if not cleanup.is_cancellable:
+ error_msg = f'cleanup: {cleanup.id} can not be canceled'
+ logging.error(error_msg)
+ raise InvalidArgumentException(error_msg)
+ cleanup.state = CleanupState.CANCELED
+ return cleanup.to_proto()
+
+ def cancel_cleanup_by_id(self, cleanup_id: int = 0) -> CleanupPb:
+ # apply exclusive lock on cleanup to avoid race condition on updating its state
+ cleanup = self._session.query(Cleanup).populate_existing().with_for_update().filter(
+ Cleanup.id == cleanup_id).first()
+ if not cleanup:
+ error_msg = f'there is no cleanup with cleanup_id:{cleanup_id}'
+ logging.error(error_msg)
+ raise InvalidArgumentException(error_msg)
+ return self._cancel_cleanup(cleanup)
diff --git a/web_console_v2/api/fedlearner_webconsole/cleanup/services_test.py b/web_console_v2/api/fedlearner_webconsole/cleanup/services_test.py
new file mode 100644
index 000000000..516d99c34
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/cleanup/services_test.py
@@ -0,0 +1,135 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.cleanup.services import CleanupService
+from fedlearner_webconsole.cleanup.models import Cleanup, CleanupState, ResourceType
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupPayload, CleanupParameter, CleanupPb
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterExpressionKind, SimpleExpression, FilterOp
+from fedlearner_webconsole.dataset.models import Dataset, DatasetType, DatasetKindV2
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+
+
+class CleanupServiceTest(NoWebServerTestCase):
+ _TARGET_START_AT = datetime(2022, 2, 22, 10, 10, 12, tzinfo=timezone.utc)
+ _CREATED_AT = datetime(2022, 2, 22, 3, 3, 4, tzinfo=timezone.utc)
+
+ def setUp(self):
+ super().setUp()
+ self.default_paylaod = CleanupPayload(paths=['/Major333/test_path/a.csv'])
+ self.deafult_dataset = Dataset(
+ id=100,
+ uuid=resource_uuid(),
+ name='dataset_1',
+ dataset_type=DatasetType.PSI,
+ project_id=100,
+ dataset_kind=DatasetKindV2.RAW,
+ path='/data/dataset_1/',
+ )
+ default_cleanup = Cleanup(id=1,
+ state=CleanupState.WAITING,
+ created_at=self._CREATED_AT,
+ updated_at=self._CREATED_AT,
+ target_start_at=self._TARGET_START_AT,
+ resource_id=100,
+ resource_type=ResourceType.DATASET.name,
+ payload=self.default_paylaod)
+ default_cleanup_2 = Cleanup(id=2,
+ state=CleanupState.WAITING,
+ created_at=self._CREATED_AT,
+ updated_at=self._CREATED_AT,
+ target_start_at=self._TARGET_START_AT,
+ resource_id=100,
+ resource_type=ResourceType.NO_RESOURCE.name,
+ payload=self.default_paylaod)
+ with db.session_scope() as session:
+ session.add(self.deafult_dataset)
+ session.add(default_cleanup)
+ session.add(default_cleanup_2)
+ session.commit()
+
+ def test_get_cleanups(self):
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='state',
+ op=FilterOp.EQUAL,
+ string_value='WAITING',
+ ))
+ with db.session_scope() as session:
+ service = CleanupService(session)
+ paginations = service.get_cleanups(filter_exp=filter_exp, page=1, page_size=2)
+ cleanup_ids = [cleanup.id for cleanup in paginations.get_items()]
+ self.assertEqual(cleanup_ids, [1, 2])
+
+ def test_get_cleanup(self):
+ expected_cleanup_display_proto = CleanupPb(id=1,
+ state='WAITING',
+ target_start_at=to_timestamp(self._TARGET_START_AT),
+ resource_id=100,
+ resource_type='DATASET',
+ payload=self.default_paylaod,
+ updated_at=to_timestamp(self._CREATED_AT),
+ created_at=to_timestamp(self._CREATED_AT))
+ with db.session_scope() as session:
+ service = CleanupService(session)
+ cleanup = service.get_cleanup(cleanup_id=1)
+ self.assertEqual(cleanup, expected_cleanup_display_proto)
+
+ def test_get_cleanup_without_resource_type(self):
+ expected_cleanup_display_proto = CleanupPb(id=2,
+ state='WAITING',
+ target_start_at=to_timestamp(self._TARGET_START_AT),
+ resource_id=100,
+ resource_type='NO_RESOURCE',
+ payload=self.default_paylaod,
+ updated_at=to_timestamp(self._CREATED_AT),
+ created_at=to_timestamp(self._CREATED_AT))
+ with db.session_scope() as session:
+ service = CleanupService(session)
+ cleanup = service.get_cleanup(cleanup_id=2)
+ self.assertEqual(cleanup, expected_cleanup_display_proto)
+
+ def test_create_cleanup(self):
+ cleanup_parm = CleanupParameter(resource_id=1011,
+ resource_type='DATASET',
+ target_start_at=to_timestamp(self._TARGET_START_AT),
+ payload=self.default_paylaod)
+ with db.session_scope() as session:
+ cleanup = CleanupService(session).create_cleanup(cleanup_parm)
+ session.commit()
+ cleanup_id = cleanup.id
+ with db.session_scope() as session:
+ created_cleanup: Cleanup = session.query(Cleanup).get(cleanup_id)
+ self.assertEqual(created_cleanup.resource_type, ResourceType.DATASET)
+ self.assertEqual(created_cleanup.resource_id, 1011)
+ self.assertEqual(to_timestamp(created_cleanup.target_start_at), to_timestamp(self._TARGET_START_AT))
+ self.assertEqual(created_cleanup.payload, self.default_paylaod)
+
+ def test_cancel_cleanup_by_id(self):
+ with db.session_scope() as session:
+ service = CleanupService(session)
+ service.cancel_cleanup_by_id(1)
+ session.commit()
+ with db.session_scope() as session:
+ cleanup = session.query(Cleanup).get(1)
+ self.assertEqual(cleanup.state, CleanupState.CANCELED)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/composer/BUILD.bazel
new file mode 100644
index 000000000..460a2c722
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/BUILD.bazel
@@ -0,0 +1,253 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "composer_service_lib",
+ srcs = ["composer_service.py"],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_croniter//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "composer_service_lib_test",
+ srcs = [
+ "composer_service_test.py",
+ ],
+ imports = ["../.."],
+ main = "composer_service_test.py",
+ deps = [
+ ":common_lib",
+ ":composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "common_lib",
+ srcs = [
+ "context.py",
+ "interface.py",
+ "models.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:mixins_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_croniter//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "context_test",
+ srcs = [
+ "context_test.py",
+ ],
+ imports = ["../.."],
+ main = "context_test.py",
+ deps = [
+ ":common_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "models_test",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":common_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "runner_lib",
+ srcs = [
+ "runner.py",
+ ],
+ imports = ["../.."],
+ visibility = ["//visibility:private"],
+ deps = [
+ ":common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/cleanup:cleaner_cronjob_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:batch_stats_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/scheduler",
+ "//web_console_v2/api/fedlearner_webconsole/job:scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:cronjob_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:project_scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/serving:runners_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:runners_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:cronjob_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_scheduler_lib",
+ ],
+)
+
+py_library(
+ name = "composer_lib",
+ srcs = [
+ "composer.py",
+ "context.py",
+ "op_locker.py",
+ "pipeline.py",
+ "strategy.py",
+ "thread_reaper.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ ":runner_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_time_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio_health_checking/grpc_health/v1:grpc_health",
+ "@com_github_grpc_grpc//src/python/grpcio_health_checking/grpc_health/v1:health_py_pb2",
+ "@com_github_grpc_grpc//src/python/grpcio_health_checking/grpc_health/v1:health_py_pb2_grpc",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "composer_test",
+ srcs = [
+ "composer_test.py",
+ ],
+ imports = ["../.."],
+ main = "composer_test.py",
+ deps = [
+ ":composer_lib",
+ ":composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/composer",
+ ],
+)
+
+py_test(
+ name = "op_locker_test",
+ srcs = [
+ "op_locker_test.py",
+ ],
+ imports = ["../.."],
+ main = "op_locker_test.py",
+ deps = [
+ ":common_lib",
+ ":composer_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "pipeline_test",
+ srcs = [
+ "pipeline_test.py",
+ ],
+ imports = ["../.."],
+ main = "pipeline_test.py",
+ deps = [
+ ":common_lib",
+ ":composer_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/composer",
+ ],
+)
+
+py_test(
+ name = "strategy_test",
+ srcs = [
+ "strategy_test.py",
+ ],
+ imports = ["../.."],
+ main = "strategy_test.py",
+ deps = [
+ ":common_lib",
+ ":composer_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/composer",
+ ],
+)
+
+py_test(
+ name = "thread_reaper_test",
+ srcs = [
+ "thread_reaper_test.py",
+ ],
+ # It's unpredictable to keep first runner running when same runner are submitted again.
+ # Ref: web_console_v2/api/fedlearner_webconsole/composer/thread_reaper_test.py:ThreadReaperTest.test_submit
+ flaky = True,
+ imports = ["../.."],
+ main = "thread_reaper_test.py",
+ deps = [
+ ":common_lib",
+ ":composer_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/composer",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ ":composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/apis.py b/web_console_v2/api/fedlearner_webconsole/composer/apis.py
new file mode 100644
index 000000000..cb7aec460
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/apis.py
@@ -0,0 +1,246 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from flask_restful import Resource, Api
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.models import ItemStatus, SchedulerItem
+from fedlearner_webconsole.composer.composer_service import SchedulerItemService, SchedulerRunnerService
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required
+from fedlearner_webconsole.utils.flask_utils import make_flask_response, FilterExpField
+from fedlearner_webconsole.exceptions import InvalidArgumentException, NotFoundException
+from webargs.flaskparser import use_kwargs, use_args
+from marshmallow import Schema, fields, validate
+
+
+class ListSchedulerItemsParams(Schema):
+ filter = FilterExpField(required=False, load_default=None)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
+
+
+class ListSchedulerRunnersParams(Schema):
+ filter = FilterExpField(required=False, load_default=None)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
+
+
+class SchedulerItemsApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_args(ListSchedulerItemsParams(), location='query')
+ def get(self, params: dict):
+ """Get a list of all scheduler items.
+ ---
+ tags:
+ - composer
+ description: Get a list of all scheduler items.
+ parameters:
+ - in: query
+ name: filter
+ schema:
+ type: string
+ required: false
+ - in: query
+ name: page
+ schema:
+ type: integer
+ required: false
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ required: false
+ responses:
+ 200:
+ description: Get a list of all scheduler items result
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SchedulerItemPb'
+ """
+ with db.session_scope() as session:
+ try:
+ pagination = SchedulerItemService(session).get_scheduler_items(
+ filter_exp=params['filter'],
+ page=params['page'],
+ page_size=params['page_size'],
+ )
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ data = [t.to_proto() for t in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+
+class SchedulerItemApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_args(ListSchedulerRunnersParams(), location='query')
+ def get(self, params: dict, item_id: int):
+ """Get all scheduler runners by item_id
+ ---
+ tags:
+ - composer
+ description: Get all scheduler runners by item_id
+ parameters:
+ - in: path
+ name: item_id
+ schema:
+ type: integer
+ required: true
+ description: The ID of the scheduler item.
+ - in: query
+ name: filter
+ schema:
+ type: string
+ required: false
+ - in: query
+ name: page
+ schema:
+ type: integer
+ required: false
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ required: false
+ responses:
+ 200:
+ description: Get all scheduler runners by item_id
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SchedulerRunnerPb'
+ """
+ with db.session_scope() as session:
+ try:
+ pagination = SchedulerRunnerService(session).get_scheduler_runners(filter_exp=params['filter'],
+ item_id=item_id,
+ page=params['page'],
+ page_size=params['page_size'])
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ data = [t.to_proto() for t in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+ @credentials_required
+ @admin_required
+ @use_kwargs(
+ {'status': fields.Str(required=True, validate=validate.OneOf([ItemStatus.ON.name, ItemStatus.OFF.name]))},
+ location='json')
+ def patch(self, item_id: int, status: str):
+ """Change status of a scheduler item
+ ---
+ tags:
+ - composer
+ description: change SchedulerItem status
+ parameters:
+ - in: path
+ required: true
+ name: item_id
+ schema:
+ type: integer
+ required: false
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ status:
+ type: string
+ responses:
+ 200:
+ description: The SchedulerItem's status has been updated
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SchedulerItemPb'
+ 400:
+ description: The param of state in the request body is invalid
+ 404:
+ description: The scheduleritem with specified ID is not found
+ """
+ with db.session_scope() as session:
+ scheduler_item = session.query(SchedulerItem).filter_by(id=item_id).first()
+ if not scheduler_item:
+ raise NotFoundException(f'Failed to find scheduler_item: {item_id}')
+ try:
+ scheduler_item.status = ItemStatus[status].value
+ session.commit()
+ except ValueError as e:
+ raise InvalidArgumentException(f'Invalid argument for Status: {status}') from e
+ return make_flask_response(scheduler_item.to_proto())
+
+
+class SchedulerRunnersApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_args(ListSchedulerRunnersParams(), location='query')
+ def get(self, params: dict):
+ """Get a list of all scheduler runners
+ ---
+ tags:
+ - composer
+ description: get scheduler runners list
+ parameters:
+ - in: query
+ name: filter
+ schema:
+ type: string
+ required: false
+ - in: query
+ name: page
+ schema:
+ type: integer
+ required: false
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ required: false
+ responses:
+ 200:
+ description: Get scheduler runners list result
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SchedulerRunnerPb'
+ """
+ with db.session_scope() as session:
+ try:
+ pagination = SchedulerRunnerService(session).get_scheduler_runners(filter_exp=params['filter'],
+ page=params['page'],
+ page_size=params['page_size'])
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ data = [t.to_proto() for t in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+
+def initialize_composer_apis(api: Api):
+ api.add_resource(SchedulerItemsApi, '/scheduler_items')
+ api.add_resource(SchedulerItemApi, '/scheduler_items/')
+ api.add_resource(SchedulerRunnersApi, '/scheduler_runners')
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/apis_test.py b/web_console_v2/api/fedlearner_webconsole/composer/apis_test.py
new file mode 100644
index 000000000..965534e25
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/apis_test.py
@@ -0,0 +1,236 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import urllib.parse
+from http import HTTPStatus
+from datetime import datetime
+from testing.common import BaseTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.models import SchedulerItem, SchedulerRunner, ItemStatus, RunnerStatus
+import json
+from fedlearner_webconsole.utils.proto import to_json
+from fedlearner_webconsole.proto import composer_pb2
+
+
+class SchedulerItemsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ # ids from 1001 to 1004 work as cron_job with status "ON"
+ scheduler_item_off = SchedulerItem(id=1001,
+ name='test_item_off',
+ status=ItemStatus.OFF.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ scheduler_item_on = SchedulerItem(id=1002,
+ name='test_item_on',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ scheduler_item_on_cron = SchedulerItem(id=1003,
+ name='test_item_on_cron',
+ cron_config='*/20 * * * *',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ scheduler_item_off_cron = SchedulerItem(id=1004,
+ name='test_item_off_cron',
+ cron_config='*/20 * * * *',
+ status=ItemStatus.OFF.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ scheduler_item_for_id_test = SchedulerItem(id=201,
+ name='scheduler_item_for_id_test',
+ cron_config='*/20 * * * *',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+
+ with db.session_scope() as session:
+ session.add(scheduler_item_on)
+ session.add(scheduler_item_off)
+ session.add(scheduler_item_on_cron)
+ session.add(scheduler_item_off_cron)
+ session.add(scheduler_item_for_id_test)
+ session.commit()
+ self.signin_as_admin()
+
+ def test_get_with_pagination(self):
+ response = self.get_helper('/api/v2/scheduler_items?page=1&page_size=1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['id'], 1004)
+
+ def test_get_with_invalid_filter(self):
+ response = self.get_helper('/api/v2/scheduler_items?filter=invalid')
+ self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
+
+ def test_get_with_id_filter(self):
+ filter_exp = urllib.parse.quote('(id=201)')
+ response = self.get_helper(f'/api/v2/scheduler_items?filter={filter_exp}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'scheduler_item_for_id_test')
+
+ def test_get_with_three_filter(self):
+ filter_exp = urllib.parse.quote('(and(is_cron=true)(status="OFF")(name~="item_off_cron"))')
+ response = self.get_helper(f'/api/v2/scheduler_items?filter={filter_exp}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['status'], ItemStatus.OFF.name)
+ self.assertNotEqual(data[0]['cron_config'], '')
+
+ def test_get_with_single_filter(self):
+ filter_exp = urllib.parse.quote('(status="OFF")')
+ response = self.get_helper(f'/api/v2/scheduler_items?filter={filter_exp}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['status'], ItemStatus.OFF.name)
+
+
+class SchedulerItemApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ scheduler_item_on = SchedulerItem(id=100,
+ name='test_item_on',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ self.default_scheduler_item = scheduler_item_on
+ self.default_context = json.dumps({'outputs': {'0': {'job_scheduler_output': {}}}})
+ self.default_output = to_json(composer_pb2.RunnerOutput(error_message='error1'))
+ runner_init = SchedulerRunner(id=0,
+ item_id=100,
+ status=RunnerStatus.INIT.value,
+ context=self.default_context,
+ output=self.default_output)
+ runner_running_1 = SchedulerRunner(id=1,
+ item_id=100,
+ status=RunnerStatus.RUNNING.value,
+ context=self.default_context,
+ output=self.default_output)
+ runner_running_2 = SchedulerRunner(id=2,
+ item_id=100,
+ status=RunnerStatus.RUNNING.value,
+ context=self.default_context,
+ output=self.default_output)
+
+ with db.session_scope() as session:
+ session.add(scheduler_item_on)
+ session.add(runner_init)
+ session.add(runner_running_1)
+ session.add(runner_running_2)
+ session.commit()
+ self.signin_as_admin()
+
+ def test_get_runners_without_pagination(self):
+ response = self.get_helper(f'/api/v2/scheduler_items/{self.default_scheduler_item.id}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 3)
+
+ def test_get_with_pagination(self):
+ response = self.get_helper(f'/api/v2/scheduler_items/{self.default_scheduler_item.id}?page=1&page_size=1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+
+ def test_change_scheduleritem_status(self):
+ get_response = self.patch_helper(f'/api/v2/scheduler_items/{self.default_scheduler_item.id}',
+ data={'status': 'OFF'})
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ dataset = session.query(SchedulerItem).get(self.default_scheduler_item.id)
+ self.assertEqual(dataset.status, ItemStatus.OFF.value)
+
+ def test_change_scheduleritem_status_with_invalid_argument(self):
+ get_response = self.patch_helper(f'/api/v2/scheduler_items/{self.default_scheduler_item.id}',
+ data={'status': 'RUNNING'})
+ self.assertEqual(get_response.status_code, HTTPStatus.BAD_REQUEST)
+ with db.session_scope() as session:
+ dataset = session.query(SchedulerItem).get(self.default_scheduler_item.id)
+ self.assertEqual(dataset.status, ItemStatus.ON.value)
+
+
+class SchedulerRunnersApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ scheduler_item_on = SchedulerItem(id=100,
+ name='test_item_on',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ self.default_scheduler_item = scheduler_item_on
+ self.default_context = json.dumps({'outputs': {'0': {'job_scheduler_output': {}}}})
+ self.default_output = to_json(composer_pb2.RunnerOutput(error_message='error1'))
+ runner_init = SchedulerRunner(id=0,
+ item_id=100,
+ status=RunnerStatus.INIT.value,
+ context=self.default_context,
+ output=self.default_output)
+ runner_running = SchedulerRunner(id=1,
+ item_id=100,
+ status=RunnerStatus.RUNNING.value,
+ context=self.default_context)
+ runner_running_2 = SchedulerRunner(id=2,
+ item_id=100,
+ status=RunnerStatus.RUNNING.value,
+ context=self.default_context)
+ runner_done = SchedulerRunner(id=3, item_id=100, status=RunnerStatus.DONE.value)
+
+ with db.session_scope() as session:
+ session.add(scheduler_item_on)
+ session.add(runner_init)
+ session.add(runner_running)
+ session.add(runner_running_2)
+ session.add(runner_done)
+ session.commit()
+ self.signin_as_admin()
+
+ def test_get_without_filter_and_pagination(self):
+ response = self.get_helper('/api/v2/scheduler_runners')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 4)
+
+ def test_get_with_pagination(self):
+ response = self.get_helper('/api/v2/scheduler_runners?page=1&page_size=1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+
+ def test_get_with_invalid_filter(self):
+ response = self.get_helper('/api/v2/scheduler_runners?filter=invalid')
+ self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
+
+ def test_get_with_filter(self):
+ filter_exp = urllib.parse.quote('(status="RUNNING")')
+ response = self.get_helper(f'/api/v2/scheduler_runners?filter={filter_exp}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['status'], RunnerStatus.RUNNING.name)
+ self.assertEqual(data[1]['status'], RunnerStatus.RUNNING.name)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/composer.py b/web_console_v2/api/fedlearner_webconsole/composer/composer.py
index e0040ba99..e23a0eec1 100644
--- a/web_console_v2/api/fedlearner_webconsole/composer/composer.py
+++ b/web_console_v2/api/fedlearner_webconsole/composer/composer.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,69 +13,53 @@
# limitations under the License.
# coding: utf-8
-
-import json
import logging
-import time
import threading
import traceback
from datetime import datetime
-from typing import List, Optional
+from envs import Envs
from sqlalchemy import func
from sqlalchemy.engine import Engine
+from fedlearner_webconsole.composer.strategy import SingletonStrategy
+from fedlearner_webconsole.proto import composer_pb2
+from fedlearner_webconsole.proto.composer_pb2 import PipelineContextData
+from fedlearner_webconsole.utils import pp_time
from fedlearner_webconsole.db import get_session
from fedlearner_webconsole.composer.runner import global_runner_fn
-from fedlearner_webconsole.composer.runner_cache import RunnerCache
-from fedlearner_webconsole.composer.interface import IItem
-from fedlearner_webconsole.composer.models import Context, decode_context, \
- ContextEncoder, SchedulerItem, ItemStatus, SchedulerRunner, RunnerStatus
+from fedlearner_webconsole.composer.models import SchedulerItem, ItemStatus, SchedulerRunner, RunnerStatus
+from fedlearner_webconsole.composer.pipeline import PipelineExecutor
from fedlearner_webconsole.composer.op_locker import OpLocker
from fedlearner_webconsole.composer.thread_reaper import ThreadReaper
+import grpc
+from concurrent import futures
+from grpc_health.v1 import health
+from grpc_health.v1 import health_pb2_grpc
class ComposerConfig(object):
+
def __init__(
self,
runner_fn: dict,
- name='default_name',
- worker_num=10,
+ name: str = 'default_name',
+ worker_num: int = 20,
):
"""Config for composer
Args:
- runner_fn: runner functions
- name: composer name
- worker_num: number of worker doing heavy job
+ runner_fn (dict): runner functions
+ name (str): composer name
+ worker_num (int): number of worker doing heavy job
"""
self.runner_fn = runner_fn
self.name = name
self.worker_num = worker_num
-class Pipeline(object):
- def __init__(self, name: str, deps: List[str], meta: dict):
- """Define the deps of scheduler item
-
- Fields:
- name: pipeline name
- deps: items to be processed in order
- meta: additional info
- """
- self.name = name
- self.deps = deps
- self.meta = meta
-
-
-class PipelineEncoder(json.JSONEncoder):
- def default(self, obj):
- return obj.__dict__
-
-
class Composer(object):
- # attributes that you can patch
- MUTABLE_ITEM_KEY = ['interval_time', 'retry_cnt']
+ LOOP_INTERVAL = 5
def __init__(self, config: ComposerConfig):
"""Composer
@@ -85,18 +69,37 @@ def __init__(self, config: ComposerConfig):
"""
self.config = config
self.name = config.name
- self.runner_fn = config.runner_fn
self.db_engine = None
self.thread_reaper = ThreadReaper(worker_num=config.worker_num)
- self.runner_cache = RunnerCache(runner_fn=config.runner_fn)
+ self.pipeline_executor = PipelineExecutor(
+ thread_reaper=self.thread_reaper,
+ db_engine=self.db_engine,
+ runner_fns=config.runner_fn,
+ )
self.lock = threading.Lock()
self._stop = False
+ self._loop_thread = None
+ self._grpc_server_thread = None
def run(self, db_engine: Engine):
self.db_engine = db_engine
+ self.pipeline_executor.db_engine = db_engine
logging.info(f'[composer] starting {self.name}...')
- loop = threading.Thread(target=self._loop, args=[], daemon=True)
- loop.start()
+ self._loop_thread = threading.Thread(target=self._loop, args=[], daemon=True)
+ self._loop_thread.start()
+ self._grpc_server_thread = threading.Thread(target=self._run, args=[], daemon=True)
+ self._grpc_server_thread.start()
+
+ def wait_for_termination(self):
+ self._loop_thread.join()
+ self._grpc_server_thread.join()
+
+ def _run(self):
+ grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
+ health_pb2_grpc.add_HealthServicer_to_server(health.HealthServicer(), grpc_server)
+ grpc_server.add_insecure_port(f'[::]:{Envs.COMPOSER_LISTEN_PORT}')
+ grpc_server.start()
+ grpc_server.wait_for_termination()
def _loop(self):
while True:
@@ -111,300 +114,106 @@ def _loop(self):
self._check_init_runners()
self._check_running_runners()
except Exception as e: # pylint: disable=broad-except
- logging.error(f'[composer] something wrong, exception: {e}, '
- f'trace: {traceback.format_exc()}')
- time.sleep(5)
+ logging.error(f'[composer] something wrong, exception: {e}, ' f'trace: {traceback.format_exc()}')
+ pp_time.sleep(self.LOOP_INTERVAL)
def stop(self):
logging.info(f'[composer] stopping {self.name}...')
with self.lock:
self._stop = True
-
- def collect(self,
- name: str,
- items: List[IItem],
- metadata: dict,
- interval: int = -1):
- """Collect scheduler item
-
- Args:
- name: item name, should be unique
- items: specify dependencies
- metadata: pass metadata to share with item dependencies each other
- interval: if value is -1, it's run-once job, or run
- every interval time in seconds
- """
- if len(name) == 0:
- return
- valid_interval = interval == -1 or interval >= 10
- if not valid_interval: # seems non-sense if interval is less than 10
- raise ValueError('interval should not less than 10 if not -1')
- with get_session(self.db_engine) as session:
- # check name if exists
- existed = session.query(SchedulerItem).filter_by(name=name).first()
- if existed:
- return
- item = SchedulerItem(
- name=name,
- pipeline=PipelineEncoder().encode(
- self._build_pipeline(name, items, metadata)),
- interval_time=interval,
- )
- session.add(item)
- try:
- session.commit()
- except Exception as e: # pylint: disable=broad-except
- logging.error(f'[composer] failed to create scheduler_item, '
- f'name: {name}, exception: {e}')
- session.rollback()
-
- def finish(self, name: str):
- """Finish item
-
- Args:
- name: item name
- """
- with get_session(self.db_engine) as session:
- existed = session.query(SchedulerItem).filter_by(
- name=name, status=ItemStatus.ON.value).first()
- if not existed:
- return
- existed.status = ItemStatus.OFF.value
- try:
- session.commit()
- except Exception as e: # pylint: disable=broad-except
- logging.error(f'[composer] failed to finish scheduler_item, '
- f'name: {name}, exception: {e}')
- session.rollback()
-
- def get_item_status(self, name: str) -> Optional[ItemStatus]:
- """Get item status
-
- Args:
- name: item name
- """
- with get_session(self.db_engine) as session:
- existed = session.query(SchedulerItem).filter(
- SchedulerItem.name == name).first()
- if not existed:
- return None
- return ItemStatus(existed.status)
-
- def patch_item_attr(self, name: str, key: str, value: str):
- """ patch item args
-
- Args:
- name (str): name of this item
- key (str): key you want to update
- value (str): value you wnat to set
-
- Returns:
- Raise if some check violates
- """
- if key not in self.__class__.MUTABLE_ITEM_KEY:
- raise ValueError(f'fail to change attribute {key}')
-
- with get_session(self.db_engine) as session:
- item: SchedulerItem = session.query(SchedulerItem).filter(
- SchedulerItem.name == name).first()
- if not item:
- raise ValueError(f'cannot find item {name}')
- setattr(item, key, value)
- session.add(item)
- try:
- session.commit()
- except Exception as e: # pylint: disable=broad-except
- logging.error(f'[composer] failed to patch item attr, '
- f'name: {name}, exception: {e}')
- session.rollback()
-
- def get_recent_runners(self,
- name: str,
- count: int = 10) -> List[SchedulerRunner]:
- """Get recent runners order by created_at in desc
-
- Args:
- name: item name
- count: the number of runners
- """
- with get_session(self.db_engine) as session:
- runners = session.query(SchedulerRunner).join(
- SchedulerItem,
- SchedulerItem.id == SchedulerRunner.item_id).filter(
- SchedulerItem.name == name).order_by(
- SchedulerRunner.created_at.desc()).limit(count)
- if not runners:
- return []
- return runners
+ if self._loop_thread is not None:
+ self._loop_thread.join(timeout=self.LOOP_INTERVAL * 2)
def _check_items(self):
with get_session(self.db_engine) as session:
- items = session.query(SchedulerItem).filter_by(
- status=ItemStatus.ON.value).all()
+ items = session.query(SchedulerItem).filter_by(status=ItemStatus.ON.value).all()
for item in items:
- if not item.need_run():
+ if not SingletonStrategy(session).should_run(item):
+ continue
+
+ pipeline: composer_pb2.Pipeline = item.get_pipeline()
+ if pipeline.version != 2:
+ logging.error(f'[Composer] Invalid pipeline in item {item.id}')
+ item.status = ItemStatus.OFF.value
+ session.commit()
continue
- # NOTE: use `func.now()` to let sqlalchemy handles
+ runner = SchedulerRunner(item_id=item.id)
+ runner.set_pipeline(pipeline)
+ runner.set_context(PipelineContextData())
+ session.add(runner)
+
+ # NOTE: use sqlalchemy's `func.now()` to let it handles
# the timezone.
item.last_run_at = func.now()
- if item.interval_time < 0:
+ if not item.cron_config:
# finish run-once item automatically
item.status = ItemStatus.OFF.value
- pp = Pipeline(**(json.loads(item.pipeline)))
- context = Context(data=pp.meta,
- internal={},
- db_engine=self.db_engine)
- runner = SchedulerRunner(
- item_id=item.id,
- pipeline=item.pipeline,
- context=ContextEncoder().encode(context),
- )
- session.add(runner)
- try:
- logging.info(
- f'[composer] insert runner, item_id: {item.id}')
- session.commit()
- except Exception as e: # pylint: disable=broad-except
- logging.error(
- f'[composer] failed to create scheduler_runner, '
- f'item_id: {item.id}, exception: {e}')
- session.rollback()
+
+ logging.info(f'[composer] insert runner, item_id: {item.id}')
+ session.commit()
def _check_init_runners(self):
with get_session(self.db_engine) as session:
- init_runners = session.query(SchedulerRunner).filter_by(
- status=RunnerStatus.INIT.value).all()
+ init_runner_ids = session.query(SchedulerRunner.id).filter_by(status=RunnerStatus.INIT.value).all()
# TODO: support priority
- for runner in init_runners:
- # if thread_reaper is full, skip this round and
- # wait next checking
- if self.thread_reaper.is_full():
- return
- lock_name = f'check_init_runner_{runner.id}_lock'
- check_lock = OpLocker(lock_name, self.db_engine).try_lock()
- if not check_lock:
- logging.error(f'[composer] failed to lock, '
- f'ignore current init_runner_{runner.id}')
- continue
- pipeline = Pipeline(**(json.loads(runner.pipeline)))
- context = decode_context(val=runner.context,
- db_engine=self.db_engine)
- # find the first job in pipeline
- first = pipeline.deps[0]
+ for runner_id, *_ in init_runner_ids:
+ # if thread_reaper is full, skip this round and
+ # wait next checking
+ if self.thread_reaper.is_full():
+ logging.info('[composer] thread_reaper is full now, waiting for other item finish')
+ return
+ lock_name = f'check_init_runner_{runner_id}_lock'
+ check_lock = OpLocker(lock_name, self.db_engine).try_lock()
+ if not check_lock:
+ logging.error(f'[composer] failed to lock, ignore current init_runner_{runner_id}')
+ continue
+ with get_session(self.db_engine) as session:
+ runner: SchedulerRunner = session.query(SchedulerRunner).get(runner_id)
# update status
runner.start_at = func.now()
runner.status = RunnerStatus.RUNNING.value
- output = json.loads(runner.output)
- output[first] = {'status': RunnerStatus.RUNNING.value}
- runner.output = json.dumps(output)
- # record current running job
- context.set_internal('current', first)
- runner.context = ContextEncoder().encode(context)
- # start runner
- runner_fn = self.runner_cache.find_runner(runner.id, first)
- self.thread_reaper.enqueue(name=lock_name,
- fn=runner_fn,
- context=context)
+ pipeline: composer_pb2.Pipeline = runner.get_pipeline()
+ if pipeline.version != 2:
+ logging.error(f'[Composer] Invalid pipeline in runner {runner.id}')
+ runner.status = RunnerStatus.FAILED.value
+ session.commit()
+ continue
try:
- logging.info(
- f'[composer] update runner, status: {runner.status}, '
- f'pipeline: {runner.pipeline}, '
- f'output: {output}, context: {runner.context}')
+ logging.info(f'[composer] update runner, status: {runner.status}, '
+ f'pipeline: {runner.pipeline}, '
+ f'context: {runner.context}')
if check_lock.is_latest_version() and \
check_lock.update_version():
session.commit()
else:
- logging.error(f'[composer] {lock_name} is outdated, '
- f'ignore updates to database')
+ logging.error(f'[composer] {lock_name} is outdated, ignore updates to database')
except Exception as e: # pylint: disable=broad-except
- logging.error(f'[composer] failed to update init runner'
- f'status, exception: {e}')
+ logging.error(f'[composer] failed to update init runner status, exception: {e}')
session.rollback()
def _check_running_runners(self):
with get_session(self.db_engine) as session:
- running_runners = session.query(SchedulerRunner).filter_by(
- status=RunnerStatus.RUNNING.value).all()
- for runner in running_runners:
- if self.thread_reaper.is_full():
- return
- lock_name = f'check_running_runner_{runner.id}_lock'
- check_lock = OpLocker(lock_name, self.db_engine).try_lock()
- if not check_lock:
- logging.error(f'[composer] failed to lock, '
- f'ignore current running_runner_{runner.id}')
- continue
+ running_runner_ids = session.query(SchedulerRunner.id).filter_by(status=RunnerStatus.RUNNING.value).all()
+ for runner_id, *_ in running_runner_ids:
+ if self.thread_reaper.is_full():
+ logging.info('[composer] thread_reaper is full now, waiting for other item finish')
+ return
+ lock_name = f'check_running_runner_{runner_id}_lock'
+ check_lock = OpLocker(lock_name, self.db_engine).try_lock()
+ if not check_lock:
+ logging.error(f'[composer] failed to lock, ' f'ignore current running_runner_{runner_id}')
+ continue
+ with get_session(self.db_engine) as session:
# TODO: restart runner if exit unexpectedly
- pipeline = Pipeline(**(json.loads(runner.pipeline)))
- output = json.loads(runner.output)
- context = decode_context(val=runner.context,
- db_engine=self.db_engine)
- current = context.internal['current']
- runner_fn = self.runner_cache.find_runner(runner.id, current)
- # check status of current one
- status, current_output = runner_fn.result(context)
- if status == RunnerStatus.RUNNING:
- continue # ignore
- if status == RunnerStatus.DONE:
- output[current] = {'status': RunnerStatus.DONE.value}
- context.set_internal(f'output_{current}', current_output)
- current_idx = pipeline.deps.index(current)
- if current_idx == len(pipeline.deps) - 1: # all done
- runner.status = RunnerStatus.DONE.value
- runner.end_at = func.now()
- else: # run next one
- next_one = pipeline.deps[current_idx + 1]
- output[next_one] = {
- 'status': RunnerStatus.RUNNING.value
- }
- context.set_internal('current', next_one)
- next_runner_fn = self.runner_cache.find_runner(
- runner.id, next_one)
- self.thread_reaper.enqueue(name=lock_name,
- fn=next_runner_fn,
- context=context)
- elif status == RunnerStatus.FAILED:
- # TODO: abort now, need retry
- output[current] = {'status': RunnerStatus.FAILED.value}
- context.set_internal(f'output_{current}', current_output)
+ runner = session.query(SchedulerRunner).get(runner_id)
+ pipeline = runner.get_pipeline()
+ if pipeline.version != 2:
+ logging.error(f'[Composer] Invalid pipeline in runner {runner.id}')
runner.status = RunnerStatus.FAILED.value
- runner.end_at = func.now()
-
- runner.pipeline = PipelineEncoder().encode(pipeline)
- runner.output = json.dumps(output)
- runner.context = ContextEncoder().encode(context)
-
- updated_db = False
- try:
- logging.info(
- f'[composer] update runner, status: {runner.status}, '
- f'pipeline: {runner.pipeline}, '
- f'output: {output}, context: {runner.context}')
- if check_lock.is_latest_version():
- if check_lock.update_version():
- session.commit()
- updated_db = True
- else:
- logging.error(f'[composer] {lock_name} is outdated, '
- f'ignore updates to database')
- except Exception as e: # pylint: disable=broad-except
- logging.error(f'[composer] failed to update running '
- f'runner status, exception: {e}')
- session.rollback()
-
- # delete useless runner obj in runner cache
- if status in (RunnerStatus.DONE,
- RunnerStatus.FAILED) and updated_db:
- self.runner_cache.del_runner(runner.id, current)
-
- @staticmethod
- def _build_pipeline(name: str, items: List[IItem],
- metadata: dict) -> Pipeline:
- deps = []
- for item in items:
- deps.append(f'{item.type().value}_{item.get_id()}')
- return Pipeline(name=name, deps=deps, meta=metadata)
+ session.commit()
+ continue
+ # If the runner is running, we always try to run it.
+ self.pipeline_executor.run(runner_id)
-composer = Composer(config=ComposerConfig(
- runner_fn=global_runner_fn(), name='scheduler for fedlearner webconsole'))
+composer = Composer(config=ComposerConfig(runner_fn=global_runner_fn(), name='scheduler for fedlearner webconsole'))
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/composer_service.py b/web_console_v2/api/fedlearner_webconsole/composer/composer_service.py
new file mode 100644
index 000000000..51c213c1e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/composer_service.py
@@ -0,0 +1,276 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=global-statement
+# coding: utf-8
+
+import logging
+from typing import List, Optional, Tuple
+from sqlalchemy.sql.schema import Column
+from sqlalchemy.sql import func
+from sqlalchemy.orm import Session
+from croniter import croniter
+from fedlearner_webconsole.composer.models import ItemStatus, RunnerStatus
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import SchedulerItem, SchedulerRunner
+from fedlearner_webconsole.proto import composer_pb2
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder
+from fedlearner_webconsole.proto.filtering_pb2 import FilterOp, FilterExpression, SimpleExpression
+from fedlearner_webconsole.utils.paginate import Pagination, paginate
+
+
+def _contains_case_insensitive(exp: SimpleExpression):
+ c: Column = getattr(SchedulerItem, exp.field)
+ return c.ilike(f'%{exp.string_value}%')
+
+
+def _is_cron(exp: SimpleExpression):
+ c: Column = SchedulerItem.cron_config
+ if exp.bool_value:
+ exp.string_value = '*'
+ return c.ilike(f'%{exp.string_value}%')
+ return c
+
+
+def _equal_item_status(exp: SimpleExpression):
+ c: Column = SchedulerItem.status
+ return c == ItemStatus[exp.string_value].value
+
+
+def _equal_runner_status(exp: SimpleExpression):
+ c: Column = SchedulerRunner.status
+ return c == RunnerStatus[exp.string_value].value
+
+
+class ComposerService(object):
+ # attributes that you can patch
+ MUTABLE_ITEM_KEY = ['cron_config', 'retry_cnt']
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def get_item_status(self, name: str) -> Optional[ItemStatus]:
+ """Get item status
+
+ Args:
+ name (str): item name
+
+ Returns:
+ ItemStatus: item status
+ """
+ existed = self._session.query(SchedulerItem).filter(SchedulerItem.name == name).first()
+ if not existed:
+ return None
+ return ItemStatus(existed.status)
+
+ def patch_item_attr(self, name: str, key: str, value: str):
+ """ patch item args
+
+ Args:
+ name (str): name of this item
+ key (str): key you want to update
+ value (str): value you want to set
+
+ Raises:
+ ValueError: if some check violates
+ Exception: if session failed
+ """
+ if key not in self.__class__.MUTABLE_ITEM_KEY:
+ raise ValueError(f'fail to change attribute {key}')
+
+ # TODO(linfan.fine): add validations
+ item: SchedulerItem = self._session.query(SchedulerItem).filter(SchedulerItem.name == name).first()
+ if not item:
+ raise ValueError(f'cannot find item {name}')
+ setattr(item, key, value)
+ self._session.add(item)
+ try:
+ self._session.flush()
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'[composer] failed to patch item attr, ' f'name: {name}, exception: {e}')
+ raise e
+
+ def collect_v2(self, name: str, items: List[Tuple[ItemType, RunnerInput]], cron_config: Optional[str] = None):
+ """Collect scheduler item.
+
+ Args:
+ name (str): item name, should be unique
+ items (List[Tuple[IItem, RunnerInput]): specify the execution pipeline (in order)
+ cron_config (Optional[str]): a cron expression for running item periodically
+
+ Raises:
+ ValueError: if `cron_config` is invalid
+ Exception: if db session failed
+ """
+ if len(name) == 0:
+ return
+ if cron_config and not croniter.is_valid(cron_config):
+ raise ValueError('invalid cron_config')
+ # check name if exists
+ existed = self._session.query(SchedulerItem.id).filter_by(name=name).first()
+ if existed:
+ logging.warning('SchedulerItem %s already existed', name)
+ return
+ scheduler_item = SchedulerItem(name=name, cron_config=cron_config, created_at=func.now())
+ queue = []
+ for item_type, rinput in items:
+ runner_input = RunnerInput(runner_type=item_type.value)
+ runner_input.MergeFrom(rinput)
+ queue.append(runner_input)
+ pipeline = composer_pb2.Pipeline(version=2, name=name, queue=queue)
+ scheduler_item.set_pipeline(pipeline)
+ self._session.add(scheduler_item)
+ try:
+ self._session.flush()
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'[composer] failed to create scheduler_item, name: {name}, exception: {e}')
+ raise e
+
+ def start(self, name: str):
+ """Enable an OFF scheduler item"""
+ existed = self._session.query(SchedulerItem).filter_by(name=name).first()
+ existed.status = ItemStatus.ON.value
+
+ def finish(self, name: str):
+ """Finish item
+
+ Args:
+ name (str): item name
+
+ Raises:
+ Exception: if db session failed
+ """
+ existed = self._session.query(SchedulerItem).filter_by(name=name, status=ItemStatus.ON.value).first()
+ if not existed:
+ return
+ existed.status = ItemStatus.OFF.value
+ self._session.query(SchedulerRunner).filter(
+ SchedulerRunner.item_id == existed.id,
+ SchedulerRunner.status.in_([RunnerStatus.INIT.value, RunnerStatus.RUNNING.value])).delete()
+ try:
+ self._session.flush()
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'[composer] failed to finish scheduler_item, ' f'name: {name}, exception: {e}')
+ raise e
+
+ def get_recent_runners(self, name: str, count: int = 10) -> List[SchedulerRunner]:
+ """Get recent runners order by created_at in desc
+
+ Args:
+ name (str): item name
+ count (int): the number of runners
+
+ Returns:
+ List[SchedulerRunner]: list of SchedulerRunner
+ """
+ runners = self._session.query(SchedulerRunner).join(
+ SchedulerItem, SchedulerItem.id == SchedulerRunner.item_id).filter(SchedulerItem.name == name).order_by(
+ SchedulerRunner.created_at.desc()).limit(count).all()
+ if not runners:
+ return []
+ return runners
+
+
+class CronJobService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def start_cronjob(self, item_name: str, items: List[Tuple[ItemType, RunnerInput]], cron_config: str):
+ """Starts a cronjob if cron_config is valid.
+
+ Args:
+ item_name (str): name of scheduler item
+ items: list of scheduler items with inputs
+ cron_config (str): cron expression;
+
+ Raises:
+ Raise if some check violates
+ InvalidArgumentException: if some check violates
+ """
+ if not croniter.is_valid(cron_config):
+ raise InvalidArgumentException(f'cron config {cron_config} is not valid')
+ service = ComposerService(self._session)
+ status = service.get_item_status(name=item_name)
+ # create a cronjob
+ if status is None:
+ service.collect_v2(name=item_name, items=items, cron_config=cron_config)
+ return
+ if status == ItemStatus.OFF:
+ logging.info(f'[start_cronjob] start composer item {item_name}')
+ service.start(name=item_name)
+ # patch a cronjob
+ try:
+ service.patch_item_attr(name=item_name, key='cron_config', value=cron_config)
+ except ValueError as err:
+ emit_store('path_item_attr_error', 1)
+ raise InvalidArgumentException(details=repr(err)) from err
+
+ def stop_cronjob(self, item_name: str):
+ service = ComposerService(self._session)
+ logging.info(f'[start_or_stop_cronjob] finish composer item {item_name}')
+ service.finish(name=item_name)
+
+
+class SchedulerItemService():
+ """ 'is_cron' param means whether should only display cron-jobs. """
+ FILTER_FIELDS = {
+ 'is_cron': SupportedField(type=FieldType.BOOL, ops={FilterOp.EQUAL: _is_cron}),
+ 'status': SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: _equal_item_status}),
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: _contains_case_insensitive}),
+ 'id': SupportedField(type=FieldType.NUMBER, ops={FilterOp.EQUAL: None})
+ }
+
+ def __init__(self, session: Session):
+ self._session = session
+ self._filter_builder = FilterBuilder(model_class=SchedulerItem, supported_fields=self.FILTER_FIELDS)
+
+ def get_scheduler_items(self,
+ page: Optional[int] = None,
+ page_size: Optional[int] = None,
+ filter_exp: Optional[FilterExpression] = None) -> Pagination:
+ query = self._session.query(SchedulerItem)
+ if filter_exp:
+ query = self._filter_builder.build_query(query, filter_exp)
+ query = query.order_by(SchedulerItem.id.desc())
+ return paginate(query, page, page_size)
+
+
+class SchedulerRunnerService():
+ FILTER_FIELDS = {
+ 'status': SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: _equal_runner_status}),
+ }
+
+ def __init__(self, session: Session):
+ self._session = session
+ self._filter_builder = FilterBuilder(model_class=SchedulerRunner, supported_fields=self.FILTER_FIELDS)
+
+ def get_scheduler_runners(self,
+ item_id: Optional[int] = None,
+ page: Optional[int] = None,
+ page_size: Optional[int] = None,
+ filter_exp: Optional[FilterExpression] = None) -> Pagination:
+ # runner_status used as index to optimize sql query
+ # id.desc better than created_at.desc for index can be used
+ query = self._session.query(SchedulerRunner).order_by(
+ SchedulerRunner.id.desc()).filter(SchedulerRunner.status > -1)
+ if filter_exp:
+ query = self._filter_builder.build_query(query, filter_exp)
+ if item_id is not None:
+ query = query.filter_by(item_id=item_id)
+
+ return paginate(query, page, page_size)
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/composer_service_test.py b/web_console_v2/api/fedlearner_webconsole/composer/composer_service_test.py
new file mode 100644
index 000000000..cb8ed76ad
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/composer_service_test.py
@@ -0,0 +1,279 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+from datetime import datetime
+
+import sys
+
+from fedlearner_webconsole.composer.composer_service import (ComposerService, CronJobService, SchedulerItemService,
+ SchedulerRunnerService)
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import RunnerStatus, SchedulerItem, ItemStatus, SchedulerRunner
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, Pipeline, ModelTrainingCronJobInput
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterExpressionKind, SimpleExpression, FilterOp
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ComposerServiceTest(NoWebServerTestCase):
+
+ def test_collect_v2(self):
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ service.collect_v2('test_item', [(ItemType.TASK, RunnerInput()), (ItemType.TASK, RunnerInput())])
+ session.commit()
+ with db.session_scope() as session:
+ item = session.query(SchedulerItem).filter(SchedulerItem.name == 'test_item').first()
+ self.assertEqual(item.status, ItemStatus.ON.value)
+ self.assertIsNone(item.cron_config)
+ self.assertEqual(
+ item.get_pipeline(),
+ Pipeline(
+ version=2,
+ name='test_item',
+ queue=[RunnerInput(runner_type=ItemType.TASK.value),
+ RunnerInput(runner_type=ItemType.TASK.value)]))
+
+ def test_collect_v2_duplication(self):
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ service.collect_v2('test_item', [(ItemType.TASK, RunnerInput())])
+ session.commit()
+ service.collect_v2('test_item', [(ItemType.TASK, RunnerInput())])
+ session.commit()
+ with db.session_scope() as session:
+ items = session.query(SchedulerItem).filter(SchedulerItem.name == 'test_item').all()
+ self.assertEqual(len(items), 1)
+
+ def test_collect_v2_cron(self):
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ service.collect_v2('test_cron_item', [
+ (ItemType.TASK, RunnerInput()),
+ ], '* * * * * */10')
+ session.commit()
+ with db.session_scope() as session:
+ item = session.query(SchedulerItem).filter(SchedulerItem.name == 'test_cron_item').first()
+ self.assertEqual(item.status, ItemStatus.ON.value)
+ self.assertEqual(item.cron_config, '* * * * * */10')
+ self.assertEqual(
+ item.get_pipeline(),
+ Pipeline(version=2, name='test_cron_item', queue=[RunnerInput(runner_type=ItemType.TASK.value)]))
+
+ def test_finish(self):
+ with db.session_scope() as session:
+ item = SchedulerItem(id=100, name='fake_item', status=ItemStatus.ON.value)
+ runner_1 = SchedulerRunner(id=100, item_id=100, status=RunnerStatus.RUNNING.value)
+ runner_2 = SchedulerRunner(id=101, item_id=100, status=RunnerStatus.DONE.value)
+ runner_3 = SchedulerRunner(id=102, item_id=100, status=RunnerStatus.FAILED.value)
+ runner_4 = SchedulerRunner(id=103, item_id=100, status=RunnerStatus.INIT.value)
+ session.add(item)
+ session.add(runner_1)
+ session.add(runner_2)
+ session.add(runner_3)
+ session.add(runner_4)
+ session.commit()
+
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ service.finish(name='fake_item')
+ session.commit()
+
+ with db.session_scope() as session:
+ item = session.query(SchedulerItem).get(100)
+ runner_1 = session.query(SchedulerRunner).get(100)
+ runner_2 = session.query(SchedulerRunner).get(101)
+ runner_3 = session.query(SchedulerRunner).get(102)
+ runner_4 = session.query(SchedulerRunner).get(103)
+
+ self.assertEqual(item.status, ItemStatus.OFF.value)
+ self.assertIsNone(runner_1)
+ self.assertEqual(runner_2.status, RunnerStatus.DONE.value)
+ self.assertEqual(runner_3.status, RunnerStatus.FAILED.value)
+ self.assertIsNone(runner_4)
+
+ def test_get_recent_runners(self):
+ with db.session_scope() as session:
+ item = SchedulerItem(id=100, name='fake_item', status=ItemStatus.ON.value)
+ runner_1 = SchedulerRunner(id=100,
+ item_id=100,
+ status=RunnerStatus.RUNNING.value,
+ created_at=datetime(2012, 1, 14, 12, 0, 6))
+ runner_2 = SchedulerRunner(id=101,
+ item_id=100,
+ status=RunnerStatus.DONE.value,
+ created_at=datetime(2012, 1, 14, 12, 0, 7))
+ runner_3 = SchedulerRunner(id=102,
+ item_id=100,
+ status=RunnerStatus.FAILED.value,
+ created_at=datetime(2012, 1, 14, 12, 0, 8))
+ runner_4 = SchedulerRunner(id=103,
+ item_id=100,
+ status=RunnerStatus.INIT.value,
+ created_at=datetime(2012, 1, 14, 12, 0, 9))
+ session.add_all([item, runner_1, runner_2, runner_3, runner_4])
+ session.commit()
+
+ with db.session_scope() as session:
+ expect_runners = [runner_4, runner_3, runner_2, runner_1]
+ runners = ComposerService(session).get_recent_runners(name='fake_item', count=10)
+ self.assertEqual(len(runners), 4)
+ for i in range(4):
+ self.assertEqual(runners[i].id, expect_runners[i].id)
+ self.assertEqual(runners[i].status, expect_runners[i].status)
+ self.assertEqual(runners[i].item_id, 100)
+
+ runners = ComposerService(session).get_recent_runners(name='fake_item', count=1)
+ self.assertEqual(len(runners), 1)
+ self.assertEqual(runners[0].id, expect_runners[0].id)
+ self.assertEqual(runners[0].status, expect_runners[0].status)
+ self.assertEqual(runners[0].item_id, 100)
+
+ def test_patch_item_attr(self):
+ test_item_name = 'test'
+ with db.session_scope() as session:
+ scheduler_item = SchedulerItem(name=test_item_name, cron_config='* */1 * * *')
+ session.add(scheduler_item)
+ session.commit()
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ service.patch_item_attr(name=test_item_name, key='cron_config', value='*/20 * * * *')
+ session.commit()
+ with db.session_scope() as session:
+ item = session.query(SchedulerItem).filter(SchedulerItem.name == test_item_name).one()
+ self.assertEqual(item.cron_config, '*/20 * * * *')
+ with self.assertRaises(ValueError):
+ with db.session_scope() as session:
+ ComposerService(session).patch_item_attr(name=test_item_name,
+ key='create_at',
+ value='2021-04-01 00:00:00')
+ session.commit()
+
+
+class CronJobServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ item = SchedulerItem(id=1,
+ name='model_training_cronjob_1',
+ cron_config='*/10 * * * *',
+ status=ItemStatus.ON.value)
+ session.add(item)
+ session.commit()
+
+ def test_start_cronjob(self):
+ with db.session_scope() as session:
+ items = [(ItemType.MODEL_TRAINING_CRON_JOB,
+ RunnerInput(model_training_cron_job_input=ModelTrainingCronJobInput(group_id=1)))]
+ CronJobService(session).start_cronjob('model_training_cronjob_1', items, '*/20 * * * *')
+ CronJobService(session).start_cronjob('model_training_cronjob_2', items, '*/20 * * * *')
+ session.commit()
+ with db.session_scope() as session:
+ item_1 = session.query(SchedulerItem).get(1)
+ self.assertEqual(item_1.cron_config, '*/20 * * * *')
+ item_2 = session.query(SchedulerItem).filter_by(name='model_training_cronjob_2').first()
+ self.assertEqual(item_2.cron_config, '*/20 * * * *')
+
+ def test_stop_cronjob(self):
+ with db.session_scope() as session:
+ CronJobService(session).stop_cronjob('model_training_cronjob_1')
+ session.commit()
+ with db.session_scope() as session:
+ item = session.query(SchedulerItem).get(1)
+ self.assertEqual(item.status, ItemStatus.OFF.value)
+
+
+class SchedulerItemServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ scheduler_item_off = SchedulerItem(id=5,
+ name='test_item_off',
+ status=ItemStatus.OFF.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ scheduler_item_on = SchedulerItem(id=6,
+ name='test_item_on',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ scheduler_item_on_cron = SchedulerItem(id=7,
+ name='test_item_on_cron',
+ cron_config='*/20 * * * *',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+
+ with db.session_scope() as session:
+ session.add(scheduler_item_on)
+ session.add(scheduler_item_off)
+ session.add(scheduler_item_on_cron)
+ session.commit()
+
+ def test_get_scheduler_items(self):
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='is_cron',
+ op=FilterOp.EQUAL,
+ bool_value=1,
+ ))
+ with db.session_scope() as session:
+ service = SchedulerItemService(session)
+ paginations = service.get_scheduler_items(filter_exp=filter_exp, page=1, page_size=7)
+ item_ids = [item.id for item in paginations.get_items()]
+ self.assertEqual(item_ids, [7])
+
+
+class SchedulerRunnerServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ scheduler_item_on = SchedulerItem(id=100,
+ name='test_item_on',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ self.default_scheduler_item = scheduler_item_on
+ runner_init = SchedulerRunner(id=0, item_id=100, status=RunnerStatus.INIT.value)
+ runner_running_1 = SchedulerRunner(id=1, item_id=100, status=RunnerStatus.RUNNING.value)
+ runner_running_2 = SchedulerRunner(id=2, item_id=100, status=RunnerStatus.RUNNING.value)
+
+ with db.session_scope() as session:
+ session.add(scheduler_item_on)
+ session.add(runner_init)
+ session.add(runner_running_1)
+ session.add(runner_running_2)
+ session.commit()
+
+ def test_get_scheduler_runners(self):
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='status',
+ op=FilterOp.EQUAL,
+ string_value='INIT',
+ ))
+ with db.session_scope() as session:
+ service = SchedulerRunnerService(session)
+ paginations = service.get_scheduler_runners(filter_exp=filter_exp, page=1, page_size=7)
+ item_ids = [item.id for item in paginations.get_items()]
+ self.assertEqual(item_ids, [0])
+
+
+if __name__ == '__main__':
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/composer_test.py b/web_console_v2/api/fedlearner_webconsole/composer/composer_test.py
new file mode 100644
index 000000000..a9b4f2c91
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/composer_test.py
@@ -0,0 +1,171 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+import logging
+import sys
+import threading
+import unittest
+from datetime import datetime
+
+from fedlearner_webconsole.composer.composer import Composer, ComposerConfig
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import ItemStatus, RunnerStatus, SchedulerItem, \
+ SchedulerRunner
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, PipelineContextData, RunnerOutput
+from testing.composer.common import TestRunner
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.fake_time_patcher import FakeTimePatcher
+
+
+class ComposerV2Test(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.time_patcher = FakeTimePatcher()
+ self.time_patcher.start(datetime(2012, 1, 14, 12, 0, 5))
+
+ runner_fn = {
+ ItemType.TASK.value: TestRunner,
+ }
+
+ def tearDown(self):
+ self.time_patcher.stop()
+ super().tearDown()
+
+ def test_multiple_composers(self):
+ logging.info('+++++++++++++++++++++++++++ test multiple composers')
+ cfg = ComposerConfig(runner_fn=self.runner_fn, name='scheduler for normal items')
+ composer1 = Composer(cfg)
+ composer2 = Composer(cfg)
+ c1 = threading.Thread(target=composer1.run, args=[db.engine])
+ c1.start()
+ c2 = threading.Thread(target=composer2.run, args=[db.engine])
+ c2.start()
+ self.time_patcher.interrupt(15)
+ composer1.stop()
+ composer2.stop()
+
+ def test_normal_items(self):
+ logging.info('+++++++++++++++++++++++++++ test normal items')
+ cfg = ComposerConfig(runner_fn=self.runner_fn, name='scheduler for normal items')
+ composer = Composer(config=cfg)
+ composer.run(db_engine=db.engine)
+ with db.session_scope() as session:
+ name = 'normal items'
+ service = ComposerService(session)
+ service.collect_v2(name, [(ItemType.TASK, RunnerInput()), (ItemType.TASK, RunnerInput()),
+ (ItemType.TASK, RunnerInput())])
+ session.commit()
+ self.time_patcher.interrupt(60)
+ with db.session_scope() as session:
+ runners = session.query(SchedulerRunner).all()
+ self.assertEqual(len(runners), 1, 'Should be only 1 runner')
+ self.assertEqual(runners[0].status, RunnerStatus.DONE.value)
+ self.assertEqual(
+ runners[0].get_context(),
+ PipelineContextData(current_runner=2,
+ outputs={
+ 0: RunnerOutput(),
+ 1: RunnerOutput(),
+ 2: RunnerOutput(),
+ }))
+ # Item should be finished
+ item = session.query(SchedulerItem).filter(SchedulerItem.name == 'normal items').first()
+ self.assertEqual(item.status, ItemStatus.OFF.value, 'should finish item')
+ composer.stop()
+
+ def test_failed_items(self):
+ logging.info('+++++++++++++++++++++++++++ test failed items')
+ cfg = ComposerConfig(runner_fn=self.runner_fn, name='scheduler for failed items')
+ composer = Composer(config=cfg)
+ composer.run(db_engine=db.engine)
+ with db.session_scope() as session:
+ name = 'failed items'
+ ComposerService(session).collect_v2(
+ name,
+ [
+ (ItemType.TASK, RunnerInput()),
+ (ItemType.TASK, RunnerInput()),
+ (ItemType.TASK, RunnerInput()),
+ # Failed one
+ (ItemType.TASK, RunnerInput()),
+ (ItemType.TASK, RunnerInput()),
+ ])
+ session.commit()
+ self.time_patcher.interrupt(60)
+ with db.session_scope() as session:
+ runners = session.query(SchedulerRunner).all()
+ self.assertEqual(len(runners), 1, 'Should be only 1 runner')
+ self.assertEqual(runners[0].status, RunnerStatus.FAILED.value)
+ self.assertEqual(
+ runners[0].get_context(),
+ PipelineContextData(current_runner=3,
+ outputs={
+ 0: RunnerOutput(),
+ 1: RunnerOutput(),
+ 2: RunnerOutput(),
+ 3: RunnerOutput(error_message='index is 3')
+ }))
+ # Item should be finished
+ item = session.query(SchedulerItem).filter(SchedulerItem.name == 'failed items').first()
+ self.assertEqual(item.status, ItemStatus.OFF.value, 'should finish item')
+ composer.stop()
+
+ def test_cron_items(self):
+ logging.info('+++++++++++++++++++++++++++ test finishing cron items')
+ cfg = ComposerConfig(runner_fn=self.runner_fn, name='finish normal items')
+ composer = Composer(config=cfg)
+ composer.run(db_engine=db.engine)
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ name = 'cronjob'
+ # test invalid cron
+ self.assertRaises(ValueError,
+ service.collect_v2,
+ name, [
+ (ItemType.TASK, RunnerInput()),
+ ],
+ cron_config='invalid cron')
+
+ service.collect_v2(
+ name,
+ [
+ (ItemType.TASK, RunnerInput()),
+ ],
+ # Every 10 seconds
+ cron_config='* * * * * */10')
+ session.commit()
+ self.assertEqual(1, len(session.query(SchedulerItem).all()))
+ # Interrupts twice since we need two rounds of tick for
+ # composer to schedule items in fake world
+ self.time_patcher.interrupt(11)
+ self.time_patcher.interrupt(11)
+ with db.session_scope() as session:
+ self.assertEqual(2, len(session.query(SchedulerRunner).all()))
+ service = ComposerService(session)
+ self.assertEqual(RunnerStatus.DONE.value,
+ service.get_recent_runners(name)[-1].status, 'should finish runner')
+ service.finish(name)
+ session.commit()
+ self.assertEqual(ItemStatus.OFF, service.get_item_status(name), 'should finish item')
+ composer.stop()
+
+
+if __name__ == '__main__':
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/context.py b/web_console_v2/api/fedlearner_webconsole/composer/context.py
new file mode 100644
index 000000000..c71b101f0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/context.py
@@ -0,0 +1,61 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=redefined-builtin
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, PipelineContextData, Pipeline
+
+
+class RunnerContext(object):
+
+ def __init__(self, index: int, input: RunnerInput):
+ self._index = index
+ self._input = input
+
+ @property
+ def index(self) -> int:
+ return self._index
+
+ @property
+ def input(self) -> RunnerInput:
+ return self._input
+
+
+class PipelineContext(object):
+
+ def __init__(self, pipeline: Pipeline, data: PipelineContextData):
+ self._pipeline = pipeline
+ self._data = data
+ self._runner_contexts = {}
+
+ @classmethod
+ def build(cls, pipeline: Pipeline, data: PipelineContextData) -> 'PipelineContext':
+ return cls(pipeline=pipeline, data=data)
+
+ def run_next(self):
+ if self._data.current_runner >= len(self._pipeline.queue) - 1:
+ return
+ self._data.current_runner += 1
+
+ def get_current_runner_context(self) -> RunnerContext:
+ runner_idx = self._data.current_runner
+ if runner_idx in self._runner_contexts:
+ return self._runner_contexts[runner_idx]
+ context = RunnerContext(index=runner_idx, input=self._pipeline.queue[runner_idx])
+ self._runner_contexts[runner_idx] = context
+ return context
+
+ @property
+ def data(self) -> PipelineContextData:
+ return self._data
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/context_test.py b/web_console_v2/api/fedlearner_webconsole/composer/context_test.py
new file mode 100644
index 000000000..54259e78f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/context_test.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.composer.context import PipelineContext
+from fedlearner_webconsole.proto.composer_pb2 import Pipeline, RunnerInput, PipelineContextData
+
+
+class PipelineContextTest(unittest.TestCase):
+
+ def test_get_current_runner_context(self):
+ pipeline_context = PipelineContext.build(pipeline=Pipeline(version=2,
+ name='test pipeline',
+ queue=[
+ RunnerInput(runner_type='test type1'),
+ RunnerInput(runner_type='test type2'),
+ ]),
+ data=PipelineContextData())
+ runner_context = pipeline_context.get_current_runner_context()
+ self.assertEqual(runner_context.index, 0)
+ self.assertEqual(runner_context.input.runner_type, 'test type1')
+ pipeline_context.run_next()
+ runner_context = pipeline_context.get_current_runner_context()
+ self.assertEqual(runner_context.index, 1)
+ self.assertEqual(runner_context.input.runner_type, 'test type2')
+ # No effect as whole pipeline already done
+ pipeline_context.run_next()
+ runner_context = pipeline_context.get_current_runner_context()
+ self.assertEqual(runner_context.index, 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/interface.py b/web_console_v2/api/fedlearner_webconsole/composer/interface.py
index f9acdeb88..001eafe79 100644
--- a/web_console_v2/api/fedlearner_webconsole/composer/interface.py
+++ b/web_console_v2/api/fedlearner_webconsole/composer/interface.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,20 +19,38 @@
import enum
from typing import Tuple
-from fedlearner_webconsole.composer.models import Context, RunnerStatus
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput
# NOTE: remember to register new item in `global_runner_fn` \
# which defined in `runner.py`
class ItemType(enum.Enum):
TASK = 'task' # test only
- MEMORY = 'memory'
- WORKFLOW_CRON_JOB = 'workflow_cron_job'
- DATA_PIPELINE = 'data_pipeline'
+ WORKFLOW_CRON_JOB = 'workflow_cron_job' # v2
+ BATCH_STATS = 'batch_stats' # v2
+ SERVING_SERVICE_PARSE_SIGNATURE = 'serving_service_parse_signature' # v2
+ SERVING_SERVICE_QUERY_PARTICIPANT_STATUS = 'serving_service_query_participant_status' # v2
+ SERVING_SERVICE_UPDATE_MODEL = 'serving_service_update_model' # v2
+ SCHEDULE_WORKFLOW = 'schedule_workflow' # v2
+ SCHEDULE_JOB = 'schedule_job' # v2
+ CLEANUP_CRON_JOB = 'cleanup_cron_job' # v2
+ MODEL_TRAINING_CRON_JOB = 'model_training_cron_job' # v2
+ TEE_CREATE_RUNNER = 'tee_create_runner' # v2
+ TEE_RESOURCE_CHECK_RUNNER = 'tee_resource_check_runner' # v2
+ SCHEDULE_PROJECT = 'schedule_project' # v2
+ DATASET_LONG_PERIOD_SCHEDULER = 'dataset_long_period_scheduler' # v2
+ DATASET_SHORT_PERIOD_SCHEDULER = 'dataset_short_period_scheduler' # v2
+ SCHEDULE_MODEL_JOB = 'schedule_model_job' # v2
+ SCHEDULE_MODEL_JOB_GROUP = 'schedule_model_job_group' # v2
+ SCHEDULE_LONG_PERIOD_MODEL_JOB_GROUP = 'schedule_long_period_model_job_group' # v2
# item interface
class IItem(metaclass=ABCMeta):
+
@abstractmethod
def type(self) -> ItemType:
pass
@@ -42,27 +60,16 @@ def get_id(self) -> int:
pass
-# runner interface
-class IRunner(metaclass=ABCMeta):
- @abstractmethod
- def start(self, context: Context):
- """Start runner
-
- Args:
- context: shared in runner. Don't write data to context in this
- method. Only can read data via `context.data`.
- """
+class IRunnerV2(metaclass=ABCMeta):
@abstractmethod
- def result(self, context: Context) -> Tuple[RunnerStatus, dict]:
- """Check runner result
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ """Runs the runner.
- NOTE: You could check runner if is timeout in this method. If it's
- timeout, return `RunnerStatus.FAILED`. Since runners executed by
- `ThreadPoolExecutor` may have some common resources, it's better to
- stop the runner by user instead of `composer`.
+ The implementation should be light, as runners will be executed by `ThreadPoolExecutor`.
Args:
- context: shared in runner. In this method, data can be
- read or written to context via `context.data`.
+ context: immutable context in the runner.
+ Returns:
+ status and the output.
"""
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/models.py b/web_console_v2/api/fedlearner_webconsole/composer/models.py
index 4cf88ac7e..17e8c5c11 100644
--- a/web_console_v2/api/fedlearner_webconsole/composer/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/composer/models.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,17 +16,24 @@
import enum
import json
-import datetime
import logging
-
-from sqlalchemy import UniqueConstraint
+from datetime import timezone, datetime
+from sqlalchemy import UniqueConstraint, Index
from sqlalchemy.engine import Engine
from sqlalchemy.sql import func
+from croniter import croniter
from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.proto import composer_pb2
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.utils.mixins import to_dict_mixin
+from fedlearner_webconsole.utils.proto import to_json, parse_from_json
+from fedlearner_webconsole.proto.composer_pb2 import SchedulerItemPb, SchedulerRunnerPb
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
class Context(object):
+
def __init__(self, data: dict, internal: dict, db_engine: Engine):
self._data = data # user data
self._internal = internal # internal system data
@@ -52,24 +59,21 @@ def db_engine(self) -> Engine:
class ContextEncoder(json.JSONEncoder):
- def default(self, obj) -> dict:
- d = obj.__dict__
- return {
- '_data': d.get('_data', {}),
- '_internal': d.get('_internal', {})
- }
+
+ def default(self, o) -> dict:
+ d = o.__dict__
+ return {'_data': d.get('_data', {}), '_internal': d.get('_internal', {})}
class ContextDecoder(json.JSONDecoder):
+
def __init__(self, db_engine: Engine):
self.db_engine = db_engine
super().__init__(object_hook=self.dict2object)
def dict2object(self, val):
if '_data' in val and '_internal' in val:
- return Context(data=val.get('_data', {}),
- internal=val.get('_internal', {}),
- db_engine=self.db_engine)
+ return Context(data=val.get('_data', {}), internal=val.get('_internal', {}), db_engine=self.db_engine)
return val
@@ -84,37 +88,24 @@ class ItemStatus(enum.Enum):
ON = 1 # need to run
+@to_dict_mixin(extras={'need_run': (lambda si: si.need_run())})
class SchedulerItem(db.Model):
__tablename__ = 'scheduler_item_v2'
- __table_args__ = (UniqueConstraint('name', name='uniq_name'),
- default_table_args('scheduler items'))
- id = db.Column(db.Integer,
- comment='id',
- primary_key=True,
- autoincrement=True)
+ __table_args__ = (
+ UniqueConstraint('name', name='uniq_name'),
+ # idx_status is a common name will may cause conflict in sqlite
+ Index('idx_item_status', 'status'),
+ default_table_args('scheduler items'),
+ )
+ id = db.Column(db.Integer, comment='id', primary_key=True, autoincrement=True)
name = db.Column(db.String(255), comment='item name', nullable=False)
- pipeline = db.Column(db.Text,
- comment='pipeline',
- nullable=False,
- default='{}')
- status = db.Column(db.Integer,
- comment='item status',
- nullable=False,
- default=ItemStatus.ON.value)
- interval_time = db.Column(db.Integer,
- comment='item run interval in second',
- nullable=False,
- default=-1)
- last_run_at = db.Column(db.DateTime(timezone=True),
- comment='last runner time')
- retry_cnt = db.Column(db.Integer,
- comment='retry count when item is failed',
- nullable=False,
- default=0)
+ pipeline = db.Column(db.Text(16777215), comment='pipeline', nullable=False, default='{}')
+ status = db.Column(db.Integer, comment='item status', nullable=False, default=ItemStatus.ON.value)
+ cron_config = db.Column(db.String(255), comment='cron expression in UTC timezone')
+ last_run_at = db.Column(db.DateTime(timezone=True), comment='last runner time')
+ retry_cnt = db.Column(db.Integer, comment='retry count when item is failed', nullable=False, default=0)
extra = db.Column(db.Text(), comment='extra info')
- created_at = db.Column(db.DateTime(timezone=True),
- comment='created at',
- server_default=func.now())
+ created_at = db.Column(db.DateTime(timezone=True), comment='created at', server_default=func.now())
updated_at = db.Column(db.DateTime(timezone=True),
comment='updated at',
server_default=func.now(),
@@ -122,23 +113,40 @@ class SchedulerItem(db.Model):
deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')
def need_run(self) -> bool:
- # job runs one time
- if self.interval_time == -1 and self.last_run_at is None:
- return True
- if self.interval_time > 0: # cronjob
- if self.last_run_at is None: # never run
- return True
- # compare datetime in utc
- next_run_at = self.last_run_at.replace(
- tzinfo=datetime.timezone.utc) + datetime.timedelta(
- seconds=self.interval_time)
- utc_now = datetime.datetime.now(datetime.timezone.utc)
- logging.debug(f'[composer] item id: {self.id}, '
- f'next_run_at: {next_run_at.timestamp()}, '
- f'utc_now: {utc_now.timestamp()}')
- if next_run_at.timestamp() < utc_now.timestamp():
- return True
- return False
+ if not self.cron_config:
+ # job runs once
+ return self.last_run_at is None
+ # cronjob
+ if self.last_run_at is None: # never run
+ # if there is no start time, croniter will return next run
+ # datetime (UTC) based on create time
+ base = self.created_at.replace(tzinfo=timezone.utc)
+ else:
+ base = self.last_run_at.replace(tzinfo=timezone.utc)
+ next_run_at = croniter(self.cron_config, base).get_next(datetime)
+ utc_now = now(timezone.utc)
+ logging.debug(f'[composer] item id: {self.id}, '
+ f'next_run_at: {next_run_at.timestamp()}, '
+ f'utc_now: {utc_now.timestamp()}')
+ return next_run_at.timestamp() < utc_now.timestamp()
+
+ def set_pipeline(self, proto: composer_pb2.Pipeline):
+ self.pipeline = to_json(proto)
+
+ def get_pipeline(self) -> composer_pb2.Pipeline:
+ return parse_from_json(self.pipeline, composer_pb2.Pipeline())
+
+ def to_proto(self) -> SchedulerItemPb:
+ return SchedulerItemPb(id=self.id,
+ name=self.name,
+ pipeline=self.get_pipeline(),
+ status=ItemStatus(self.status).name,
+ cron_config=self.cron_config,
+ last_run_at=to_timestamp(self.last_run_at) if self.last_run_at else None,
+ retry_cnt=self.retry_cnt,
+ created_at=to_timestamp(self.created_at) if self.created_at else None,
+ updated_at=to_timestamp(self.updated_at) if self.updated_at else None,
+ deleted_at=to_timestamp(self.deleted_at) if self.deleted_at else None)
class RunnerStatus(enum.Enum):
@@ -148,43 +156,62 @@ class RunnerStatus(enum.Enum):
FAILED = 3
+@to_dict_mixin()
class SchedulerRunner(db.Model):
__tablename__ = 'scheduler_runner_v2'
- __table_args__ = (default_table_args('scheduler runners'))
- id = db.Column(db.Integer,
- comment='id',
- primary_key=True,
- autoincrement=True)
+ __table_args__ = (
+ # idx_status is a common name will may cause conflict in sqlite
+ Index('idx_runner_status', 'status'),
+ Index('idx_runner_item_id', 'item_id'),
+ default_table_args('scheduler runners'),
+ )
+ id = db.Column(db.Integer, comment='id', primary_key=True, autoincrement=True)
item_id = db.Column(db.Integer, comment='item id', nullable=False)
- status = db.Column(db.Integer,
- comment='runner status',
- nullable=False,
- default=RunnerStatus.INIT.value)
- start_at = db.Column(db.DateTime(timezone=True),
- comment='runner start time')
+ status = db.Column(db.Integer, comment='runner status', nullable=False, default=RunnerStatus.INIT.value)
+ start_at = db.Column(db.DateTime(timezone=True), comment='runner start time')
end_at = db.Column(db.DateTime(timezone=True), comment='runner end time')
- pipeline = db.Column(db.Text(),
- comment='pipeline from scheduler item',
- nullable=False,
- default='{}')
- output = db.Column(db.Text(),
- comment='output',
- nullable=False,
- default='{}')
- context = db.Column(db.Text(),
- comment='context',
- nullable=False,
- default='{}')
+ pipeline = db.Column(db.Text(16777215), comment='pipeline from scheduler item', nullable=False, default='{}')
+ output = db.Column(db.Text(), comment='output', nullable=False, default='{}')
+ context = db.Column(db.Text(16777215), comment='context', nullable=False, default='{}')
extra = db.Column(db.Text(), comment='extra info')
- created_at = db.Column(db.DateTime(timezone=True),
- comment='created at',
- server_default=func.now())
+ created_at = db.Column(db.DateTime(timezone=True), comment='created at', server_default=func.now())
updated_at = db.Column(db.DateTime(timezone=True),
comment='updated at',
server_default=func.now(),
onupdate=func.now())
deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')
+ def set_pipeline(self, proto: composer_pb2.Pipeline):
+ self.pipeline = to_json(proto)
+
+ def get_pipeline(self) -> composer_pb2.Pipeline:
+ return parse_from_json(self.pipeline, composer_pb2.Pipeline())
+
+ def set_context(self, proto: composer_pb2.PipelineContextData):
+ self.context = to_json(proto)
+
+ def get_context(self) -> composer_pb2.PipelineContextData:
+ return parse_from_json(self.context, composer_pb2.PipelineContextData())
+
+ def set_output(self, proto: composer_pb2.RunnerOutput):
+ self.output = to_json(proto)
+
+ def get_output(self) -> composer_pb2.RunnerOutput:
+ return parse_from_json(self.output, composer_pb2.RunnerOutput())
+
+ def to_proto(self) -> SchedulerRunnerPb:
+ return SchedulerRunnerPb(id=self.id,
+ item_id=self.item_id,
+ status=RunnerStatus(self.status).name,
+ start_at=to_timestamp(self.start_at) if self.start_at else None,
+ end_at=to_timestamp(self.end_at) if self.end_at else None,
+ pipeline=self.get_pipeline(),
+ output=self.get_output(),
+ context=self.get_context(),
+ created_at=to_timestamp(self.created_at) if self.created_at else None,
+ updated_at=to_timestamp(self.updated_at) if self.updated_at else None,
+ deleted_at=to_timestamp(self.deleted_at) if self.deleted_at else None)
+
class OptimisticLock(db.Model):
__tablename__ = 'optimistic_lock_v2'
@@ -192,15 +219,10 @@ class OptimisticLock(db.Model):
UniqueConstraint('name', name='uniq_name'),
default_table_args('optimistic lock'),
)
- id = db.Column(db.Integer,
- comment='id',
- primary_key=True,
- autoincrement=True)
+ id = db.Column(db.Integer, comment='id', primary_key=True, autoincrement=True)
name = db.Column(db.String(255), comment='lock name', nullable=False)
version = db.Column(db.BIGINT, comment='lock version', nullable=False)
- created_at = db.Column(db.DateTime(timezone=True),
- comment='created at',
- server_default=func.now())
+ created_at = db.Column(db.DateTime(timezone=True), comment='created at', server_default=func.now())
updated_at = db.Column(db.DateTime(timezone=True),
comment='updated at',
server_default=func.now(),
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/models_test.py b/web_console_v2/api/fedlearner_webconsole/composer/models_test.py
new file mode 100644
index 000000000..23746ac85
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/models_test.py
@@ -0,0 +1,195 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+from fedlearner_webconsole.composer.models import SchedulerItem, ItemStatus, RunnerStatus, SchedulerRunner
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto import composer_pb2
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.utils.proto import to_json
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class SchedulerItemTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ self.default_pipeline = composer_pb2.Pipeline(version=2,
+ name='test pipeline',
+ queue=[
+ composer_pb2.RunnerInput(runner_type='test type1'),
+ composer_pb2.RunnerInput(runner_type='test type2'),
+ ])
+ scheduler_item = SchedulerItem(id=5,
+ name='test_item_off',
+ pipeline=to_json(self.default_pipeline),
+ status=ItemStatus.OFF.value,
+ cron_config='* * * * * 15',
+ last_run_at=created_at,
+ retry_cnt=0,
+ created_at=created_at,
+ updated_at=created_at)
+ with db.session_scope() as session:
+ session.add(scheduler_item)
+ session.commit()
+
+ def test_need_run_normal_job(self):
+ with db.session_scope() as session:
+ item = SchedulerItem(name='test normal item')
+ session.commit()
+ # Never run
+ self.assertTrue(item.need_run())
+ item.last_run_at = now()
+ session.commit()
+ self.assertFalse(item.need_run())
+
+ @patch('fedlearner_webconsole.composer.models.now')
+ def test_need_run_cron_job(self, mock_now):
+ with db.session_scope() as session:
+ item = SchedulerItem(
+ name='test cron item',
+ # Runs every 30 minutes
+ cron_config='*/30 * * * *',
+ created_at=datetime(2021, 9, 1, 10, 10))
+ session.commit()
+ # Never run
+ mock_now.return_value = datetime(2021, 9, 1, 10, 20, tzinfo=timezone.utc)
+ self.assertFalse(item.need_run())
+ mock_now.return_value = datetime(2021, 9, 1, 10, 50, tzinfo=timezone.utc)
+ self.assertTrue(item.need_run())
+ # Has been run
+ item.last_run_at = datetime(2021, 9, 1, 10, 10)
+ session.commit()
+ mock_now.return_value = datetime(2021, 9, 1, 10, 11, tzinfo=timezone.utc)
+ self.assertFalse(item.need_run())
+ mock_now.return_value = datetime(2021, 9, 1, 10, 50, tzinfo=timezone.utc)
+ self.assertTrue(item.need_run())
+
+ def test_get_pipeline(self):
+ with db.session_scope() as session:
+ scheduler_item = session.query(SchedulerItem).first()
+ self.assertEqual(self.default_pipeline, scheduler_item.get_pipeline())
+
+ def test_set_pipeline(self):
+ with db.session_scope() as session:
+ scheduler_item = session.query(SchedulerItem).first()
+ pipeline = composer_pb2.Pipeline(name='test1')
+ scheduler_item.set_pipeline(pipeline)
+ self.assertEqual(pipeline, scheduler_item.get_pipeline())
+
+ def test_to_proto(self):
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ with db.session_scope() as session:
+ scheduler_item = session.query(SchedulerItem).first()
+ self.assertEqual(
+ scheduler_item.to_proto(),
+ composer_pb2.SchedulerItemPb(id=5,
+ name='test_item_off',
+ pipeline=self.default_pipeline,
+ status=ItemStatus.OFF.name,
+ cron_config='* * * * * 15',
+ last_run_at=int(created_at.timestamp()),
+ retry_cnt=0,
+ created_at=int(created_at.timestamp()),
+ updated_at=int(created_at.timestamp())))
+
+
+class SchedulerRunnerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.default_context = composer_pb2.PipelineContextData(current_runner=0)
+ self.default_pipeline = composer_pb2.Pipeline(version=2,
+ name='test pipeline',
+ queue=[
+ composer_pb2.RunnerInput(runner_type='test type1'),
+ composer_pb2.RunnerInput(runner_type='test type2'),
+ ])
+ self.default_output = composer_pb2.RunnerOutput(error_message='error1')
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ scheduler_runner = SchedulerRunner(
+ id=5,
+ item_id=1,
+ status=RunnerStatus.INIT.value,
+ start_at=created_at,
+ pipeline=to_json(self.default_pipeline),
+ output=to_json(self.default_output),
+ context=to_json(self.default_context),
+ created_at=created_at,
+ updated_at=created_at,
+ )
+ with db.session_scope() as session:
+ session.add(scheduler_runner)
+ session.commit()
+
+ def test_get_pipeline(self):
+ with db.session_scope() as session:
+ scheduler_runner = session.query(SchedulerRunner).first()
+ self.assertEqual(self.default_pipeline, scheduler_runner.get_pipeline())
+
+ def test_set_pipeline(self):
+ with db.session_scope() as session:
+ scheduler_runner = session.query(SchedulerRunner).first()
+ pipeline = composer_pb2.Pipeline(name='test1')
+ scheduler_runner.set_pipeline(pipeline)
+ self.assertEqual(pipeline, scheduler_runner.get_pipeline())
+
+ def test_get_context(self):
+ with db.session_scope() as session:
+ scheduler_runner = session.query(SchedulerRunner).first()
+ self.assertEqual(self.default_context, scheduler_runner.get_context())
+
+ def test_set_context(self):
+ with db.session_scope() as session:
+ scheduler_runner = session.query(SchedulerRunner).first()
+ context = composer_pb2.PipelineContextData(current_runner=1)
+ scheduler_runner.set_context(context)
+ self.assertEqual(context, scheduler_runner.get_context())
+
+ def test_get_output(self):
+ with db.session_scope() as session:
+ scheduler_runner = session.query(SchedulerRunner).first()
+ self.assertEqual(self.default_output, scheduler_runner.get_output())
+
+ def test_set_output(self):
+ with db.session_scope() as session:
+ scheduler_runner = session.query(SchedulerRunner).first()
+ output = composer_pb2.RunnerOutput(error_message='error2')
+ scheduler_runner.set_output(output)
+ self.assertEqual(output, scheduler_runner.get_output())
+
+ def test_to_proto(self):
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ with db.session_scope() as session:
+ scheduler_runner = session.query(SchedulerRunner).first()
+ self.assertEqual(
+ scheduler_runner.to_proto(),
+ composer_pb2.SchedulerRunnerPb(id=5,
+ item_id=1,
+ status=RunnerStatus.INIT.name,
+ start_at=int(created_at.timestamp()),
+ pipeline=self.default_pipeline,
+ output=self.default_output,
+ context=self.default_context,
+ created_at=int(created_at.timestamp()),
+ updated_at=int(created_at.timestamp())))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/op_locker.py b/web_console_v2/api/fedlearner_webconsole/composer/op_locker.py
index 8b0bdd404..1c4c09e77 100644
--- a/web_console_v2/api/fedlearner_webconsole/composer/op_locker.py
+++ b/web_console_v2/api/fedlearner_webconsole/composer/op_locker.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,11 +23,13 @@
class OpLocker(object):
+
def __init__(self, name: str, db_engine: Engine):
"""Optimistic Lock
Args:
- name: lock name should be unique in same thread
+ name (str): lock name should be unique in same thread
+ db_engine (Engine): db engine
"""
self._name = name
self._version = 0
@@ -45,14 +47,12 @@ def version(self) -> int:
def try_lock(self) -> 'OpLocker':
with get_session(self.db_engine) as session:
try:
- lock = session.query(OptimisticLock).filter_by(
- name=self._name).first()
+ lock = session.query(OptimisticLock).filter_by(name=self._name).first()
if lock:
self._has_lock = True
self._version = lock.version
return self
- new_lock = OptimisticLock(name=self._name,
- version=self._version)
+ new_lock = OptimisticLock(name=self._name, version=self._version)
session.add(new_lock)
session.commit()
self._has_lock = True
@@ -67,16 +67,13 @@ def is_latest_version(self) -> bool:
with get_session(self.db_engine) as session:
try:
- new_lock = session.query(OptimisticLock).filter_by(
- name=self._name).first()
+ new_lock = session.query(OptimisticLock).filter_by(name=self._name).first()
if not new_lock:
return False
- logging.info(f'[op_locker] version, current: {self._version}, '
- f'new: {new_lock.version}')
+ logging.info(f'[op_locker] version, current: {self._version}, ' f'new: {new_lock.version}')
return self._version == new_lock.version
except Exception as e: # pylint: disable=broad-except
- logging.error(
- f'failed to check lock is conflict, exception: {e}')
+ logging.error(f'failed to check lock is conflict, exception: {e}')
return False
def update_version(self) -> bool:
@@ -86,8 +83,7 @@ def update_version(self) -> bool:
with get_session(self.db_engine) as session:
try:
- lock = session.query(OptimisticLock).filter_by(
- name=self._name).first()
+ lock = session.query(OptimisticLock).filter_by(name=self._name).first()
lock.version = self._version + 1
session.commit()
return True
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/op_locker_test.py b/web_console_v2/api/fedlearner_webconsole/composer/op_locker_test.py
new file mode 100644
index 000000000..6680e8d45
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/op_locker_test.py
@@ -0,0 +1,47 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+import logging
+import sys
+import unittest
+
+from fedlearner_webconsole.composer.models import OptimisticLock
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.op_locker import OpLocker
+from testing.common import BaseTestCase
+
+
+class OpLockTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ STORAGE_ROOT = '/tmp'
+ START_SCHEDULER = False
+
+ def test_lock(self):
+ lock = OpLocker('test', db.engine).try_lock()
+ self.assertEqual(True, lock.is_latest_version(), 'should be latest version')
+
+ # update database version
+ with db.session_scope() as session:
+ new_lock = session.query(OptimisticLock).filter_by(name=lock.name).first()
+ new_lock.version = new_lock.version + 1
+ session.commit()
+ self.assertEqual(False, lock.is_latest_version(), 'should not be latest version')
+
+
+if __name__ == '__main__':
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/pipeline.py b/web_console_v2/api/fedlearner_webconsole/composer/pipeline.py
new file mode 100644
index 000000000..a11cbe7cc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/pipeline.py
@@ -0,0 +1,114 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from concurrent.futures import Future
+from typing import Dict
+
+from sqlalchemy import func
+from sqlalchemy.engine import Engine
+
+from fedlearner_webconsole.composer.context import PipelineContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus, SchedulerRunner
+from fedlearner_webconsole.composer.op_locker import OpLocker
+from fedlearner_webconsole.composer.thread_reaper import ThreadReaper
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto import composer_pb2
+
+
+class PipelineExecutor(object):
+
+ def __init__(self, thread_reaper: ThreadReaper, db_engine: Engine, runner_fns: Dict[str, IRunnerV2]):
+ self.thread_reaper = thread_reaper
+ self.db_engine = db_engine
+ self._runner_fns = runner_fns
+ self._running_workers = {}
+
+ def run(self, runner_id: int) -> bool:
+ """Starts runner by submitting it to the thread reaper."""
+ with db.session_scope() as session:
+ runner: SchedulerRunner = session.query(SchedulerRunner).get(runner_id)
+ pipeline: composer_pb2.Pipeline = runner.get_pipeline()
+ if runner.status not in [RunnerStatus.RUNNING.value]:
+ return False
+ if self.thread_reaper.is_running(runner_id) or self.thread_reaper.is_full():
+ return False
+ pipeline_context = PipelineContext.build(pipeline=pipeline, data=runner.get_context())
+ current_runner_context = pipeline_context.get_current_runner_context()
+ runner_fn = self._runner_fns[current_runner_context.input.runner_type]()
+ return self.thread_reaper.submit(
+ runner_id=runner_id,
+ fn=runner_fn,
+ context=current_runner_context,
+ done_callback=self._runner_done_callback,
+ )
+
+ def _runner_done_callback(self, runner_id: int, fu: Future):
+ """Callback when one runner finishes.
+
+ The callback will only update the status, other workers in pipeline will be
+ triggered by the executor in the next round check."""
+ with db.session_scope() as session:
+ runner = session.query(SchedulerRunner).get(runner_id)
+ pipeline = runner.get_pipeline()
+ if pipeline.version != 2:
+ return
+ pipeline_context = PipelineContext.build(pipeline=pipeline, data=runner.get_context())
+ current_runner_context = pipeline_context.get_current_runner_context()
+ output = None
+ try:
+ status, output = fu.result()
+ # Defensively confirming the status
+ if status == RunnerStatus.RUNNING:
+ return
+ pipeline_context.data.outputs[current_runner_context.index].MergeFrom(output)
+ if status == RunnerStatus.DONE:
+ if current_runner_context.index == len(pipeline.queue) - 1:
+ # the whole pipeline is done
+ runner.status = RunnerStatus.DONE.value
+ runner.end_at = func.now()
+ else:
+ # mark to run next
+ pipeline_context.run_next()
+ elif status == RunnerStatus.FAILED:
+ runner.status = RunnerStatus.FAILED.value
+ runner.end_at = func.now()
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception(f'[PipelineExecutor] failed to run {runner.id}')
+ runner.status = RunnerStatus.FAILED.value
+ runner.end_at = func.now()
+ pipeline_context.data.outputs[current_runner_context.index].error_message = str(e)
+ runner.set_context(pipeline_context.data)
+
+ logging.info(f'[pipeline-executor] update runner, status: {runner.status}, '
+ f'pipeline: {runner.pipeline}, '
+ f'output: {output}, context: {runner.context}')
+ # Retry 3 times
+ for _ in range(3):
+ try:
+ lock_name = f'update_running_runner_{runner_id}_lock'
+ lock = OpLocker(lock_name, self.db_engine).try_lock()
+ if lock.is_latest_version():
+ if lock.update_version():
+ session.commit()
+ break
+ else:
+ logging.error(f'[composer] {lock_name} is outdated, ignore updates to database')
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'[composer] failed to update running runner status, exception: {e}')
+ else:
+ # Failed 3 times
+ session.rollback()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/pipeline_test.py b/web_console_v2/api/fedlearner_webconsole/composer/pipeline_test.py
new file mode 100644
index 000000000..33c75cde6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/pipeline_test.py
@@ -0,0 +1,146 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime
+from unittest.mock import patch
+
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import SchedulerRunner, RunnerStatus
+from fedlearner_webconsole.composer.pipeline import PipelineExecutor
+from fedlearner_webconsole.composer.thread_reaper import ThreadReaper
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto import composer_pb2
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, PipelineContextData, RunnerOutput
+from testing.composer.common import TestRunner
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.fake_time_patcher import FakeTimePatcher
+
+_RUNNER_FNS = {
+ ItemType.TASK.value: TestRunner,
+}
+
+
+class PipelineExecutorTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.thread_reaper = ThreadReaper(worker_num=1)
+ self.executor = PipelineExecutor(self.thread_reaper, db.engine, _RUNNER_FNS)
+
+ self.time_patcher = FakeTimePatcher()
+ self.time_patcher.start(datetime(2012, 1, 14, 12, 0, 5))
+
+ def tearDown(self):
+ self.time_patcher.stop()
+ self.thread_reaper.stop(wait=True)
+ super().tearDown()
+
+ def test_run_completed(self):
+ runner = SchedulerRunner(item_id=123, status=RunnerStatus.RUNNING.value)
+ runner.set_pipeline(
+ composer_pb2.Pipeline(version=2, name='test pipeline',
+ queue=[RunnerInput(runner_type=ItemType.TASK.value)]))
+ with db.session_scope() as session:
+ session.add(runner)
+ session.commit()
+ self.executor.run(runner.id)
+ self.time_patcher.interrupt(60)
+ with db.session_scope() as session:
+ runner = session.query(SchedulerRunner).get(runner.id)
+ self.assertEqual(runner.status, RunnerStatus.DONE.value)
+ self.assertEqual(runner.get_context(), PipelineContextData(current_runner=0, outputs={0: RunnerOutput()}))
+
+ def test_run_failed(self):
+ runner = SchedulerRunner(item_id=123, status=RunnerStatus.RUNNING.value)
+ runner.set_pipeline(
+ composer_pb2.Pipeline(
+ version=2,
+ name='test failed pipeline',
+ queue=[
+ RunnerInput(runner_type=ItemType.TASK.value),
+ RunnerInput(runner_type=ItemType.TASK.value),
+ RunnerInput(runner_type=ItemType.TASK.value),
+ # Failed one
+ RunnerInput(runner_type=ItemType.TASK.value),
+ RunnerInput(runner_type=ItemType.TASK.value),
+ ]))
+ runner.set_context(composer_pb2.PipelineContextData(current_runner=3))
+ with db.session_scope() as session:
+ session.add(runner)
+ session.commit()
+ self.executor.run(runner.id)
+ self.time_patcher.interrupt(60)
+ with db.session_scope() as session:
+ runner = session.query(SchedulerRunner).get(runner.id)
+ self.assertEqual(runner.status, RunnerStatus.FAILED.value)
+ self.assertEqual(
+ runner.get_context(),
+ PipelineContextData(current_runner=3, outputs={3: RunnerOutput(error_message='index is 3')}))
+
+ @patch('testing.composer.common.TestRunner.run')
+ def test_run_exception(self, mock_run):
+
+ def fake_run(*args, **kwargs):
+ raise RuntimeError('fake exception')
+
+ mock_run.side_effect = fake_run
+
+ runner = SchedulerRunner(item_id=666, status=RunnerStatus.RUNNING.value)
+ runner.set_pipeline(
+ composer_pb2.Pipeline(
+ version=2,
+ name='test failed pipeline',
+ queue=[
+ # Exception one
+ RunnerInput(runner_type=ItemType.TASK.value),
+ ]))
+ runner.set_context(composer_pb2.PipelineContextData(current_runner=0))
+ with db.session_scope() as session:
+ session.add(runner)
+ session.commit()
+ self.executor.run(runner.id)
+ self.time_patcher.interrupt(60)
+ with db.session_scope() as session:
+ runner = session.query(SchedulerRunner).get(runner.id)
+ self.assertEqual(runner.status, RunnerStatus.FAILED.value)
+ self.assertEqual(
+ runner.get_context(),
+ PipelineContextData(current_runner=0, outputs={0: RunnerOutput(error_message='fake exception')}))
+
+ def test_run_second_runner(self):
+ runner = SchedulerRunner(item_id=123, status=RunnerStatus.RUNNING.value)
+ runner.set_pipeline(
+ composer_pb2.Pipeline(version=2,
+ name='test running pipeline',
+ queue=[
+ RunnerInput(runner_type=ItemType.TASK.value),
+ RunnerInput(runner_type=ItemType.TASK.value),
+ RunnerInput(runner_type=ItemType.TASK.value),
+ ]))
+ runner.set_context(composer_pb2.PipelineContextData(current_runner=1))
+ with db.session_scope() as session:
+ session.add(runner)
+ session.commit()
+ self.executor.run(runner.id)
+ self.time_patcher.interrupt(60)
+ with db.session_scope() as session:
+ runner = session.query(SchedulerRunner).get(runner.id)
+ self.assertEqual(runner.status, RunnerStatus.RUNNING.value)
+ self.assertEqual(runner.get_context(), PipelineContextData(current_runner=2, outputs={1: RunnerOutput()}))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/runner.py b/web_console_v2/api/fedlearner_webconsole/composer/runner.py
index 46d87ec20..2807e7ab3 100644
--- a/web_console_v2/api/fedlearner_webconsole/composer/runner.py
+++ b/web_console_v2/api/fedlearner_webconsole/composer/runner.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,78 +13,45 @@
# limitations under the License.
# coding: utf-8
-import datetime
import logging
-import random
import sys
-import time
-from typing import Tuple
-from fedlearner_webconsole.composer.interface import IItem, IRunner, ItemType
-from fedlearner_webconsole.composer.models import Context, RunnerStatus, \
- SchedulerRunner
-from fedlearner_webconsole.dataset.data_pipeline import DataPipelineRunner
-from fedlearner_webconsole.db import get_session
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.dataset.batch_stats import BatchStatsRunner
+from fedlearner_webconsole.dataset.scheduler.dataset_long_period_scheduler import DatasetLongPeriodScheduler
+from fedlearner_webconsole.dataset.scheduler.dataset_short_period_scheduler import DatasetShortPeriodScheduler
+from fedlearner_webconsole.job.scheduler import JobScheduler
+from fedlearner_webconsole.project.project_scheduler import ScheduleProjectRunner
+from fedlearner_webconsole.serving.runners import ModelSignatureParser, QueryParticipantStatusRunner, UpdateModelRunner
from fedlearner_webconsole.workflow.cronjob import WorkflowCronJob
-
-
-class MemoryItem(IItem):
- def __init__(self, task_id: int):
- self.id = task_id
-
- def type(self) -> ItemType:
- return ItemType.MEMORY
-
- def get_id(self) -> int:
- return self.id
-
-
-class MemoryRunner(IRunner):
- def __init__(self, task_id: int):
- """Runner Example
-
- Args:
- task_id: required
- """
- self.task_id = task_id
- self._start_at = None
-
- def start(self, context: Context):
- # NOTE: in this method, context.data can only be getter,
- # don't modify context
- data = context.data.get(str(self.task_id), 'EMPTY')
- logging.info(f'[memory_runner] {self.task_id} started, data: {data}')
- self._start_at = datetime.datetime.utcnow()
-
- def result(self, context: Context) -> Tuple[RunnerStatus, dict]:
- time.sleep(2)
- now = datetime.datetime.utcnow()
- timeout = random.randint(0, 10)
- # mock timeout
- if self._start_at is not None and self._start_at + datetime.timedelta(
- seconds=timeout) < now:
- # kill runner
- logging.info(f'[memory_runner] {self.task_id} is timeout, '
- f'start at: {self._start_at}')
- return RunnerStatus.FAILED, {}
-
- # use `get_session` to query database
- with get_session(context.db_engine) as session:
- count = session.query(SchedulerRunner).count()
- # write data to context
- context.set_data(f'is_done_{self.task_id}', {
- 'status': 'OK',
- 'count': count
- })
- return RunnerStatus.DONE, {}
+from fedlearner_webconsole.workflow.workflow_scheduler import ScheduleWorkflowRunner
+from fedlearner_webconsole.cleanup.cleaner_cronjob import CleanupCronJob
+from fedlearner_webconsole.mmgr.cronjob import ModelTrainingCronJob
+from fedlearner_webconsole.mmgr.scheduler import ModelJobSchedulerRunner, ModelJobGroupSchedulerRunner, \
+ ModelJobGroupLongPeriodScheduler
+from fedlearner_webconsole.tee.runners import TeeCreateRunner, TeeResourceCheckRunner
def global_runner_fn():
# register runner_fn
runner_fn = {
- ItemType.MEMORY.value: MemoryRunner,
ItemType.WORKFLOW_CRON_JOB.value: WorkflowCronJob,
- ItemType.DATA_PIPELINE.value: DataPipelineRunner,
+ ItemType.BATCH_STATS.value: BatchStatsRunner,
+ ItemType.SERVING_SERVICE_PARSE_SIGNATURE.value: ModelSignatureParser,
+ ItemType.SERVING_SERVICE_QUERY_PARTICIPANT_STATUS.value: QueryParticipantStatusRunner,
+ ItemType.SERVING_SERVICE_UPDATE_MODEL.value: UpdateModelRunner,
+ ItemType.SCHEDULE_WORKFLOW.value: ScheduleWorkflowRunner,
+ ItemType.SCHEDULE_JOB.value: JobScheduler,
+ ItemType.CLEANUP_CRON_JOB.value: CleanupCronJob,
+ ItemType.MODEL_TRAINING_CRON_JOB.value: ModelTrainingCronJob,
+ ItemType.TEE_CREATE_RUNNER.value: TeeCreateRunner,
+ ItemType.TEE_RESOURCE_CHECK_RUNNER.value: TeeResourceCheckRunner,
+ ItemType.SCHEDULE_PROJECT.value: ScheduleProjectRunner,
+ ItemType.DATASET_LONG_PERIOD_SCHEDULER.value: DatasetLongPeriodScheduler,
+ ItemType.DATASET_SHORT_PERIOD_SCHEDULER.value: DatasetShortPeriodScheduler,
+ ItemType.SCHEDULE_MODEL_JOB.value: ModelJobSchedulerRunner,
+ ItemType.SCHEDULE_MODEL_JOB_GROUP.value: ModelJobGroupSchedulerRunner,
+ ItemType.SCHEDULE_LONG_PERIOD_MODEL_JOB_GROUP.value: ModelJobGroupLongPeriodScheduler,
}
for item in ItemType:
if item.value in runner_fn or item == ItemType.TASK:
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/runner_cache.py b/web_console_v2/api/fedlearner_webconsole/composer/runner_cache.py
deleted file mode 100644
index bd93e8bac..000000000
--- a/web_console_v2/api/fedlearner_webconsole/composer/runner_cache.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import logging
-import threading
-
-from fedlearner_webconsole.composer.interface import IRunner
-
-
-class RunnerCache(object):
- def __init__(self, runner_fn: dict):
- self._lock = threading.Lock()
- self._cache = {} # id:name => obj
- self.runner_fn = runner_fn
-
- def find_runner(self, runner_id: int, runner_name: str) -> IRunner:
- """Find runner
-
- Args:
- runner_id: id in runner table
- runner_name: {item_type}_{item_id}
- """
- with self._lock:
- key = self.cache_key(runner_id, runner_name)
- obj = self._cache.get(key, None)
- if obj:
- return obj
- item_type, item_id = runner_name.rsplit('_', 1)
- if item_type not in self.runner_fn:
- logging.error(
- f'failed to find item_type {item_type} in runner_fn, '
- f'please register it in global_runner_fn')
- raise ValueError(f'unknown item_type {item_type} in runner')
- obj = self.runner_fn[item_type](int(item_id))
- self._cache[key] = obj
- return obj
-
- def del_runner(self, runner_id: int, runner_name: str):
- """Delete runner
-
- Args:
- runner_id: id in runner table
- runner_name: {item_type}_{item_id}
- """
- with self._lock:
- key = self.cache_key(runner_id, runner_name)
- del self._cache[key]
-
- @staticmethod
- def cache_key(runner_id: int, runner_name: str) -> str:
- return f'{runner_id}:{runner_name}'
-
- @property
- def data(self) -> dict:
- with self._lock:
- return self._cache
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/strategy.py b/web_console_v2/api/fedlearner_webconsole/composer/strategy.py
new file mode 100644
index 000000000..87173a36d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/strategy.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC, abstractmethod
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.composer.models import SchedulerItem, SchedulerRunner, RunnerStatus
+
+
+class RunnerStrategy(ABC):
+
+ def __init__(self, session: Session):
+ self.session = session
+
+ @abstractmethod
+ def should_run(self, item: SchedulerItem) -> bool:
+ """Checks if the scheduler item should run or not."""
+
+
+class SingletonStrategy(RunnerStrategy):
+ """A strategy to make sure there is only one runner instance for the scheduler item.
+
+ 1. For normal scheduler item, there is no difference with normal strategy.
+ 2. For cron jobs, there will be only one scheduler runner for the item.
+ """
+
+ def should_run(self, item: SchedulerItem) -> bool:
+ if not item.need_run():
+ return False
+
+ if item.cron_config:
+ ongoing_count = self.session.query(func.count(SchedulerRunner.id)).filter(
+ SchedulerRunner.item_id == item.id,
+ SchedulerRunner.status.in_([RunnerStatus.INIT.value, RunnerStatus.RUNNING.value])).scalar()
+ if ongoing_count > 0:
+ return False
+
+ return True
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/strategy_test.py b/web_console_v2/api/fedlearner_webconsole/composer/strategy_test.py
new file mode 100644
index 000000000..03f65a786
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/strategy_test.py
@@ -0,0 +1,78 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime
+from unittest.mock import patch, Mock
+
+from fedlearner_webconsole.composer.models import SchedulerItem, ItemStatus, RunnerStatus, SchedulerRunner
+from fedlearner_webconsole.composer.strategy import SingletonStrategy
+from fedlearner_webconsole.db import db
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class SingletonStrategyTest(NoWebServerTestCase):
+
+ @patch.object(SchedulerItem, 'need_run')
+ def test_should_run_normal_item(self, mock_need_run: Mock):
+ with db.session_scope() as session:
+ item = SchedulerItem(id=123,
+ name='test normal item',
+ status=ItemStatus.ON.value,
+ created_at=datetime(2021, 9, 1, 10, 10))
+ strategy = SingletonStrategy(session)
+ mock_need_run.return_value = True
+ self.assertTrue(strategy.should_run(item))
+ # No need to run
+ mock_need_run.return_value = False
+ self.assertFalse(strategy.should_run(item))
+
+ @patch.object(SchedulerItem, 'need_run')
+ def test_should_run_cron_item(self, mock_need_run: Mock):
+ item_id = 123123
+ runner_id = 7644
+ with db.session_scope() as session:
+ item = SchedulerItem(
+ id=item_id,
+ name='test cron item',
+ # Runs every 30 minutes
+ cron_config='*/30 * * * *',
+ created_at=datetime(2022, 1, 1, 10, 0))
+ session.add(item)
+ session.commit()
+ with db.session_scope() as session:
+ strategy = SingletonStrategy(session)
+ mock_need_run.return_value = False
+ self.assertFalse(strategy.should_run(item))
+ mock_need_run.return_value = True
+ self.assertTrue(strategy.should_run(item))
+ runner = SchedulerRunner(id=runner_id, item_id=item_id, status=RunnerStatus.RUNNING.value)
+ session.add(runner)
+ session.commit()
+ with db.session_scope() as session:
+ # Already one running runner, so no new one will be generated.
+ item = session.query(SchedulerItem).get(item_id)
+ self.assertFalse(strategy.should_run(item))
+ runner = session.query(SchedulerRunner).get(runner_id)
+ runner.status = RunnerStatus.DONE.value
+ session.commit()
+ with db.session_scope() as session:
+ # All runners are done, so a new one will be there.
+ item = session.query(SchedulerItem).get(item_id)
+ self.assertTrue(strategy.should_run(item))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/thread_reaper.py b/web_console_v2/api/fedlearner_webconsole/composer/thread_reaper.py
index e63a1ae69..b97510228 100644
--- a/web_console_v2/api/fedlearner_webconsole/composer/thread_reaper.py
+++ b/web_console_v2/api/fedlearner_webconsole/composer/thread_reaper.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,14 @@
import threading
from concurrent.futures import Future
from concurrent.futures.thread import ThreadPoolExecutor
+from typing import Callable, Optional
-from fedlearner_webconsole.composer.models import Context
-from fedlearner_webconsole.composer.interface import IRunner
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
class ThreadReaper(object):
+
def __init__(self, worker_num: int):
"""ThreadPool with battery
@@ -33,26 +35,51 @@ def __init__(self, worker_num: int):
self.lock = threading.RLock()
self.worker_num = worker_num
self.running_worker_num = 0
+ self._running_workers = {}
self._thread_pool = ThreadPoolExecutor(max_workers=worker_num)
- def enqueue(self, name: str, fn: IRunner, context: Context) -> bool:
+ def is_running(self, runner_id: int) -> bool:
+ with self.lock:
+ return runner_id in self._running_workers
+
+ def submit(self,
+ runner_id: int,
+ fn: IRunnerV2,
+ context: RunnerContext,
+ done_callback: Optional[Callable[[int, Future], None]] = None) -> bool:
if self.is_full():
return False
- logging.info(f'[thread_reaper] enqueue {name}')
+
+ def full_done_callback(fu: Future):
+ # The order matters, as we need to update the status at the last.
+ if done_callback:
+ done_callback(runner_id, fu)
+ self._track_status(runner_id, fu)
+
+ logging.info(f'[thread_reaper] enqueue {runner_id}')
with self.lock:
+ if runner_id in self._running_workers:
+ logging.warning(f'f[thread_reaper] {runner_id} already enqueued')
+ return False
self.running_worker_num += 1
- fu = self._thread_pool.submit(fn.start, context=context)
- fu.add_done_callback(self._track_status)
+ self._running_workers[runner_id] = fn
+ fu = self._thread_pool.submit(fn.run, context=context)
+ fu.add_done_callback(full_done_callback)
return True
- def _track_status(self, fu: Future):
+ def _track_status(self, runner_id: int, fu: Future):
with self.lock:
self.running_worker_num -= 1
- logging.info(f'this job is done, result: {fu.result()}')
+ # Safely removing
+ self._running_workers.pop(runner_id, None)
+ try:
+ logging.info(f'f------Job {runner_id} is done------')
+ logging.info(f'result: {fu.result()}')
+ except Exception as e: # pylint: disable=broad-except
+ logging.info(f'error: {str(e)}')
if self.running_worker_num < 0:
- logging.error(
- f'[thread_reaper] something wrong, should be non-negative, '
- f'val: f{self.running_worker_num}')
+ logging.error(f'[thread_reaper] something wrong, should be non-negative, '
+ f'val: f{self.running_worker_num}')
def is_full(self) -> bool:
with self.lock:
diff --git a/web_console_v2/api/fedlearner_webconsole/composer/thread_reaper_test.py b/web_console_v2/api/fedlearner_webconsole/composer/thread_reaper_test.py
new file mode 100644
index 000000000..7d2126a7c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/composer/thread_reaper_test.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+import logging
+from concurrent.futures import Future
+
+import sys
+import unittest
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.thread_reaper import ThreadReaper
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput
+from testing.composer.common import TestRunner
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.fake_time_patcher import FakeTimePatcher
+
+
+class ThreadReaperTest(NoWebServerTestCase):
+
+ class Config(NoWebServerTestCase.Config):
+ STORAGE_ROOT = '/tmp'
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ self.fake_time_patcher = FakeTimePatcher()
+ self.fake_time_patcher.start()
+
+ def tearDown(self):
+ self.fake_time_patcher.stop()
+ return super().tearDown()
+
+ def test_submit(self):
+ thread_reaper = ThreadReaper(worker_num=2)
+ runner = TestRunner()
+ submitted = thread_reaper.submit(
+ runner_id=123,
+ fn=runner,
+ context=RunnerContext(0, RunnerInput()),
+ )
+ self.assertTrue(submitted)
+ self.assertTrue(thread_reaper.is_running(123))
+ self.assertFalse(thread_reaper.is_full())
+ # Submit again
+ submitted = thread_reaper.submit(
+ runner_id=123,
+ fn=runner,
+ context=RunnerContext(0, RunnerInput()),
+ )
+ self.assertFalse(submitted)
+ self.assertFalse(thread_reaper.is_full())
+ submitted = thread_reaper.submit(
+ runner_id=3333,
+ fn=runner,
+ context=RunnerContext(1, RunnerInput()),
+ )
+ self.assertTrue(submitted)
+ self.assertTrue(thread_reaper.is_full())
+ self.fake_time_patcher.interrupt(5)
+ self.assertFalse(thread_reaper.is_running(123))
+ self.assertFalse(thread_reaper.is_full())
+ thread_reaper.stop(wait=True)
+
+ def test_submit_with_exception(self):
+ thread_reaper = ThreadReaper(worker_num=1)
+ error = None
+ runner_id = None
+
+ def done_callback(rid: int, fu: Future):
+ nonlocal error, runner_id
+ try:
+ runner_id = rid
+ fu.result()
+ except RuntimeError as e:
+ error = str(e)
+
+ runner = TestRunner(with_exception=True)
+ thread_reaper.submit(runner_id=123,
+ fn=runner,
+ context=RunnerContext(1, RunnerInput()),
+ done_callback=done_callback)
+
+ self.fake_time_patcher.interrupt(5)
+ self.assertEqual(runner.context.index, 1)
+ self.assertEqual(runner_id, 123)
+ self.assertEqual(error, 'fake error')
+ thread_reaper.stop(wait=True)
+
+
+if __name__ == '__main__':
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/dataset/BUILD.bazel
new file mode 100644
index 000000000..f3f575f95
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/BUILD.bazel
@@ -0,0 +1,446 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "batch_stats_lib",
+ srcs = [
+ "batch_stats.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":data_path_lib",
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:hooks_lib",
+ ],
+)
+
+py_test(
+ name = "batch_stats_lib_test",
+ size = "small",
+ srcs = [
+ "batch_stats_test.py",
+ ],
+ imports = ["../.."],
+ main = "batch_stats_test.py",
+ deps = [
+ ":batch_stats_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "controllers_lib",
+ srcs = [
+ "controllers.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:resource_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:transaction_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "controllers_lib_test",
+ size = "small",
+ srcs = [
+ "controllers_test.py",
+ ],
+ imports = ["../.."],
+ main = "controllers_test.py",
+ deps = [
+ ":controllers_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_library(
+ name = "delete_dependency_lib",
+ srcs = ["delete_dependency.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ ],
+)
+
+py_test(
+ name = "delete_dependency_lib_test",
+ size = "small",
+ srcs = [
+ "delete_dependency_test.py",
+ ],
+ imports = ["../.."],
+ main = "delete_dependency_test.py",
+ deps = [
+ ":delete_dependency_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "common_lib",
+ srcs = [
+ "consts.py",
+ "dataset_directory.py",
+ "meta_data.py",
+ "util.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "@common_python_dateutil//:pkg",
+ "@common_python_slugify//:pkg",
+ ],
+)
+
+py_test(
+ name = "dataset_directory_test",
+ size = "small",
+ srcs = [
+ "dataset_directory_test.py",
+ ],
+ imports = ["../.."],
+ main = "dataset_directory_test.py",
+ deps = [
+ ":common_lib",
+ ],
+)
+
+py_test(
+ name = "meta_data_test",
+ size = "small",
+ srcs = [
+ "meta_data_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "meta_data_test.py",
+ deps = [
+ ":common_lib",
+ "//web_console_v2/api:envs_lib",
+ ],
+)
+
+py_test(
+ name = "util_test",
+ size = "small",
+ srcs = [
+ "util_test.py",
+ ],
+ imports = ["../.."],
+ main = "util_test.py",
+ deps = [
+ ":common_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ visibility = ["//visibility:public"],
+ deps = [
+ "common_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "medium",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "metrics_lib",
+ srcs = [
+ "metrics.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ ],
+)
+
+py_test(
+ name = "metrics_lib_test",
+ size = "small",
+ srcs = [
+ "metrics_test.py",
+ ],
+ imports = ["../.."],
+ main = "metrics_test.py",
+ deps = [
+ ":metrics_lib",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = [
+ "auth_service.py",
+ "services.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ ":delete_dependency_lib",
+ ":filter_funcs_lib",
+ ":metrics_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/cleanup:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/cleanup:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "services_lib_test",
+ size = "medium",
+ srcs = [
+ "services_test.py",
+ ],
+ imports = ["../.."],
+ main = "services_test.py",
+ deps = [
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ ],
+)
+
+py_test(
+ name = "auth_service_test",
+ size = "small",
+ srcs = [
+ "auth_service_test.py",
+ ],
+ imports = ["../.."],
+ main = "auth_service_test.py",
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ ":controllers_lib",
+ ":filter_funcs_lib",
+ ":local_controllers_lib",
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:resource_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:sorting_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "large",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_library(
+ name = "local_controllers_lib",
+ srcs = ["local_controllers.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "local_controllers_lib_test",
+ size = "medium",
+ srcs = [
+ "local_controllers_test.py",
+ ],
+ imports = ["../.."],
+ main = "local_controllers_test.py",
+ deps = [
+ ":local_controllers_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "data_path_lib",
+ srcs = ["data_path.py"],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ ],
+)
+
+py_test(
+ name = "data_path_lib_test",
+ size = "small",
+ srcs = [
+ "data_path_test.py",
+ ],
+ imports = ["../.."],
+ main = "data_path_test.py",
+ deps = [
+ ":data_path_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "filter_funcs_lib",
+ srcs = ["filter_funcs.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "filter_funcs_lib_test",
+ size = "small",
+ srcs = ["filter_funcs_test.py"],
+ imports = ["../.."],
+ main = "filter_funcs_test.py",
+ deps = [
+ ":filter_funcs_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/apis.py b/web_console_v2/api/fedlearner_webconsole/dataset/apis.py
index 865f41a26..f195b2f50 100644
--- a/web_console_v2/api/fedlearner_webconsole/dataset/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/apis.py
@@ -1,236 +1,2242 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
-# coding: utf-8
# pylint: disable=raise-missing-from
-import os
+from datetime import timedelta
+import logging
+from typing import Any, Dict, Optional, List
+from urllib.parse import urlparse
-from datetime import datetime, timezone
from http import HTTPStatus
+from flask_restful import Resource, Api
+from webargs.flaskparser import use_kwargs, use_args
+from marshmallow.exceptions import ValidationError
+from marshmallow import post_load, validate, fields
+from marshmallow.schema import Schema
+from google.protobuf.json_format import ParseDict, ParseError
+from envs import Envs
-from flask import current_app, request
-from flask_restful import Resource, Api, reqparse
-from slugify import slugify
-
-from fedlearner_webconsole.dataset.models import (Dataset, DatasetType,
- BatchState, DataBatch)
-from fedlearner_webconsole.dataset.services import DatasetService
-from fedlearner_webconsole.exceptions import (InvalidArgumentException,
- NotFoundException)
-from fedlearner_webconsole.db import db_handler as db
-from fedlearner_webconsole.proto import dataset_pb2
-from fedlearner_webconsole.scheduler.scheduler import scheduler
-from fedlearner_webconsole.utils.decorators import jwt_required
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.dataset.controllers import DatasetJobController
+from fedlearner_webconsole.dataset.job_configer.dataset_job_configer import DatasetJobConfiger
+from fedlearner_webconsole.dataset.job_configer.base_configer import set_variable_value_to_job_config
+from fedlearner_webconsole.dataset.local_controllers import DatasetJobStageLocalController
+from fedlearner_webconsole.dataset.models import (DataBatch, DataSource, DataSourceType, Dataset,
+ DatasetJobSchedulerState, ResourceState, DatasetJob, DatasetJobKind,
+ DatasetJobStage, DatasetJobState, ImportType, StoreFormat,
+ DatasetType, DatasetSchemaChecker, DatasetKindV2, DatasetFormat)
+from fedlearner_webconsole.dataset.services import (DatasetJobService, DatasetService, DataSourceService,
+ DatasetJobStageService)
+from fedlearner_webconsole.dataset.util import get_export_dataset_name, add_default_url_scheme, is_streaming_folder, \
+ CronInterval
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.dataset.filter_funcs import dataset_auth_status_filter_op_in, dataset_format_filter_op_in, \
+ dataset_format_filter_op_equal, dataset_publish_frontend_filter_op_equal
+from fedlearner_webconsole.exceptions import InvalidArgumentException, MethodNotAllowedException, NoAccessException, \
+ NotFoundException
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs, TimeRange
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterOp
+from fedlearner_webconsole.proto.review_pb2 import TicketDetails, TicketType
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+from fedlearner_webconsole.rpc.v2.job_service_client import JobServiceClient
+from fedlearner_webconsole.rpc.v2.resource_service_client import ResourceServiceClient
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required, input_validator
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils import filtering, sorting
+from fedlearner_webconsole.utils.flask_utils import FilterExpField, make_flask_response
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.utils.file_tree import FileTreeBuilder
+from fedlearner_webconsole.workflow.models import WorkflowExternalState
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.flag.models import Flag
+
+_DEFAULT_DATA_SOURCE_PREVIEW_FILE_NUM = 3
+
+
+def _path_authority_validator(path: str):
+ """Validate data_source path
+ this func is used to forbiden access to local filesystem
+ 1. if path is not nfs, pass
+ 2. if path is nfs and belongs to STORAGE_ROOT, pass
+ 3. if path is nfs but doesn's belong to STORAGE_ROOT, raise ValidationError
+ """
+ path = path.strip()
+ authority_path = add_default_url_scheme(Envs.STORAGE_ROOT)
+ if not authority_path.endswith('/'):
+ authority_path += '/'
+ validate_path = add_default_url_scheme(path)
+ if _parse_data_source_url(validate_path).type != DataSourceType.FILE.value:
+ return
+ if not validate_path.startswith(authority_path):
+ raise ValidationError(f'no access to unauchority path {validate_path}!')
+
+
+def _export_path_validator(path: str):
+ path = path.strip()
+ if len(path) == 0:
+ raise ValidationError('export path is empty!')
+ fm = FileManager()
+ if not fm.can_handle(path):
+ raise ValidationError('cannot handle export path!')
+ if not fm.isdir(path):
+ raise ValidationError('export path is not exist!')
+ _path_authority_validator(path)
+
+
+def _parse_data_source_url(data_source_url: str) -> dataset_pb2.DataSource:
+ data_source_url = data_source_url.strip()
+ data_source_url = add_default_url_scheme(data_source_url)
+ url_parser = urlparse(data_source_url)
+ data_source_type = url_parser.scheme
+ # source_type must in DataSourceType
+ if data_source_type not in [o.value for o in DataSourceType]:
+ raise ValidationError(f'{data_source_type} is not a supported data_source type')
+ return dataset_pb2.DataSource(
+ type=data_source_type,
+ url=data_source_url,
+ is_user_upload=False,
+ is_user_export=False,
+ )
+
+
+def _validate_data_source(data_source_url: str, dataset_type: DatasetType):
+ fm = FileManager()
+ if not fm.can_handle(path=data_source_url):
+ raise InvalidArgumentException(f'invalid data_source_url: {data_source_url}')
+ if not fm.isdir(path=data_source_url):
+ raise InvalidArgumentException(f'cannot connect to data_source_url: {data_source_url}')
+ if dataset_type == DatasetType.STREAMING:
+ res, message = is_streaming_folder(data_source_url)
+ if not res:
+ raise InvalidArgumentException(message)
+
-_FORMAT_ERROR_MESSAGE = '{} is empty'
+class DatasetJobConfigParameter(Schema):
+ dataset_uuid = fields.Str(required=False)
+ dataset_id = fields.Integer(required=False)
+ variables = fields.List(fields.Dict())
+ @post_load
+ def make_dataset_job_config(self, item: Dict[str, Any], **kwargs) -> dataset_pb2.DatasetJobConfig:
+ del kwargs # this variable is not needed for now
-def _get_dataset_path(dataset_name):
- root_dir = current_app.config.get('STORAGE_ROOT')
- prefix = datetime.now().strftime('%Y%m%d_%H%M%S')
- # Builds a path for dataset according to the dataset name
- # Example: '/data/dataset/20210305_173312_test-dataset
- return f'{root_dir}/dataset/{prefix}_{slugify(dataset_name)[:32]}'
+ try:
+ dataset_job_config = dataset_pb2.DatasetJobConfig()
+ return ParseDict(item, dataset_job_config)
+ except ParseError as err:
+ raise ValidationError(message='failed to convert dataset_job_config',
+ field_name='global_configs',
+ data=err.args)
+
+
+class DatasetJobParameter(Schema):
+ global_configs = fields.Dict(required=True, keys=fields.Str(), values=fields.Nested(DatasetJobConfigParameter()))
+ dataset_job_kind = fields.Str(required=False,
+ validate=validate.OneOf([o.value for o in DatasetJobKind]),
+ load_default='')
+
+ @post_load
+ def make_dataset_job(self, item: Dict[str, Any], **kwargs) -> dataset_pb2.DatasetJob:
+ del kwargs # this variable is not needed for now
+
+ global_configs = item['global_configs']
+ global_configs_pb = DatasetJobGlobalConfigs()
+ for domain_name, job_config in global_configs.items():
+ global_configs_pb.global_configs[get_pure_domain_name(domain_name)].MergeFrom(job_config)
+
+ return dataset_pb2.DatasetJob(kind=item['dataset_job_kind'], global_configs=global_configs_pb)
+
+
+class DatasetJobVariablesParameter(Schema):
+ variables = fields.List(fields.Dict())
+
+ @post_load
+ def make_dataset_job_config(self, item: Dict[str, Any], **kwargs) -> dataset_pb2.DatasetJobConfig:
+ del kwargs # this variable is not needed for now
+
+ try:
+ dataset_job_config = dataset_pb2.DatasetJobConfig()
+ return ParseDict(item, dataset_job_config)
+ except ParseError as err:
+ raise ValidationError(message='failed to convert dataset_job_config',
+ field_name='dataset_job_config',
+ data=err.args)
class DatasetApi(Resource):
- @jwt_required()
- def get(self, dataset_id):
+
+ @credentials_required
+ def get(self, dataset_id: int):
+ """Get dataset details
+ ---
+ tags:
+ - dataset
+ description: get details of dataset
+ parameters:
+ - in: path
+ required: true
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: get details of dataset
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
with db.session_scope() as session:
- dataset = session.query(Dataset).get(dataset_id)
- if dataset is None:
- raise NotFoundException(
- f'Failed to find dataset: {dataset_id}')
- return {'data': dataset.to_dict()}
-
- @jwt_required()
- def patch(self, dataset_id: int):
- parser = reqparse.RequestParser()
- parser.add_argument('name',
- type=str,
- required=False,
- help='dataset name')
- parser.add_argument('comment',
- type=str,
- required=False,
- help='dataset comment')
- parser.add_argument('comment')
- data = parser.parse_args()
+ dataset = DatasetService(session).get_dataset(dataset_id)
+ # TODO(liuhehan): this commit is a lazy update of dataset store_format, remove it after release 2.4
+ session.commit()
+ return make_flask_response(dataset)
+
+ @input_validator
+ @credentials_required
+ @emits_event()
+ @use_kwargs({'comment': fields.Str(required=False, load_default=None)})
+ def patch(self, dataset_id: int, comment: Optional[str]):
+ """Change dataset info
+ ---
+ tags:
+ - dataset
+ description: change dataset info
+ parameters:
+ - in: path
+ required: true
+ name: dataset_id
+ schema:
+ type: integer
+ requestBody:
+ required: false
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ comment:
+ type: string
+ responses:
+ 200:
+ description: change dataset info
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
with db.session_scope() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
- raise NotFoundException(
- f'Failed to find dataset: {dataset_id}')
- if data['name']:
- dataset.name = data['name']
- if data['comment']:
- dataset.comment = data['comment']
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ if comment:
+ dataset.comment = comment
session.commit()
- return {'data': dataset.to_dict()}, HTTPStatus.OK
+ return make_flask_response(dataset.to_proto())
+
+ @credentials_required
+ @emits_event()
+ def delete(self, dataset_id: int):
+ """Delete dataset
+ ---
+ tags:
+ - dataset
+ description: delete dataset
+ parameters:
+ - in: path
+ required: true
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: deleted dataset result
+ """
+ with db.session_scope() as session:
+ # added an exclusive lock to this row
+ # ensure the state is modified correctly in a concurrency scenario.
+ dataset = session.query(Dataset).with_for_update().populate_existing().get(dataset_id)
+ if not dataset:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ DatasetService(session).cleanup_dataset(dataset)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
class DatasetPreviewApi(Resource):
- def get(self, dataset_id: int):
+
+ @credentials_required
+ @use_kwargs({
+ 'batch_id': fields.Integer(required=True),
+ }, location='query')
+ def get(self, dataset_id: int, batch_id: int):
+ """Get dataset preview
+ ---
+ tags:
+ - dataset
+ description: get dataset preview
+ parameters:
+ - in: path
+ required: true
+ name: dataset_id
+ schema:
+ type: integer
+ - in: query
+ name: batch_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: dataset preview info
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ dtypes:
+ type: array
+ items:
+ type: object
+ properties:
+ key:
+ type: string
+ value:
+ type: string
+ sample:
+ type: array
+ items:
+ type: array
+ items:
+ anyOf:
+ - type: string
+ - type: integer
+ - type: number
+ num_example:
+ type: integer
+ metrics:
+ type: object
+ images:
+ type: array
+ items:
+ type: object
+ properties:
+ created_at:
+ type: string
+ file_name:
+ type: string
+ name:
+ type: string
+ height:
+ type: string
+ width:
+ type: string
+ path:
+ type: string
+ """
if dataset_id <= 0:
raise NotFoundException(f'Failed to find dataset: {dataset_id}')
with db.session_scope() as session:
- data = DatasetService(session).get_dataset_preview(dataset_id)
- return {'data': data}
+ data = DatasetService(session).get_dataset_preview(dataset_id, batch_id)
+ return make_flask_response(data)
+
+class DatasetLedgerApi(Resource):
-class DatasetMetricsApi(Resource):
def get(self, dataset_id: int):
- if dataset_id <= 0:
- raise NotFoundException(f'Failed to find dataset: {dataset_id}')
- name = request.args.get('name', None)
- if not name:
- raise InvalidArgumentException(f'required params name')
+ """Get dataset ledger
+ ---
+ tags:
+ - dataset
+ description: get
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: get dataset ledger page
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetLedger'
+ """
+ return make_flask_response(data={}, status=HTTPStatus.NO_CONTENT)
+
+
+class DatasetExportApi(Resource):
+
+ @credentials_required
+ @use_kwargs({
+ 'export_path': fields.Str(required=True, validate=_export_path_validator),
+ 'batch_id': fields.Integer(required=False, load_default=None)
+ })
+ def post(self, dataset_id: int, export_path: str, batch_id: Optional[int]):
+ """Export dataset
+ ---
+ tags:
+ - dataset
+ description: Export dataset
+ parameters:
+ - in: path
+ required: true
+ name: dataset_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ export_path:
+ type: string
+ required: true
+ batch_id:
+ type: integer
+ required: false
+ responses:
+ 201:
+ description: Export dataset
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ export_dataset_id:
+ type: integer
+ dataset_job_id:
+ type: integer
+ """
+ export_path = _parse_data_source_url(export_path).url
with db.session_scope() as session:
- data = DatasetService(session).feature_metrics(name, dataset_id)
- return {'data': data}
+ input_dataset: Dataset = session.query(Dataset).get(dataset_id)
+ if not input_dataset:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ export_index = session.query(DatasetJob).filter(DatasetJob.kind == DatasetJobKind.EXPORT).filter(
+ DatasetJob.input_dataset_id == dataset_id).count()
+ if batch_id:
+ data_batch = session.query(DataBatch).filter(DataBatch.dataset_id == dataset_id).filter(
+ DataBatch.id == batch_id).first()
+ if data_batch is None:
+ raise NotFoundException(f'Failed to find data_batch {batch_id} in dataset {dataset_id}')
+ data_batches = [data_batch]
+ export_dataset_name = get_export_dataset_name(index=export_index,
+ input_dataset_name=input_dataset.name,
+ input_data_batch_name=data_batch.batch_name)
+ else:
+ data_batches = input_dataset.data_batches
+ export_dataset_name = get_export_dataset_name(index=export_index, input_dataset_name=input_dataset.name)
+ dataset_job_config = dataset_pb2.DatasetJobConfig(dataset_uuid=input_dataset.uuid)
+ store_format = StoreFormat.UNKNOWN.value if input_dataset.store_format == StoreFormat.UNKNOWN \
+ else StoreFormat.CSV.value
+ dataset_parameter = dataset_pb2.DatasetParameter(name=export_dataset_name,
+ type=input_dataset.dataset_type.value,
+ project_id=input_dataset.project.id,
+ kind=DatasetKindV2.EXPORTED.value,
+ format=DatasetFormat(input_dataset.dataset_format).name,
+ is_published=False,
+ store_format=store_format,
+ auth_status=AuthStatus.AUTHORIZED.name,
+ path=export_path)
+ output_dataset = DatasetService(session=session).create_dataset(dataset_parameter=dataset_parameter)
+ session.flush()
+ global_configs = DatasetJobGlobalConfigs()
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ global_configs.global_configs[pure_domain_name].MergeFrom(dataset_job_config)
+ export_dataset_job = DatasetJobService(session).create_as_coordinator(project_id=input_dataset.project_id,
+ kind=DatasetJobKind.EXPORT,
+ output_dataset_id=output_dataset.id,
+ global_configs=global_configs)
+ session.flush()
+ for data_batch in reversed(data_batches):
+ # skip non-succeeded data_batch
+ if not data_batch.is_available():
+ continue
+ DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage_as_coordinator(
+ dataset_job_id=export_dataset_job.id,
+ global_configs=export_dataset_job.get_global_configs(),
+ event_time=data_batch.event_time)
+
+ session.commit()
+ return make_flask_response(data={
+ 'export_dataset_id': output_dataset.id,
+ 'dataset_job_id': export_dataset_job.id
+ },
+ status=HTTPStatus.OK)
+
+
+class DatasetStateFixtApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_kwargs({
+ 'force':
+ fields.Str(required=False, load_default=None, validate=validate.OneOf([o.value for o in DatasetJobState]))
+ })
+ def post(self, dataset_id: int, force: str):
+ """fix dataset state
+ ---
+ tags:
+ - dataset
+ description: fix dataset state
+ parameters:
+ - in: path
+ required: true
+ name: dataset_id
+ schema:
+ type: integer
+ requestBody:
+ required: false
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ force:
+ type: array
+ items:
+ type: string
+ responses:
+ 200:
+ description: fix dataset state successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(dataset_id)
+ if not dataset:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ if force:
+ dataset.parent_dataset_job.state = DatasetJobState(force)
+ else:
+ workflow_state = dataset.parent_dataset_job.workflow.get_state_for_frontend()
+ # if workflow is completed, restart the batch stats task
+ if workflow_state == WorkflowExternalState.COMPLETED:
+ item_name = dataset.parent_dataset_job.get_context().batch_stats_item_name
+ runners = ComposerService(session).get_recent_runners(item_name, count=1)
+ # This is a hack to restart the composer runner, see details in job_scheduler.py
+ if len(runners) > 0:
+ runners[0].status = RunnerStatus.INIT.value
+ dataset.parent_dataset_job.state = DatasetJobState.RUNNING
+ elif workflow_state in (WorkflowExternalState.FAILED, WorkflowExternalState.STOPPED,
+ WorkflowExternalState.INVALID):
+ dataset.parent_dataset_job.state = DatasetJobState.FAILED
+ session.commit()
+ return make_flask_response(data=dataset.to_proto(), status=HTTPStatus.OK)
+
+
+class DatasetParameter(Schema):
+ name = fields.Str(required=True)
+ dataset_type = fields.Str(required=False,
+ load_default=DatasetType.PSI.value,
+ validate=validate.OneOf([o.value for o in DatasetType]))
+ comment = fields.Str(required=False)
+ project_id = fields.Int(required=True)
+ kind = fields.Str(required=False,
+ load_default=DatasetKindV2.RAW.value,
+ validate=validate.OneOf([o.value for o in DatasetKindV2]))
+ dataset_format = fields.Str(required=True, validate=validate.OneOf([o.name for o in DatasetFormat]))
+ need_publish = fields.Bool(required=False, load_default=False)
+ value = fields.Int(required=False, load_default=0, validate=[validate.Range(min=100, max=10000)])
+ schema_checkers = fields.List(fields.Str(validate=validate.OneOf([o.value for o in DatasetSchemaChecker])))
+ is_published = fields.Bool(required=False, load_default=False)
+ import_type = fields.Str(required=False,
+ load_default=ImportType.COPY.value,
+ validate=validate.OneOf([o.value for o in ImportType]))
+ store_format = fields.Str(required=False,
+ load_default=StoreFormat.TFRECORDS.value,
+ validate=validate.OneOf([o.value for o in StoreFormat]))
+
+ @post_load
+ def make_dataset_parameter(self, item: Dict[str, str], **kwargs) -> dataset_pb2.DatasetParameter:
+ return dataset_pb2.DatasetParameter(name=item.get('name'),
+ type=item.get('dataset_type'),
+ comment=item.get('comment'),
+ project_id=item.get('project_id'),
+ kind=item.get('kind'),
+ format=item.get('dataset_format'),
+ need_publish=item.get('need_publish'),
+ value=item.get('value'),
+ is_published=item.get('is_published'),
+ schema_checkers=item.get('schema_checkers'),
+ import_type=item.get('import_type'),
+ store_format=item.get('store_format'))
class DatasetsApi(Resource):
- @jwt_required()
- def get(self):
- parser = reqparse.RequestParser()
- parser.add_argument('project',
- type=int,
- required=False,
- help='project')
- data = parser.parse_args()
- with db.session_scope() as session:
- datasets = DatasetService(session).get_datasets(
- project_id=int(data['project'] or 0))
- return {'data': [d.to_dict() for d in datasets]}
-
- @jwt_required()
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('name',
- required=True,
- type=str,
- help=_FORMAT_ERROR_MESSAGE.format('name'))
- parser.add_argument('dataset_type',
- required=True,
- type=DatasetType,
- help=_FORMAT_ERROR_MESSAGE.format('dataset_type'))
- parser.add_argument('comment', type=str)
- parser.add_argument('project_id',
- required=True,
- type=int,
- help=_FORMAT_ERROR_MESSAGE.format('project_id'))
- body = parser.parse_args()
- name = body.get('name')
- dataset_type = body.get('dataset_type')
- comment = body.get('comment')
- project_id = body.get('project_id')
+ FILTER_FIELDS = {
+ 'name':
+ filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'project_id':
+ filtering.SupportedField(type=filtering.FieldType.NUMBER, ops={FilterOp.EQUAL: None}),
+ 'uuid':
+ filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'dataset_kind':
+ filtering.SupportedField(type=filtering.FieldType.STRING, ops={
+ FilterOp.IN: None,
+ FilterOp.EQUAL: None
+ }),
+ 'dataset_format':
+ filtering.SupportedField(type=filtering.FieldType.STRING,
+ ops={
+ FilterOp.IN: dataset_format_filter_op_in,
+ FilterOp.EQUAL: dataset_format_filter_op_equal
+ }),
+ 'is_published':
+ filtering.SupportedField(type=filtering.FieldType.BOOL, ops={FilterOp.EQUAL: None}),
+ 'dataset_type':
+ filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'publish_frontend_state':
+ filtering.SupportedField(type=filtering.FieldType.STRING,
+ ops={FilterOp.EQUAL: dataset_publish_frontend_filter_op_equal}),
+ 'auth_status':
+ filtering.SupportedField(type=filtering.FieldType.STRING,
+ ops={FilterOp.IN: dataset_auth_status_filter_op_in}),
+ }
+
+ SORTER_FIELDS = ['created_at']
+ def __init__(self):
+ self._filter_builder = filtering.FilterBuilder(model_class=Dataset, supported_fields=self.FILTER_FIELDS)
+ self._sorter_builder = sorting.SorterBuilder(model_class=Dataset, supported_fields=self.SORTER_FIELDS)
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'page':
+ fields.Integer(required=False, load_default=1),
+ 'page_size':
+ fields.Integer(required=False, load_default=10),
+ 'dataset_job_kind':
+ fields.String(required=False, load_default=None),
+ 'state_frontend':
+ fields.List(
+ fields.String(
+ required=False, load_default=None, validate=validate.OneOf([o.value for o in ResourceState]))),
+ 'filter_exp':
+ FilterExpField(required=False, load_default=None, data_key='filter'),
+ 'sorter_exp':
+ fields.String(required=False, load_default=None, data_key='order_by'),
+ 'cron_interval':
+ fields.String(
+ required=False, load_default=None, validate=validate.OneOf([o.value for o in CronInterval])),
+ },
+ location='query')
+ def get(self,
+ page: int,
+ page_size: int,
+ dataset_job_kind: Optional[str] = None,
+ state_frontend: Optional[List[str]] = None,
+ filter_exp: Optional[FilterExpression] = None,
+ sorter_exp: Optional[str] = None,
+ cron_interval: Optional[str] = None):
+ """Get datasets list
+ ---
+ tags:
+ - dataset
+ description: get datasets list
+ parameters:
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: dataset_job_kind
+ schema:
+ type: string
+ - in: query
+ name: state_frontend
+ schema:
+ type: array
+ collectionFormat: multi
+ items:
+ type: string
+ enum: [PENDING, PROCESSING, SUCCEEDED, FAILED]
+ - in: query
+ name: filter
+ schema:
+ type: string
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ - in: query
+ name: cron_interval
+ schema:
+ type: string
+ responses:
+ 200:
+ description: get datasets list result
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetRef'
+ """
+ if dataset_job_kind is not None:
+ try:
+ dataset_job_kind = DatasetJobKind(dataset_job_kind)
+ except TypeError as err:
+ raise InvalidArgumentException(
+ details=f'failed to find dataset dataset_job_kind {dataset_job_kind}') from err
with db.session_scope() as session:
+ query = DatasetService(session).query_dataset_with_parent_job()
+ if dataset_job_kind:
+ query = query.filter(DatasetJob.kind == dataset_job_kind)
+ if filter_exp is not None:
+ try:
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
try:
- # Create dataset
- dataset = Dataset(
- name=name,
- dataset_type=dataset_type,
- comment=comment,
- path=_get_dataset_path(name),
- project_id=project_id,
- )
- session.add(dataset)
- # TODO: scan cronjob
- session.commit()
- return {'data': dataset.to_dict()}
- except Exception as e:
- session.rollback()
- raise InvalidArgumentException(details=str(e))
+ if sorter_exp is not None:
+ sorter_exp = sorting.parse_expression(sorter_exp)
+ else:
+ sorter_exp = sorting.SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ # TODO(liuhehan): add state_frontend as custom_builder
+ if state_frontend is not None:
+ states = []
+ for state in state_frontend:
+ states.append(ResourceState(state))
+ query = DatasetService.filter_dataset_state(query, states)
+ # filter daily or hourly cron
+ if cron_interval:
+ if cron_interval == CronInterval.HOURS.value:
+ time_range = timedelta(hours=1)
+ else:
+ time_range = timedelta(days=1)
+ query = query.filter(DatasetJob.time_range == time_range)
+ pagination = paginate(query=query, page=page, page_size=page_size)
+ datasets = []
+ for dataset in pagination.get_items():
+ dataset_ref = dataset.to_ref()
+ dataset_ref.total_value = 0
+ datasets.append(dataset_ref)
+ # TODO(liuhehan): this commit is a lazy update of dataset store_format, remove it after release 2.4
+ session.commit()
+ return make_flask_response(data=datasets, page_meta=pagination.get_metadata())
+
+ @input_validator
+ @credentials_required
+ @emits_event()
+ @use_args(DatasetParameter())
+ def post(self, dataset_parameter: dataset_pb2.DatasetParameter):
+ """Create dataset
+ ---
+ tags:
+ - dataset
+ description: Create dataset
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/DatasetParameter'
+ responses:
+ 201:
+ description: Create dataset
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
+ with db.session_scope() as session:
+ # processed dataset must be is_published
+ if DatasetKindV2(dataset_parameter.kind) == DatasetKindV2.PROCESSED and not dataset_parameter.is_published:
+ raise InvalidArgumentException('is_published must be true if dataset kind is PROCESSED')
+ if DatasetKindV2(dataset_parameter.kind) == DatasetKindV2.PROCESSED and ImportType(
+ dataset_parameter.import_type) != ImportType.COPY:
+ raise InvalidArgumentException('import type must be copy if dataset kind is PROCESSED')
+ if StoreFormat(dataset_parameter.store_format) == StoreFormat.CSV and DatasetKindV2(
+ dataset_parameter.kind) in [DatasetKindV2.RAW, DatasetKindV2.PROCESSED]:
+ raise InvalidArgumentException('csv store_type is not support if dataset kind is RAW or PROCESSED')
+ dataset_parameter.auth_status = AuthStatus.AUTHORIZED.name
+ dataset = DatasetService(session=session).create_dataset(dataset_parameter=dataset_parameter)
+ session.flush()
+ # create review ticket for processed_dataset
+ if DatasetKindV2(dataset_parameter.kind) == DatasetKindV2.PROCESSED:
+ ticket_helper = get_ticket_helper(session=session)
+ ticket_helper.create_ticket(TicketType.CREATE_PROCESSED_DATASET, TicketDetails(uuid=dataset.uuid))
+ session.commit()
+ return make_flask_response(data=dataset.to_proto(), status=HTTPStatus.CREATED)
+
+
+class ChildrenDatasetsApi(Resource):
+
+ def get(self, dataset_id: int):
+ """Get children datasets list
+ ---
+ tags:
+ - dataset
+ description: Get children datasets list
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: get children datasets list result
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetRef'
+ """
+ with db.session_scope() as session:
+ query = DatasetService(session=session).query_dataset_with_parent_job()
+ query = query.filter(DatasetJob.input_dataset_id == dataset_id)
+ # exported dataset should not be shown in children datasets
+ query = query.filter(Dataset.dataset_kind != DatasetKindV2.EXPORTED)
+ return make_flask_response(data=[dataset.to_ref() for dataset in query.all()])
+
+
+class BatchParameter(Schema):
+ data_source_id = fields.Integer(required=True)
+ comment = fields.Str(required=False)
+
+ @post_load
+ def make_batch_parameter(self, item: Dict[str, Any], **kwargs) -> dataset_pb2.BatchParameter:
+ data_source_id = item.get('data_source_id')
+ comment = item.get('comment')
+
+ with db.session_scope() as session:
+ data_source = session.query(DataSource).get(data_source_id)
+ if data_source is None:
+ raise ValidationError(message=f'failed to find data_source {data_source_id}',
+ field_name='data_source_id')
+
+ return dataset_pb2.BatchParameter(comment=comment, data_source_id=data_source_id)
class BatchesApi(Resource):
- @jwt_required()
+
+ SORTER_FIELDS = ['created_at', 'updated_at']
+
+ def __init__(self):
+ self._sorter_builder = sorting.SorterBuilder(model_class=DataBatch, supported_fields=self.SORTER_FIELDS)
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'page': fields.Integer(required=False, load_default=1),
+ 'page_size': fields.Integer(required=False, load_default=10),
+ 'sorter_exp': fields.String(required=False, load_default=None, data_key='order_by')
+ },
+ location='query')
+ def get(self, dataset_id: int, page: int, page_size: int, sorter_exp: Optional[str]):
+ """List data batches
+ ---
+ tags:
+ - dataset
+ description: List data batches
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of data batches
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DataBatch'
+ """
+ with db.session_scope() as session:
+ query = session.query(DataBatch).filter(DataBatch.dataset_id == dataset_id)
+ try:
+ if sorter_exp is not None:
+ sorter_exp = sorting.parse_expression(sorter_exp)
+ else:
+ # default sort is created_at desc
+ sorter_exp = sorting.SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ pagination = paginate(query=query, page=page, page_size=page_size)
+ return make_flask_response(data=[data_batch.to_proto() for data_batch in pagination.get_items()],
+ page_meta=pagination.get_metadata())
+
+
+class BatchApi(Resource):
+
+ @credentials_required
+ def get(self, dataset_id: int, data_batch_id: int):
+ """Get data batch by id
+ ---
+ tags:
+ - dataset
+ description: Get data batch by id
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ - in: path
+ name: data_batch_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: Get data batch by id
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DataBatch'
+ """
+ with db.session_scope() as session:
+ batch: DataBatch = session.query(DataBatch).filter(DataBatch.dataset_id == dataset_id).filter(
+ DataBatch.id == data_batch_id).first()
+ if batch is None:
+ raise NotFoundException(f'failed to find batch {data_batch_id} in dataset {dataset_id}')
+ return make_flask_response(data=batch.to_proto())
+
+
+class BatchAnalyzeApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'dataset_job_config': fields.Nested(DatasetJobVariablesParameter())})
+ def post(self, dataset_id: int, data_batch_id: int, dataset_job_config: dataset_pb2.DatasetJobConfig):
+ """Analyze data_batch by id
+ ---
+ tags:
+ - dataset
+ description: Analyze data_batch by id
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ - in: path
+ name: data_batch_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ dataset_job_config:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJobConfig'
+ responses:
+ 200:
+ description: analyzer dataset job details
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJob'
+ """
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(dataset_id)
+ if dataset is None:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ dataset_job_config.dataset_uuid = dataset.uuid
+ global_configs = DatasetJobGlobalConfigs()
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ global_configs.global_configs[pure_domain_name].MergeFrom(dataset_job_config)
+ analyzer_dataset_job: DatasetJob = session.query(DatasetJob).filter(
+ DatasetJob.output_dataset_id == dataset_id).filter(DatasetJob.kind == DatasetJobKind.ANALYZER).first()
+ if analyzer_dataset_job is None:
+ analyzer_dataset_job = DatasetJobService(session).create_as_coordinator(project_id=dataset.project_id,
+ kind=DatasetJobKind.ANALYZER,
+ output_dataset_id=dataset_id,
+ global_configs=global_configs)
+ else:
+ previous_global_configs = analyzer_dataset_job.get_global_configs()
+ for variable in dataset_job_config.variables:
+ set_variable_value_to_job_config(previous_global_configs.global_configs[pure_domain_name], variable)
+ analyzer_dataset_job.set_global_configs(previous_global_configs)
+ session.flush()
+ DatasetJobStageService(session).create_dataset_job_stage_as_coordinator(
+ project_id=dataset.project_id,
+ dataset_job_id=analyzer_dataset_job.id,
+ output_data_batch_id=data_batch_id,
+ global_configs=analyzer_dataset_job.get_global_configs())
+ dataset_job_details = analyzer_dataset_job.to_proto()
+ session.commit()
+
+ return make_flask_response(data=dataset_job_details, status=HTTPStatus.OK)
+
+
+class BatchMetricsApi(Resource):
+
+ @credentials_required
+ @use_kwargs({
+ 'name': fields.Str(required=True),
+ }, location='query')
+ def get(self, dataset_id: int, data_batch_id: int, name: str):
+ """Get data batch metrics info
+ ---
+ tags:
+ - dataset
+ description: get data batch metrics info
+ parameters:
+ - in: path
+ required: true
+ name: dataset_id
+ schema:
+ type: integer
+ - in: path
+ required: true
+ name: data_batch_id
+ schema:
+ type: integer
+ - in: query
+ required: true
+ name: name
+ schema:
+ type: string
+ responses:
+ 200:
+ description: get data batch metrics info
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ name:
+ type: string
+ metrics:
+ type: object
+ properties:
+ count:
+ type: string
+ max:
+ type: string
+ min:
+ type: string
+ mean:
+ type: string
+ stddev:
+ type: string
+ missing_count:
+ type: string
+ hist:
+ type: object
+ properties:
+ x:
+ type: array
+ items:
+ type: number
+ y:
+ type: array
+ items:
+ type: number
+ """
+ # TODO(liuhehan): return dataset metrics in proto
+ with db.session_scope() as session:
+ data = DatasetService(session).feature_metrics(name, dataset_id, data_batch_id)
+ return make_flask_response(data)
+
+
+class BatchRerunApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'dataset_job_parameter': fields.Nested(DatasetJobParameter())})
+ def post(self, dataset_id: int, data_batch_id: int, dataset_job_parameter: dataset_pb2.DatasetJob):
+ """rerun data_batch by id
+ ---
+ tags:
+ - dataset
+ description: Rerun data_batch by id
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ - in: path
+ name: data_batch_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ dataset_job_parameter:
+ $ref: '#/definitions/DatasetJobParameter'
+ responses:
+ 200:
+ description: dataset job stage details
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJobStage'
+ """
+ global_configs = dataset_job_parameter.global_configs
+
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(dataset_id)
+ if dataset is None:
+ raise InvalidArgumentException(f'failed to find dataset: {dataset_id}')
+ data_batch: DataBatch = session.query(DataBatch).filter(DataBatch.dataset_id == dataset_id).filter(
+ DataBatch.id == data_batch_id).first()
+ if data_batch is None:
+ raise InvalidArgumentException(f'failed to find data_batch: {data_batch_id}')
+ dataset_job: DatasetJob = dataset.parent_dataset_job
+ if dataset_job is None:
+ raise InvalidArgumentException(f'dataset_job is missing, output_dataset_id: {dataset_id}')
+ # get current global_configs
+ if dataset_job.is_coordinator():
+ current_global_configs = dataset_job.get_global_configs()
+ else:
+ participant: Participant = session.query(Participant).get(dataset_job.coordinator_id)
+ system_client = SystemServiceClient.from_participant(domain_name=participant.domain_name)
+ flag_resp = system_client.list_flags()
+ if not flag_resp.get(Flag.DATA_BATCH_RERUN_ENABLED.name):
+ raise MethodNotAllowedException(
+ f'particiapnt {participant.pure_domain_name()} not support rerun data_batch, ' \
+ 'could only rerun data_batch created as coordinator'
+ )
+ client = RpcClient.from_project_and_participant(dataset_job.project.name, dataset_job.project.token,
+ participant.domain_name)
+ response = client.get_dataset_job(uuid=dataset_job.uuid)
+ current_global_configs = response.dataset_job.global_configs
+ # set global_configs
+ for pure_domain_name in global_configs.global_configs:
+ for variable in global_configs.global_configs[pure_domain_name].variables:
+ set_variable_value_to_job_config(current_global_configs.global_configs[pure_domain_name], variable)
+ # create dataset_job_stage
+ dataset_job_stage = DatasetJobStageService(session).create_dataset_job_stage_as_coordinator(
+ project_id=dataset.project_id,
+ dataset_job_id=dataset_job.id,
+ output_data_batch_id=data_batch_id,
+ global_configs=current_global_configs)
+ session.flush()
+ dataset_job_stage_details = dataset_job_stage.to_proto()
+ session.commit()
+
+ return make_flask_response(data=dataset_job_stage_details, status=HTTPStatus.OK)
+
+
+class DataSourceParameter(Schema):
+ name = fields.Str(required=True)
+ comment = fields.Str(required=False)
+ data_source_url = fields.Str(required=True, validate=_path_authority_validator)
+ is_user_upload = fields.Bool(required=False)
+ dataset_format = fields.Str(required=False,
+ load_default=DatasetFormat.TABULAR.name,
+ validate=validate.OneOf([o.name for o in DatasetFormat]))
+ store_format = fields.Str(required=False,
+ load_default=StoreFormat.UNKNOWN.value,
+ validate=validate.OneOf([o.value for o in StoreFormat]))
+ dataset_type = fields.Str(required=False,
+ load_default=DatasetType.PSI.value,
+ validate=validate.OneOf([o.value for o in DatasetType]))
+
+ @post_load
+ def make_data_source(self, item: Dict[str, str], **kwargs) -> dataset_pb2.DataSource:
+ del kwargs # this variable is not needed for now
+ name = item.get('name')
+ comment = item.get('comment')
+ data_source_url = item.get('data_source_url')
+ is_user_upload = item.get('is_user_upload', False)
+ data_source = _parse_data_source_url(data_source_url)
+ data_source.name = name
+ data_source.dataset_format = item.get('dataset_format')
+ data_source.store_format = item.get('store_format')
+ data_source.dataset_type = item.get('dataset_type')
+ if is_user_upload:
+ data_source.is_user_upload = True
+ if comment:
+ data_source.comment = comment
+ return data_source
+
+
+class DataSourcesApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'data_source': fields.Nested(DataSourceParameter()), 'project_id': fields.Integer(required=True)})
+ def post(self, data_source: dataset_pb2.DataSource, project_id: int):
+ """Create a data source
+ ---
+ tags:
+ - dataset
+ description: create a data source
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ data_source:
+ type: object
+ required: true
+ properties:
+ schema:
+ $ref: '#/definitions/DataSourceParameter'
+ project_id:
+ type: integer
+ required: true
+ responses:
+ 201:
+ description: The data source is created
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DataSource'
+ 409:
+ description: A data source with the same name exists
+ 400:
+ description: |
+ A data source that webconsole cannot connect with
+ Probably, unexist data source or unauthorized to the data source
+ """
+
+ _validate_data_source(data_source.url, DatasetType(data_source.dataset_type))
+ with db.session_scope() as session:
+ data_source.project_id = project_id
+ data_source = DataSourceService(session=session).create_data_source(data_source)
+ session.commit()
+ return make_flask_response(data=data_source.to_proto(), status=HTTPStatus.CREATED)
+
+ @credentials_required
+ @use_kwargs({'project_id': fields.Integer(required=False, load_default=0, validate=validate.Range(min=0))},
+ location='query')
+ def get(self, project_id: int):
+ """Get a list of data source
+ ---
+ tags:
+ - dataset
+ description: get a list of data source
+ parameters:
+ - in: query
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: list of data source
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DataSource'
+ """
+
+ with db.session_scope() as session:
+ data_sources = DataSourceService(session=session).get_data_sources(project_id)
+ return make_flask_response(data=data_sources)
+
+
+class DataSourceApi(Resource):
+
+ @credentials_required
+ def get(self, data_source_id: int):
+ """Get target data source by id
+ ---
+ tags:
+ - dataset
+ description: get target data source by id
+ parameters:
+ - in: path
+ name: data_source_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: data source
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DataSource'
+ """
+
+ with db.session_scope() as session:
+ data_source: DataSource = session.query(DataSource).get(data_source_id)
+ if not data_source:
+ raise NotFoundException(message=f'cannot find data_source with id: {data_source_id}')
+ return make_flask_response(data=data_source.to_proto())
+
+ @credentials_required
+ def delete(self, data_source_id: int):
+ """Delete a data source
+ ---
+ tags:
+ - dataset
+ description: delete a data source
+ parameters:
+ - in: path
+ name: data_source_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: deleted data source result
+ """
+
+ with db.session_scope() as session:
+ DataSourceService(session=session).delete_data_source(data_source_id)
+ session.commit()
+ return make_flask_response(data={}, status=HTTPStatus.NO_CONTENT)
+
+
+class DataSourceTreeApi(Resource):
+
+ @credentials_required
+ def get(self, data_source_id: int):
+ """Get the data source tree
+ ---
+ tags:
+ - dataset
+ description: get the data source tree
+ parameters:
+ - in: path
+ name: data_source_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: the file tree of the data source
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.FileTreeNode'
+ """
+ with db.session_scope() as session:
+ data_source: DataSource = session.query(DataSource).get(data_source_id)
+ # relative path is used in returned file tree
+ file_tree = FileTreeBuilder(data_source.path, relpath=True).build_with_root()
+ return make_flask_response(file_tree)
+
+
+class DataSourceCheckConnectionApi(Resource):
+
+ @credentials_required
+ @use_kwargs({
+ 'data_source_url':
+ fields.Str(required=True, validate=_path_authority_validator),
+ 'file_num':
+ fields.Integer(required=False, load_default=_DEFAULT_DATA_SOURCE_PREVIEW_FILE_NUM),
+ 'dataset_type':
+ fields.Str(required=False,
+ load_default=DatasetType.PSI.value,
+ validate=validate.OneOf([o.value for o in DatasetType]))
+ })
+ def post(self, data_source_url: str, file_num: int, dataset_type: str):
+ """Check data source connection status
+ ---
+ tags:
+ - dataset
+ description: check data source connection status
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ data_source_url:
+ type: string
+ required: true
+ file_num:
+ type: integer
+ required: false
+ dataset_type:
+ type: string
+ required: false
+ responses:
+ 200:
+ description: status details and file_names
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ extra_nums:
+ type: interger
+ file_names:
+ type: array
+ items:
+ type: string
+ """
+
+ data_source_url = _parse_data_source_url(data_source_url).url
+ _validate_data_source(data_source_url, DatasetType(dataset_type))
+ file_names = FileManager().listdir(data_source_url)
+ return make_flask_response(data={
+ 'file_names': file_names[:file_num],
+ 'extra_nums': max(len(file_names) - file_num, 0),
+ })
+
+
+class ParticipantDatasetsApi(Resource):
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'kind':
+ fields.Str(required=False, load_default=None),
+ 'uuid':
+ fields.Str(required=False, load_default=None),
+ 'participant_id':
+ fields.Integer(required=False, load_default=None),
+ 'cron_interval':
+ fields.String(
+ required=False, load_default=None, validate=validate.OneOf([o.value for o in CronInterval])),
+ },
+ location='query')
+ def get(
+ self,
+ project_id: int,
+ kind: Optional[str],
+ uuid: Optional[str],
+ participant_id: Optional[int],
+ cron_interval: Optional[str],
+ ):
+ """Get list of participant datasets
+ ---
+ tags:
+ - dataset
+ description: get list of participant datasets
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: query
+ name: kind
+ schema:
+ type: string
+ - in: query
+ name: uuid
+ schema:
+ type: string
+ - in: query
+ name: participant_id
+ schema:
+ type: integer
+ - in: query
+ name: cron_interval
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of participant datasets
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ParticipantDatasetRef'
+ """
+ if kind is not None:
+ try:
+ DatasetKindV2(kind)
+ except ValueError as err:
+ raise InvalidArgumentException(details=f'failed to find dataset kind {kind}') from err
+ time_range = None
+ if cron_interval:
+ if cron_interval == CronInterval.HOURS.value:
+ time_range = TimeRange(hours=1)
+ else:
+ time_range = TimeRange(days=1)
+
+ with db.session_scope() as session:
+ if participant_id is None:
+ participants = ParticipantService(session).get_platform_participants_by_project(project_id)
+ else:
+ participant = session.query(Participant).get(participant_id)
+ if participant is None:
+ raise NotFoundException(f'particiapnt {participant_id} is not found')
+ participants = [participant]
+ project = session.query(Project).get(project_id)
+ data = []
+ for participant in participants:
+ # check flag
+ system_client = SystemServiceClient.from_participant(domain_name=participant.domain_name)
+ flag_resp = system_client.list_flags()
+ # if participant supports list dataset rpc, use new rpc
+ if flag_resp.get(Flag.LIST_DATASETS_RPC_ENABLED.name):
+ client = ResourceServiceClient.from_project_and_participant(participant.domain_name, project.name)
+ response = client.list_datasets(kind=DatasetKindV2(kind) if kind is not None else None,
+ uuid=uuid,
+ state=ResourceState.SUCCEEDED,
+ time_range=time_range)
+ else:
+ client = RpcClient.from_project_and_participant(project.name, project.token,
+ participant.domain_name)
+ response = client.list_participant_datasets(kind=kind, uuid=uuid)
+ datasets = response.participant_datasets
+ if uuid:
+ datasets = [d for d in datasets if uuid and d.uuid == uuid]
+ for dataset in datasets:
+ dataset.participant_id = participant.id
+ dataset.project_id = project_id
+ data.extend(datasets)
+
+ return make_flask_response(data=data)
+
+
+class DatasetPublishApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'value': fields.Int(required=False, load_default=0, validate=[validate.Range(min=100, max=10000)])})
+ def post(self, dataset_id: int, value: int):
+ """Publish the dataset in workspace
+ ---
+ tags:
+ - dataset
+ description: Publish the dataset in workspace
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ value:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: published the dataset in workspace
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
+ with db.session_scope() as session:
+ dataset = DatasetService(session=session).publish_dataset(dataset_id, value)
+ session.commit()
+ return make_flask_response(data=dataset.to_proto())
+
+ @credentials_required
+ def delete(self, dataset_id: int):
+ """Revoke publish dataset ops
+ ---
+ tags:
+ - dataset
+ description: Revoke publish dataset ops
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: revoked publish dataset successfully
+ """
+ with db.session_scope() as session:
+ DatasetService(session=session).withdraw_dataset(dataset_id)
+ session.commit()
+ return make_flask_response(data=None, status=HTTPStatus.NO_CONTENT)
+
+
+class DatasetAuthorizehApi(Resource):
+
+ @credentials_required
def post(self, dataset_id: int):
- parser = reqparse.RequestParser()
- parser.add_argument('event_time', type=int)
- parser.add_argument('files',
- required=True,
- type=list,
- location='json',
- help=_FORMAT_ERROR_MESSAGE.format('files'))
- parser.add_argument('move', type=bool)
- parser.add_argument('comment', type=str)
- body = parser.parse_args()
- event_time = body.get('event_time')
- files = body.get('files')
- move = body.get('move', False)
- comment = body.get('comment')
+ """Authorize target dataset by id
+ ---
+ tags:
+ - dataset
+ description: authorize target dataset by id
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: authorize target dataset by id
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
with db.session_scope() as session:
- dataset = session.query(Dataset).filter_by(id=dataset_id).first()
+ dataset: Dataset = session.query(Dataset).get(dataset_id)
if dataset is None:
- raise NotFoundException(
- f'Failed to find dataset: {dataset_id}')
- if event_time is None and dataset.type == DatasetType.STREAMING:
- raise InvalidArgumentException(
- details='data_batch.event_time is empty')
- # TODO: PSI dataset should not allow multi batches
-
- # Use current timestamp to fill when type is PSI
- event_time = datetime.fromtimestamp(
- event_time or datetime.utcnow().timestamp(), tz=timezone.utc)
- batch_folder_name = event_time.strftime('%Y%m%d_%H%M%S')
- batch_path = f'{dataset.path}/batch/{batch_folder_name}'
- # Create batch
- batch = DataBatch(dataset_id=dataset.id,
- event_time=event_time,
- comment=comment,
- state=BatchState.NEW,
- move=move,
- path=batch_path)
- batch_details = dataset_pb2.DataBatch()
- for file_path in files:
- file = batch_details.files.add()
- file.source_path = file_path
- file_name = file_path.split('/')[-1]
- file.destination_path = f'{batch_path}/{file_name}'
- batch.set_details(batch_details)
- session.add(batch)
- session.commit()
- session.refresh(batch)
- scheduler.wakeup(data_batch_ids=[batch.id])
- return {'data': batch.to_dict()}
-
-
-class FilesApi(Resource):
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ # update local auth_status
+ dataset.auth_status = AuthStatus.AUTHORIZED
+ if dataset.participants_info is not None:
+ # update local auth_status cache
+ AuthService(session=session, dataset_job=dataset.parent_dataset_job).update_auth_status(
+ domain_name=SettingService.get_system_info().pure_domain_name, auth_status=AuthStatus.AUTHORIZED)
+ # update participants auth_status cache
+ DatasetJobController(session=session).inform_auth_status(dataset_job=dataset.parent_dataset_job,
+ auth_status=AuthStatus.AUTHORIZED)
+ session.commit()
+ return make_flask_response(data=dataset.to_proto())
+
+ @credentials_required
+ def delete(self, dataset_id: int):
+ """Revoke dataset authorization by id
+ ---
+ tags:
+ - dataset
+ description: revoke dataset authorization by id
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: revoke dataset authorization by id successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(dataset_id)
+ if dataset is None:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ # update local auth_status
+ dataset.auth_status = AuthStatus.WITHDRAW
+ if dataset.participants_info is not None:
+ # update local auth_status cache
+ AuthService(session=session, dataset_job=dataset.parent_dataset_job).update_auth_status(
+ domain_name=SettingService.get_system_info().pure_domain_name, auth_status=AuthStatus.WITHDRAW)
+ # update participants auth_status cache
+ DatasetJobController(session=session).inform_auth_status(dataset_job=dataset.parent_dataset_job,
+ auth_status=AuthStatus.WITHDRAW)
+ session.commit()
+ return make_flask_response(data=dataset.to_proto())
+
+
+class DatasetFlushAuthStatusApi(Resource):
+
+ @credentials_required
+ def post(self, dataset_id: int):
+ """flush dataset auth status cache by id
+ ---
+ tags:
+ - dataset
+ description: flush dataset auth status cache by id
+ parameters:
+ - in: path
+ name: dataset_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: flush dataset auth status cache by id successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Dataset'
+ """
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(dataset_id)
+ if dataset is None:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ if dataset.participants_info is not None:
+ DatasetJobController(session=session).update_auth_status_cache(dataset_job=dataset.parent_dataset_job)
+ session.commit()
+ return make_flask_response(data=dataset.to_proto())
+
+
+class TimeRangeParameter(Schema):
+ days = fields.Integer(required=False, load_default=0, validate=[validate.Range(min=0, max=1)])
+ hours = fields.Integer(required=False, load_default=0, validate=[validate.Range(min=0, max=1)])
+
+ @post_load
+ def make_time_range(self, item: Dict[str, Any], **kwargs) -> dataset_pb2.TimeRange:
+ days = item['days']
+ hours = item['hours']
+
+ return dataset_pb2.TimeRange(days=days, hours=hours)
+
+
+class DatasetJobDefinitionApi(Resource):
+
+ @credentials_required
+ def get(self, dataset_job_kind: str):
+ """Get variables of this dataset_job
+ ---
+ tags:
+ - dataset
+ description: Get variables of this dataset_job
+ parameters:
+ - in: path
+ name: dataset_job_kind
+ schema:
+ type: string
+ responses:
+ 200:
+ description: variables of this dataset_job
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ variables:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Variable'
+ is_federated:
+ type: boolean
+ """
+ # webargs doesn't support location=path for now
+ # reference: webargs/core.py:L285
+ try:
+ dataset_job_kind = DatasetJobKind(dataset_job_kind)
+ except ValueError as err:
+ raise InvalidArgumentException(details=f'unkown dataset_job_kind {dataset_job_kind}') from err
+ with db.session_scope() as session:
+ configer = DatasetJobConfiger.from_kind(dataset_job_kind, session)
+ user_variables = configer.user_variables
+ is_federated = not DatasetJobService(session).is_local(dataset_job_kind)
+ return make_flask_response(data={'variables': user_variables, 'is_federated': is_federated})
+
+
+class DatasetJobsApi(Resource):
+ FILTER_FIELDS = {
+ 'name': filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'kind': filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.IN: None}),
+ 'input_dataset_id': filtering.SupportedField(type=filtering.FieldType.NUMBER, ops={FilterOp.EQUAL: None}),
+ 'coordinator_id': filtering.SupportedField(type=filtering.FieldType.NUMBER, ops={FilterOp.IN: None}),
+ 'state': filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.IN: None}),
+ }
+
+ SORTER_FIELDS = ['created_at']
+
+ def __init__(self):
+ self._filter_builder = filtering.FilterBuilder(model_class=DatasetJob, supported_fields=self.FILTER_FIELDS)
+ self._sorter_builder = sorting.SorterBuilder(model_class=DatasetJob, supported_fields=self.SORTER_FIELDS)
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'page': fields.Integer(required=False, load_default=1),
+ 'page_size': fields.Integer(required=False, load_default=10),
+ 'filter_exp': FilterExpField(required=False, load_default=None, data_key='filter'),
+ 'sorter_exp': fields.String(required=False, load_default=None, data_key='order_by'),
+ },
+ location='query')
+ def get(self,
+ project_id: int,
+ page: int,
+ page_size: int,
+ filter_exp: Optional[FilterExpression] = None,
+ sorter_exp: Optional[str] = None):
+ """Get list of this dataset_jobs
+ ---
+ tags:
+ - dataset
+ description: Get list of this dataset_jobs
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: query
+ name: filter
+ schema:
+ type: string
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of this dataset_jobs
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJobRef'
+ """
+ with db.session_scope() as session:
+ query = session.query(DatasetJob).filter(DatasetJob.project_id == project_id)
+ if filter_exp is not None:
+ try:
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ try:
+ if sorter_exp is not None:
+ sorter_exp = sorting.parse_expression(sorter_exp)
+ else:
+ sorter_exp = sorting.SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ pagination = paginate(query=query, page=page, page_size=page_size)
+
+ return make_flask_response(data=[dataset_job.to_ref() for dataset_job in pagination.get_items()],
+ page_meta=pagination.get_metadata())
+
+ @credentials_required
+ @use_kwargs({
+ 'dataset_job_parameter': fields.Nested(DatasetJobParameter()),
+ 'output_dataset_id': fields.Integer(required=False, load_default=None),
+ 'time_range': fields.Nested(TimeRangeParameter(), required=False, load_default=dataset_pb2.TimeRange())
+ })
+ def post(self, project_id: int, dataset_job_parameter: dataset_pb2.DatasetJob, output_dataset_id: Optional[int],
+ time_range: dataset_pb2.TimeRange):
+ """Create new dataset job of the kind
+ ---
+ tags:
+ - dataset
+ description: Create new dataset job of the kind
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ dataset_job_parameter:
+ $ref: '#/definitions/DatasetJobParameter'
+ time_range:
+ $ref: '#/definitions/TimeRangeParameter'
+ responses:
+ 201:
+ description: Create new dataset job of the kind
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJob'
+ """
+ dataset_job_kind = DatasetJobKind(dataset_job_parameter.kind)
+ if not Flag.OT_PSI_ENABLED.value and dataset_job_kind == DatasetJobKind.OT_PSI_DATA_JOIN:
+ raise NoAccessException(f'dataset job {dataset_job_parameter.kind} is not enabled')
+ if not Flag.HASH_DATA_JOIN_ENABLED.value and dataset_job_kind == DatasetJobKind.HASH_DATA_JOIN:
+ raise NoAccessException(f'dataset job {dataset_job_parameter.kind} is not enabled')
+
+ global_configs = dataset_job_parameter.global_configs
+
+ with db.session_scope() as session:
+ output_dataset = session.query(Dataset).get(output_dataset_id)
+ if not output_dataset:
+ raise InvalidArgumentException(f'failed to find dataset: {output_dataset_id}')
+ time_delta = None
+ if output_dataset.dataset_type == DatasetType.STREAMING:
+ if not (time_range.days > 0) ^ (time_range.hours > 0):
+ raise InvalidArgumentException('must specify cron by days or hours')
+ time_delta = timedelta(days=time_range.days, hours=time_range.hours)
+ dataset_job = DatasetJobService(session).create_as_coordinator(project_id=project_id,
+ kind=dataset_job_kind,
+ output_dataset_id=output_dataset.id,
+ global_configs=global_configs,
+ time_range=time_delta)
+ session.flush()
+ dataset_job_details = dataset_job.to_proto()
+
+ # we set particiapnts_info in dataset_job api as we need get participants from dataset_kind
+ particiapnts = DatasetJobService(session=session).get_participants_need_distribute(dataset_job=dataset_job)
+ AuthService(session=session,
+ dataset_job=dataset_job).initialize_participants_info_as_coordinator(participants=particiapnts)
+ # set need_create_stage to True for non-cron dataset_job,
+ # we donot create stage here as we should promise no stage created before all particiapnts authorized
+ if not dataset_job.is_cron():
+ context = dataset_job.get_context()
+ context.need_create_stage = True
+ dataset_job.set_context(context)
+ session.commit()
+
+ return make_flask_response(dataset_job_details, status=HTTPStatus.CREATED)
+
+
+class DatasetJobApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, dataset_job_id: int):
+ """Get detail of this dataset_job
+ ---
+ tags:
+ - dataset
+ description: Get detail of this dataset_job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: dataset_job_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: detail of this dataset_job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJob'
+ """
+ with db.session_scope() as session:
+ # TODO(wangsen.0914): move these logic into service
+ dataset_job: DatasetJob = session.query(DatasetJob).filter_by(project_id=project_id).filter_by(
+ id=dataset_job_id).first()
+ if dataset_job is None:
+ raise NotFoundException(f'failed to find datasetjob {dataset_job_id}')
+ dataset_job_pb = dataset_job.to_proto()
+ if not dataset_job.is_coordinator():
+ participant = session.query(Participant).get(dataset_job.coordinator_id)
+ client = RpcClient.from_project_and_participant(dataset_job.project.name, dataset_job.project.token,
+ participant.domain_name)
+ response = client.get_dataset_job(uuid=dataset_job.uuid)
+ dataset_job_pb.global_configs.MergeFrom(response.dataset_job.global_configs)
+ dataset_job_pb.scheduler_state = response.dataset_job.scheduler_state
+ return make_flask_response(dataset_job_pb)
+
+ @credentials_required
+ def delete(self, project_id: int, dataset_job_id: int):
+ """Delete dataset_job by id
+ ---
+ tags:
+ - dataset
+ description: Delete dataset_job by id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: dataset_job_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: delete dataset_job successfully
+ """
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).filter_by(project_id=project_id).filter_by(
+ id=dataset_job_id).first()
+ if dataset_job is None:
+ message = f'Failed to delete dataset_job: {dataset_job_id}; reason: failed to find dataset_job'
+ logging.error(message)
+ raise NotFoundException(message)
+ DatasetJobService(session).delete_dataset_job(dataset_job=dataset_job)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class DatasetJobStopApi(Resource):
+
+ @credentials_required
+ def post(self, project_id: int, dataset_job_id: int):
+ """Stop dataset_job by id
+ ---
+ tags:
+ - dataset
+ description: Stop dataset_job by id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: dataset_job_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: stop dataset_job successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJob'
+ """
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).filter_by(project_id=project_id).filter_by(
+ id=dataset_job_id).first()
+ if dataset_job is None:
+ raise NotFoundException(f'failed to find datasetjob {dataset_job_id}')
+ DatasetJobController(session).stop(uuid=dataset_job.uuid)
+ session.commit()
+ return make_flask_response(data=dataset_job.to_proto())
+
+
+class DatasetJobStopSchedulerApi(Resource):
+
+ @credentials_required
+ def post(self, project_id: int, dataset_job_id: int):
+ """Stop scheduler dataset_job by id
+ ---
+ tags:
+ - dataset
+ description: Stop scheduler dataset_job by id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: dataset_job_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: stop scheduler dataset_job successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJob'
+ """
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).filter_by(project_id=project_id).filter_by(
+ id=dataset_job_id).first()
+ if dataset_job is None:
+ raise NotFoundException(f'failed to find datasetjob {dataset_job_id}')
+ if dataset_job.is_coordinator():
+ DatasetJobService(session=session).stop_cron_scheduler(dataset_job=dataset_job)
+ dataset_job_pb = dataset_job.to_proto()
+ else:
+ participant = session.query(Participant).get(dataset_job.coordinator_id)
+ client = JobServiceClient.from_project_and_participant(participant.domain_name,
+ dataset_job.project.name)
+ client.update_dataset_job_scheduler_state(uuid=dataset_job.uuid,
+ scheduler_state=DatasetJobSchedulerState.STOPPED)
+ client = RpcClient.from_project_and_participant(dataset_job.project.name, dataset_job.project.token,
+ participant.domain_name)
+ response = client.get_dataset_job(uuid=dataset_job.uuid)
+ dataset_job_pb = dataset_job.to_proto()
+ dataset_job_pb.global_configs.MergeFrom(response.dataset_job.global_configs)
+ dataset_job_pb.scheduler_state = response.dataset_job.scheduler_state
+ session.commit()
+ return make_flask_response(data=dataset_job_pb)
+
+
+class DatasetJobStagesApi(Resource):
+
+ FILTER_FIELDS = {
+ 'state': filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.IN: None}),
+ }
+
+ SORTER_FIELDS = ['created_at']
+
def __init__(self):
- self._file_manager = FileManager()
+ self._filter_builder = filtering.FilterBuilder(model_class=DatasetJobStage, supported_fields=self.FILTER_FIELDS)
+ self._sorter_builder = sorting.SorterBuilder(model_class=DatasetJobStage, supported_fields=self.SORTER_FIELDS)
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'page': fields.Integer(required=False, load_default=1),
+ 'page_size': fields.Integer(required=False, load_default=10),
+ 'filter_exp': FilterExpField(required=False, load_default=None, data_key='filter'),
+ 'sorter_exp': fields.String(required=False, load_default=None, data_key='order_by')
+ },
+ location='query')
+ def get(self, project_id: int, dataset_job_id: int, page: int, page_size: int,
+ filter_exp: Optional[FilterExpression], sorter_exp: Optional[str]):
+ """List dataset job stages
+ ---
+ tags:
+ - dataset
+ description: List dataset job stages
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: dataset_job_id
+ schema:
+ type: integer
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: filter
+ schema:
+ type: string
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of dataset job stages
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJobStageRef'
+ """
+ with db.session_scope() as session:
+ query = session.query(DatasetJobStage).filter(DatasetJobStage.project_id == project_id).filter(
+ DatasetJobStage.dataset_job_id == dataset_job_id)
+ if filter_exp is not None:
+ try:
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ try:
+ if sorter_exp is not None:
+ sorter_exp = sorting.parse_expression(sorter_exp)
+ else:
+ # default sort is created_at desc
+ sorter_exp = sorting.SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ pagination = paginate(query=query, page=page, page_size=page_size)
+ return make_flask_response(
+ data=[dataset_job_stage.to_ref() for dataset_job_stage in pagination.get_items()],
+ page_meta=pagination.get_metadata())
- @jwt_required()
- def get(self):
- # TODO: consider the security factor
- if 'directory' in request.args:
- directory = request.args['directory']
- else:
- directory = os.path.join(current_app.config.get('STORAGE_ROOT'),
- 'upload')
- files = self._file_manager.ls(directory, recursive=True)
- return {'data': [dict(file._asdict()) for file in files]}
+
+class DatasetJobStageApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, dataset_job_id: int, dataset_job_stage_id: int):
+ """Get details of given dataset job stage
+ ---
+ tags:
+ - dataset
+ description: Get details of given dataset job stage
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: dataset_job_id
+ schema:
+ type: integer
+ - in: path
+ name: dataset_job_stage_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: dataset job stage details
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DatasetJobStage'
+ """
+ with db.session_scope() as session:
+ dataset_job_stage: DatasetJobStage = session.query(DatasetJobStage).filter(
+ DatasetJobStage.project_id == project_id).filter(
+ DatasetJobStage.dataset_job_id == dataset_job_id).filter(
+ DatasetJobStage.id == dataset_job_stage_id).first()
+ if not dataset_job_stage:
+ raise NotFoundException(f'Failed to find dataset job stage: {dataset_job_stage_id}')
+ dataset_job_stage_pb = dataset_job_stage.to_proto()
+ if not dataset_job_stage.is_coordinator():
+ participant = session.query(Participant).get(dataset_job_stage.coordinator_id)
+ client = JobServiceClient.from_project_and_participant(participant.domain_name,
+ dataset_job_stage.project.name)
+ response = client.get_dataset_job_stage(dataset_job_stage_uuid=dataset_job_stage.uuid)
+ dataset_job_stage_pb.global_configs.MergeFrom(response.dataset_job_stage.global_configs)
+
+ return make_flask_response(dataset_job_stage_pb)
def initialize_dataset_apis(api: Api):
api.add_resource(DatasetsApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/')
+ api.add_resource(DatasetPublishApi, '/datasets/:publish')
+ api.add_resource(DatasetAuthorizehApi, '/datasets/:authorize')
+ api.add_resource(DatasetFlushAuthStatusApi, '/datasets/:flush_auth_status')
api.add_resource(BatchesApi, '/datasets//batches')
+ api.add_resource(BatchApi, '/datasets//batches/')
+ api.add_resource(ChildrenDatasetsApi, '/datasets//children_datasets')
+ api.add_resource(BatchAnalyzeApi, '/datasets//batches/:analyze')
+ api.add_resource(BatchMetricsApi, '/datasets//batches//feature_metrics')
+ api.add_resource(BatchRerunApi, '/datasets//batches/:rerun')
api.add_resource(DatasetPreviewApi, '/datasets//preview')
- api.add_resource(DatasetMetricsApi,
- '/datasets//feature_metrics')
- api.add_resource(FilesApi, '/files')
+ api.add_resource(DatasetLedgerApi, '/datasets//ledger')
+ api.add_resource(DatasetExportApi, '/datasets/:export')
+ api.add_resource(DatasetStateFixtApi, '/datasets/:state_fix')
+
+ api.add_resource(DataSourcesApi, '/data_sources')
+ api.add_resource(DataSourceApi, '/data_sources/')
+ api.add_resource(DataSourceCheckConnectionApi, '/data_sources:check_connection')
+ api.add_resource(DataSourceTreeApi, '/data_sources//tree')
+
+ api.add_resource(ParticipantDatasetsApi, '/project//participant_datasets')
+
+ api.add_resource(DatasetJobDefinitionApi, '/dataset_job_definitions/')
+ api.add_resource(DatasetJobsApi, '/projects//dataset_jobs')
+ api.add_resource(DatasetJobApi, '/projects//dataset_jobs/')
+ api.add_resource(DatasetJobStopApi, '/projects//dataset_jobs/:stop')
+ api.add_resource(DatasetJobStopSchedulerApi,
+ '/projects//dataset_jobs/:stop_scheduler')
+
+ api.add_resource(DatasetJobStagesApi,
+ '/projects//dataset_jobs//dataset_job_stages')
+ api.add_resource(
+ DatasetJobStageApi,
+ '/projects//dataset_jobs//dataset_job_stages/')
+
+ schema_manager.append(DataSourceParameter)
+ schema_manager.append(DatasetJobConfigParameter)
+ schema_manager.append(DatasetParameter)
+ schema_manager.append(BatchParameter)
+ schema_manager.append(DatasetJobParameter)
+ schema_manager.append(TimeRangeParameter)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/apis_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/apis_test.py
new file mode 100644
index 000000000..9c5c84861
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/apis_test.py
@@ -0,0 +1,3175 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import os
+import tempfile
+from datetime import datetime, timedelta
+from http import HTTPStatus
+from pathlib import Path
+import urllib
+import unittest
+from unittest.mock import patch, MagicMock, ANY, PropertyMock
+from google.protobuf.struct_pb2 import Value
+
+from collections import namedtuple
+from marshmallow.exceptions import ValidationError
+from tensorflow.io import gfile
+
+from envs import Envs
+from testing.common import BaseTestCase
+from testing.dataset import FakeDatasetJobConfiger, FakeFederatedDatasetJobConfiger
+
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.apis import _parse_data_source_url, _path_authority_validator
+from fedlearner_webconsole.dataset.models import (Dataset, DatasetJob, DatasetJobKind, DatasetJobSchedulerState,
+ DatasetJobStage, ImportType, DatasetKindV2, DatasetSchemaChecker,
+ DatasetJobState, StoreFormat, DatasetType, ResourceState, DataBatch,
+ DatasetFormat, BatchState, DataSourceType, DataSource)
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.participant.models import ProjectParticipant, Participant
+from fedlearner_webconsole.proto import dataset_pb2, service_pb2, setting_pb2
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo, ParticipantsInfo
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import ListDatasetsResponse
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.algorithm_pb2 import FileTreeNode
+
+FakeFileStatistics = namedtuple('FakeFileStatistics', ['length', 'mtime_nsec', 'is_directory'])
+
+
+def fake_get_items(*args, **kwargs):
+ return {}, []
+
+
+def fake_export_task_result(*args, **kwargs):
+ return {}
+
+
+def fake_isdir(*args, **kwargs):
+ path = kwargs.get('path')
+ return (path in [
+ 'file:///data/test', 'hdfs:///home/', 'hdfs:///home/20220801', 'hdfs:///home/20220802',
+ 'hdfs:///home/20220803-15', 'hdfs:///home/2022080316'
+ ])
+
+
+class DatasetApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ STORAGE_ROOT = tempfile.gettempdir()
+
+ def setUp(self):
+ super().setUp()
+ self._storage_root = Envs.STORAGE_ROOT
+ self._file_manager = FileManager()
+ with db.session_scope() as session:
+ project = Project(name='test-project')
+ session.add(project)
+ session.flush([project])
+ participant = Participant(id=project.id, name='test_participant', domain_name='fake_domain_name')
+ session.add(participant)
+ project_participant = ProjectParticipant(project_id=project.id, participant_id=participant.id)
+ session.add(project_participant)
+ workflow = Workflow(state=WorkflowState.COMPLETED, name='workflow_generate_by_dataset_job')
+ session.add(workflow)
+
+ session.commit()
+
+ with db.session_scope() as session:
+ self.default_dataset1 = Dataset(name='default dataset1',
+ creator_username='test',
+ uuid='default dataset1 uuid',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ created_at=datetime(2012, 1, 14, 12, 0, 5))
+ meta_info = dataset_pb2.DatasetMetaInfo(value=100,
+ schema_checkers=[
+ DatasetSchemaChecker.RAW_ID_CHECKER.value,
+ DatasetSchemaChecker.NUMERIC_COLUMNS_CHECKER.value
+ ])
+ self.default_dataset1.set_meta_info(meta_info)
+ session.add(self.default_dataset1)
+ session.flush()
+ default_dataset_job_1 = DatasetJob(workflow_id=workflow.id,
+ uuid=resource_uuid(),
+ project_id=project.id,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=100,
+ output_dataset_id=self.default_dataset1.id,
+ state=DatasetJobState.FAILED)
+ session.add(default_dataset_job_1)
+ session.commit()
+ with db.session_scope() as session:
+ self.default_dataset2 = Dataset(name='default dataset2',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment2',
+ path=os.path.join(tempfile.gettempdir(), 'dataset/123'),
+ project_id=project.id,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_format=DatasetFormat.TABULAR.value,
+ created_at=datetime(2012, 1, 14, 12, 0, 6))
+ session.add(self.default_dataset2)
+ session.flush([self.default_dataset2])
+ data_batch = DataBatch(event_time=datetime.now(),
+ comment='comment',
+ state=BatchState.NEW,
+ dataset_id=self.default_dataset2.id,
+ path='/data/dataset/123/batch_test_batch')
+ session.add(data_batch)
+ session.flush()
+ default_dataset_job_2 = DatasetJob(workflow_id=workflow.id,
+ uuid=resource_uuid(),
+ project_id=project.id,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ input_dataset_id=100,
+ output_dataset_id=self.default_dataset2.id)
+ session.add(default_dataset_job_2)
+ default_dataset_job_3 = DatasetJob(workflow_id=workflow.id,
+ uuid=resource_uuid(),
+ project_id=project.id,
+ kind=DatasetJobKind.ANALYZER,
+ input_dataset_id=self.default_dataset2.id,
+ output_dataset_id=self.default_dataset2.id)
+ session.add(default_dataset_job_3)
+ session.commit()
+
+ with db.session_scope() as session:
+ workflow = Workflow(id=100, state=WorkflowState.COMPLETED, name='fake_workflow')
+ dataset = Dataset(id=3,
+ name='dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='comment',
+ path=os.path.join(tempfile.gettempdir(), 'dataset/321'),
+ project_id=3,
+ created_at=datetime(2012, 1, 14, 12, 0, 7))
+ session.add(workflow)
+ session.add(dataset)
+ session.flush()
+ default_dataset_job_4 = DatasetJob(workflow_id=workflow.id,
+ uuid=resource_uuid(),
+ project_id=project.id,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ input_dataset_id=100,
+ output_dataset_id=dataset.id)
+ session.add(default_dataset_job_4)
+ session.commit()
+
+ def test_get_dataset(self):
+ get_response = self.get_helper(f'/api/v2/datasets/{self.default_dataset1.id}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ get_response,
+ {
+ 'id': 1,
+ 'name': 'default dataset1',
+ 'creator_username': 'test',
+ 'comment': 'test comment1',
+ 'path': '/data/dataset/123',
+ 'deleted_at': 0,
+ 'dataset_kind': 'RAW',
+ 'is_published': False,
+ 'project_id': 1,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'num_feature': 0,
+ 'num_example': 0,
+ 'state_frontend': ResourceState.FAILED.value,
+ 'file_size': 0,
+ 'parent_dataset_job_id': 1,
+ 'data_source': ANY,
+ 'workflow_id': 1,
+ 'value': 100,
+ 'schema_checkers': ['RAW_ID_CHECKER', 'NUMERIC_COLUMNS_CHECKER'],
+ 'dataset_type': 'STREAMING',
+ 'import_type': 'COPY',
+ 'store_format': 'TFRECORDS',
+ 'analyzer_dataset_job_id': 0,
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ },
+ ignore_fields=[
+ 'created_at',
+ 'updated_at',
+ 'uuid',
+ ],
+ )
+
+ def test_get_internal_processed_dataset(self):
+ with db.session_scope() as session:
+ default_dataset = Dataset(id=10,
+ uuid=resource_uuid(),
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ dataset_kind=DatasetKindV2.INTERNAL_PROCESSED,
+ is_published=False,
+ auth_status=AuthStatus.AUTHORIZED)
+ session.add(default_dataset)
+ session.commit()
+
+ get_response = self.get_helper('/api/v2/datasets/10')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ get_response,
+ {
+ 'id': 10,
+ 'project_id': 1,
+ 'name': 'default dataset',
+ 'path': '/data/dataset/123',
+ 'comment': 'test comment',
+ 'dataset_format': 'TABULAR',
+ 'state_frontend': 'SUCCEEDED',
+ 'dataset_kind': 'INTERNAL_PROCESSED',
+ 'workflow_id': 0,
+ 'data_source': '',
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'deleted_at': 0,
+ 'parent_dataset_job_id': 0,
+ 'analyzer_dataset_job_id': 0,
+ 'is_published': False,
+ 'value': 0,
+ 'schema_checkers': [],
+ 'creator_username': '',
+ 'import_type': 'COPY',
+ 'dataset_type': 'PSI',
+ 'store_format': 'TFRECORDS',
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'AUTHORIZED',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ },
+ ignore_fields=[
+ 'uuid',
+ 'created_at',
+ 'updated_at',
+ ],
+ )
+
+ def test_get_dataset_not_found(self):
+ get_response = self.get_helper('/api/v2/datasets/10086')
+ self.assertEqual(get_response.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_get_datasets(self):
+ with db.session_scope() as session:
+ default_dataset_job_4 = session.query(DatasetJob).get(4)
+ default_dataset_job_4.kind = DatasetJobKind.ANALYZER
+ default_dataset_job_4.input_dataset_id = 3
+
+ default_data_source = DataSource(id=4,
+ name='default data_source',
+ creator_username='test',
+ uuid='default data_source uuid',
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.SOURCE,
+ created_at=datetime(2012, 1, 14, 12, 0, 1))
+ session.add(default_data_source)
+ session.commit()
+ get_response = self.get_helper('/api/v2/datasets')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(datasets, [{
+ 'id': 3,
+ 'project_id': 3,
+ 'comment': 'comment',
+ 'created_at': 1326542407,
+ 'creator_username': 'test',
+ 'data_source': ANY,
+ 'dataset_format': 'TABULAR',
+ 'dataset_kind': 'RAW',
+ 'dataset_type': 'STREAMING',
+ 'file_size': 0,
+ 'import_type': 'COPY',
+ 'is_published': False,
+ 'name': 'dataset',
+ 'num_example': 0,
+ 'path': ANY,
+ 'state_frontend': 'FAILED',
+ 'store_format': 'TFRECORDS',
+ 'total_value': 0,
+ 'uuid': '',
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ }, {
+ 'id': 2,
+ 'project_id': 1,
+ 'name': 'default dataset2',
+ 'creator_username': 'test',
+ 'created_at': 1326542406,
+ 'path': ANY,
+ 'dataset_format': 'TABULAR',
+ 'comment': 'test comment2',
+ 'state_frontend': 'PENDING',
+ 'dataset_kind': 'PROCESSED',
+ 'data_source': ANY,
+ 'file_size': 0,
+ 'is_published': False,
+ 'num_example': 0,
+ 'uuid': '',
+ 'total_value': 0,
+ 'store_format': 'TFRECORDS',
+ 'dataset_type': 'STREAMING',
+ 'import_type': 'COPY',
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ }, {
+ 'id': 1,
+ 'project_id': 1,
+ 'name': 'default dataset1',
+ 'creator_username': 'test',
+ 'created_at': 1326542405,
+ 'path': '/data/dataset/123',
+ 'dataset_format': 'TABULAR',
+ 'comment': 'test comment1',
+ 'state_frontend': 'FAILED',
+ 'uuid': 'default dataset1 uuid',
+ 'dataset_kind': 'RAW',
+ 'file_size': 0,
+ 'is_published': False,
+ 'num_example': 0,
+ 'data_source': ANY,
+ 'total_value': 0,
+ 'store_format': 'TFRECORDS',
+ 'dataset_type': 'STREAMING',
+ 'import_type': 'COPY',
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ }, {
+ 'id': 4,
+ 'project_id': 1,
+ 'name': 'default data_source',
+ 'creator_username': 'test',
+ 'comment': 'test comment1',
+ 'created_at': 1326542401,
+ 'path': '/data/dataset/123',
+ 'data_source': ANY,
+ 'dataset_format': 'TABULAR',
+ 'dataset_kind': 'SOURCE',
+ 'dataset_type': 'PSI',
+ 'file_size': 0,
+ 'import_type': 'COPY',
+ 'is_published': False,
+ 'num_example': 0,
+ 'state_frontend': 'FAILED',
+ 'store_format': 'TFRECORDS',
+ 'total_value': 0,
+ 'uuid': 'default data_source uuid',
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ }])
+ self.assertEqual(
+ json.loads(get_response.data).get('page_meta'), {
+ 'current_page': 1,
+ 'page_size': 10,
+ 'total_pages': 1,
+ 'total_items': 4
+ })
+
+ with db.session_scope() as session:
+ default_dataset_job_4 = session.query(DatasetJob).get(4)
+ default_dataset_job_4.kind = DatasetJobKind.DATA_ALIGNMENT
+ default_dataset_job_4.input_dataset_id = 100
+ session.commit()
+
+ # test sorter
+ sorter_param = urllib.parse.quote('created_at asc')
+ get_response = self.get_helper(f'/api/v2/datasets?order_by={sorter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 4)
+ self.assertEqual([dataset.get('id') for dataset in datasets], [4, 1, 2, 3])
+
+ fake_sorter_param = urllib.parse.quote('fake_time asc')
+ get_response = self.get_helper(f'/api/v2/datasets?order_by={fake_sorter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.BAD_REQUEST)
+
+ # test filter
+ filter_param = urllib.parse.quote('(and(project_id=1)(name~="default"))')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 3)
+ self.assertEqual([dataset.get('name') for dataset in datasets],
+ ['default dataset2', 'default dataset1', 'default data_source'])
+
+ filter_param = urllib.parse.quote('(and(project_id=1)(dataset_format:["TABULAR"])(dataset_kind:["RAW"]))')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('name'), 'default dataset1')
+
+ filter_param = urllib.parse.quote('(dataset_format="IMAGE")')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 0)
+
+ filter_param = urllib.parse.quote('(dataset_format="UNKOWN")')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.BAD_REQUEST)
+
+ filter_param = urllib.parse.quote('(is_published=false)')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 4)
+
+ # test state_frontend
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.state = DatasetJobState.SUCCEEDED
+ session.commit()
+ get_response = self.get_helper('/api/v2/datasets?state_frontend=SUCCEEDED')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('name'), 'default dataset1')
+
+ def test_get_datasets_by_publish_frontend_state(self):
+ with db.session_scope() as session:
+ default_dataset_1 = session.query(Dataset).get(1)
+ default_dataset_1.ticket_status = None
+ default_dataset_1.is_published = False
+ default_dataset_2 = session.query(Dataset).get(2)
+ default_dataset_2.ticket_status = TicketStatus.PENDING
+ default_dataset_2.is_published = True
+ default_dataset_3 = session.query(Dataset).get(3)
+ default_dataset_3.ticket_status = TicketStatus.APPROVED
+ default_dataset_3.is_published = True
+ session.commit()
+ filter_param = urllib.parse.quote('(publish_frontend_state="UNPUBLISHED")')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('id'), 1)
+
+ filter_param = urllib.parse.quote('(publish_frontend_state="TICKET_PENDING")')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('id'), 2)
+
+ filter_param = urllib.parse.quote('(publish_frontend_state="PUBLISHED")')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('id'), 3)
+
+ def test_get_datasets_by_auth_state(self):
+ with db.session_scope() as session:
+ default_dataset_1 = session.query(Dataset).get(1)
+ default_dataset_1.auth_status = AuthStatus.AUTHORIZED
+ default_dataset_2 = session.query(Dataset).get(2)
+ default_dataset_2.auth_status = AuthStatus.PENDING
+ default_dataset_3 = session.query(Dataset).get(3)
+ default_dataset_3.auth_status = AuthStatus.WITHDRAW
+ session.commit()
+ filter_param = urllib.parse.quote('(auth_status:["AUTHORIZED", "WITHDRAW"])')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 2)
+ self.assertEqual([dataset.get('id') for dataset in datasets], [3, 1])
+
+ with db.session_scope() as session:
+ default_dataset_2 = session.query(Dataset).get(2)
+ default_dataset_2.auth_status = None
+ session.commit()
+ filter_param = urllib.parse.quote('(auth_status:["AUTHORIZED"])')
+ get_response = self.get_helper(f'/api/v2/datasets?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 2)
+ self.assertEqual([dataset.get('id') for dataset in datasets], [2, 1])
+
+ def test_get_internal_processed_datasets(self):
+ with db.session_scope() as session:
+ internal_processed_dataset = Dataset(id=10,
+ uuid=resource_uuid(),
+ name='internal_processed dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.INTERNAL_PROCESSED,
+ is_published=False,
+ auth_status=AuthStatus.AUTHORIZED)
+ session.add(internal_processed_dataset)
+
+ dataset_job = session.query(DatasetJob).get(2)
+ dataset_job.state = DatasetJobState.SUCCEEDED
+
+ session.commit()
+
+ get_response = self.get_helper('/api/v2/datasets?state_frontend=SUCCEEDED')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 2)
+ self.assertCountEqual([dataset.get('name') for dataset in datasets],
+ ['default dataset2', 'internal_processed dataset'])
+
+ filter_param = urllib.parse.quote('(dataset_kind:["PROCESSED"])')
+ get_response = self.get_helper(f'/api/v2/datasets?state_frontend=SUCCEEDED&filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('name'), 'default dataset2')
+
+ get_response = self.get_helper('/api/v2/datasets?state_frontend=FAILED')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('name'), 'default dataset1')
+
+ def test_get_datasets_by_time_range(self):
+ with db.session_scope() as session:
+ default_dataset_1 = session.query(Dataset).get(1)
+ default_dataset_1.parent_dataset_job.time_range = timedelta(days=1)
+ default_dataset_2 = session.query(Dataset).get(2)
+ default_dataset_2.parent_dataset_job.time_range = timedelta(hours=1)
+ session.commit()
+ get_response = self.get_helper('/api/v2/datasets?cron_interval=DAYS')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('id'), 1)
+ get_response = self.get_helper('/api/v2/datasets?cron_interval=HOURS')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 1)
+ self.assertEqual(datasets[0].get('id'), 2)
+ get_response = self.get_helper('/api/v2/datasets')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ datasets = self.get_response_data(get_response)
+ self.assertEqual(len(datasets), 3)
+
+ def test_change_dataset_comment(self):
+ get_response = self.patch_helper(f'/api/v2/datasets/{self.default_dataset1.id}', data={'comment': 'test api'})
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(self.default_dataset1.id)
+ self.assertEqual(dataset.comment, 'test api')
+
+ def test_preview_dataset(self):
+ with db.session_scope() as session:
+ tmp_path = tempfile.gettempdir()
+ self.batch_path = os.path.join(tmp_path, 'dataset/20211228_161352_train-ds/batch/20211228_081351')
+ self.default_databatch1 = DataBatch(name='20220101',
+ id=111,
+ event_time=datetime.now(),
+ comment='comment',
+ state=BatchState.NEW,
+ dataset_id=1,
+ path=self.batch_path)
+ session.add(self.default_databatch1)
+ session.commit()
+ with db.session_scope() as session:
+ tmp_path = tempfile.gettempdir()
+ self.batch_path = os.path.join(tmp_path, 'dataset/20211228_161352_train-ds/batch/20211228_081352')
+ self.default_databatch2 = DataBatch(name='20220102',
+ id=222,
+ event_time=datetime.now(),
+ comment='comment',
+ state=BatchState.NEW,
+ dataset_id=2,
+ path=self.batch_path)
+ session.add(self.default_databatch2)
+ session.commit()
+ meta_file = DatasetDirectory(dataset_path=self.default_dataset2.path).batch_meta_file(batch_name='20220101')
+ gfile.makedirs(meta_file.split('/_META')[0])
+ meta_data = {
+ 'dtypes': [{
+ 'key': 'f01',
+ 'value': 'bigint'
+ }],
+ 'sample': [
+ [
+ 1,
+ ],
+ [
+ 0,
+ ],
+ ],
+ 'count': 0,
+ 'features': {
+ 'f01': {
+ 'count': '2',
+ 'mean': '0.0015716767309123998',
+ 'stddev': '0.03961485047808605',
+ 'min': '0',
+ 'max': '1',
+ 'missing_count': '0'
+ },
+ },
+ 'hist': {
+ 'f01': {
+ 'x': [
+ 0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6000000000000001, 0.7000000000000001, 0.8, 0.9,
+ 1
+ ],
+ 'y': [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
+ },
+ },
+ }
+ with gfile.GFile(meta_file, 'w') as f:
+ f.write(json.dumps(meta_data))
+
+ response = self.get_helper('/api/v2/datasets/2/preview?batch_id=111')
+ self.assertEqual(response.status_code, 200)
+ preview_data = self.get_response_data(response)
+ golden_preview = {
+ 'dtypes': [{
+ 'key': 'f01',
+ 'value': 'bigint'
+ }],
+ 'sample': [
+ [1],
+ [0],
+ ],
+ 'num_example': 0,
+ 'metrics': {
+ 'f01': {
+ 'count': '2',
+ 'mean': '0.0015716767309123998',
+ 'stddev': '0.03961485047808605',
+ 'min': '0',
+ 'max': '1',
+ 'missing_count': '0'
+ },
+ },
+ }
+ self.assertEqual(preview_data, golden_preview, 'should has preview data')
+
+ @patch('fedlearner_webconsole.dataset.services.get_dataset_path')
+ def test_post_raw_datasets(self, mock_get_dataset_path: MagicMock):
+ name = 'test_dataset'
+ dataset_path = os.path.join(self._storage_root, 'dataset/20200608_060606_test-post-dataset')
+ mock_get_dataset_path.return_value = dataset_path
+ dataset_type = DatasetType.PSI.value
+ comment = 'test comment'
+ create_response = self.post_helper('/api/v2/datasets',
+ data={
+ 'name': name,
+ 'dataset_type': dataset_type,
+ 'comment': comment,
+ 'project_id': 1,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'kind': DatasetKindV2.RAW.value,
+ 'need_publish': True,
+ 'value': 100,
+ })
+ self.assertEqual(create_response.status_code, HTTPStatus.CREATED)
+
+ self.assertResponseDataEqual(
+ create_response,
+ {
+ 'id': ANY,
+ 'name': 'test_dataset',
+ 'creator_username': 'ada',
+ 'comment': comment,
+ 'path': dataset_path,
+ 'deleted_at': 0,
+ 'data_source': '',
+ 'project_id': 1,
+ 'dataset_kind': DatasetKindV2.RAW.name,
+ 'is_published': False,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'state_frontend': ResourceState.FAILED.value,
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'parent_dataset_job_id': 0,
+ 'workflow_id': 0,
+ 'value': 100,
+ 'schema_checkers': [],
+ 'dataset_type': DatasetType.PSI.value,
+ 'import_type': ImportType.COPY.value,
+ 'store_format': StoreFormat.TFRECORDS.value,
+ 'analyzer_dataset_job_id': 0,
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'AUTHORIZED',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ },
+ ignore_fields=['created_at', 'updated_at', 'uuid'],
+ )
+ with db.session_scope() as session:
+ data_batch: DataBatch = session.query(DataBatch).outerjoin(
+ Dataset, Dataset.id == DataBatch.dataset_id).filter(Dataset.name == name).first()
+ self.assertIsNone(data_batch)
+
+ @patch('fedlearner_webconsole.dataset.services.get_dataset_path')
+ def test_post_processed_datasets(self, mock_get_dataset_path: MagicMock):
+ name = 'test_dataset'
+ dataset_path = os.path.join(self._storage_root, 'dataset/20200608_060606_test-post-dataset')
+ mock_get_dataset_path.return_value = dataset_path
+ dataset_type = DatasetType.PSI.value
+ comment = 'test comment'
+
+ # test bad request
+ create_response = self.post_helper('/api/v2/datasets',
+ data={
+ 'name': name,
+ 'dataset_type': dataset_type,
+ 'comment': comment,
+ 'project_id': 1,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'kind': DatasetKindV2.PROCESSED.value,
+ 'need_publish': True,
+ 'value': 100,
+ 'is_published': False
+ })
+ self.assertEqual(create_response.status_code, HTTPStatus.BAD_REQUEST)
+
+ # test pass
+ create_response = self.post_helper('/api/v2/datasets',
+ data={
+ 'name': name,
+ 'dataset_type': dataset_type,
+ 'comment': comment,
+ 'project_id': 1,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'kind': DatasetKindV2.PROCESSED.value,
+ 'is_published': True,
+ 'value': 100,
+ })
+ self.assertEqual(create_response.status_code, HTTPStatus.CREATED)
+
+ self.assertResponseDataEqual(
+ create_response,
+ {
+ 'id': ANY,
+ 'name': 'test_dataset',
+ 'creator_username': 'ada',
+ 'comment': comment,
+ 'path': dataset_path,
+ 'deleted_at': 0,
+ 'data_source': '',
+ 'project_id': 1,
+ 'dataset_kind': DatasetKindV2.PROCESSED.name,
+ 'is_published': True,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'state_frontend': ResourceState.FAILED.value,
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'parent_dataset_job_id': 0,
+ 'workflow_id': 0,
+ 'value': 100,
+ 'schema_checkers': [],
+ 'dataset_type': DatasetType.PSI.value,
+ 'import_type': ImportType.COPY.value,
+ 'store_format': StoreFormat.TFRECORDS.value,
+ 'analyzer_dataset_job_id': 0,
+ 'publish_frontend_state': 'PUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'AUTHORIZED',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ },
+ ignore_fields=['created_at', 'updated_at', 'uuid'],
+ )
+ with db.session_scope() as session:
+ data_batch: DataBatch = session.query(DataBatch).outerjoin(
+ Dataset, Dataset.id == DataBatch.dataset_id).filter(Dataset.name == name).first()
+ self.assertIsNone(data_batch)
+ dataset = session.query(Dataset).filter(Dataset.name == name).first()
+ self.assertEqual(dataset.ticket_status, TicketStatus.APPROVED)
+
+ @patch('fedlearner_webconsole.dataset.services.get_dataset_path')
+ def test_post_datasets_with_checkers(self, mock_get_dataset_path: MagicMock):
+ dataset_path = os.path.join(self._storage_root, 'dataset/20200608_060606_test-post-dataset')
+ mock_get_dataset_path.return_value = dataset_path
+ create_response = self.post_helper('/api/v2/datasets',
+ data={
+ 'name': 'fake_dataset',
+ 'comment': 'comment',
+ 'project_id': 1,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'kind': DatasetKindV2.RAW.value,
+ 'schema_checkers': ['RAW_ID_CHECKER', 'NUMERIC_COLUMNS_CHECKER']
+ })
+ self.assertEqual(create_response.status_code, HTTPStatus.CREATED)
+
+ self.assertResponseDataEqual(
+ create_response,
+ {
+ 'id': 4,
+ 'name': 'fake_dataset',
+ 'creator_username': 'ada',
+ 'comment': 'comment',
+ 'path': dataset_path,
+ 'deleted_at': 0,
+ 'data_source': '',
+ 'project_id': 1,
+ 'dataset_kind': DatasetKindV2.RAW.name,
+ 'is_published': False,
+ 'dataset_format': DatasetFormat.TABULAR.name,
+ 'state_frontend': ResourceState.FAILED.value,
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'parent_dataset_job_id': 0,
+ 'workflow_id': 0,
+ 'value': 0,
+ 'schema_checkers': ['RAW_ID_CHECKER', 'NUMERIC_COLUMNS_CHECKER'],
+ 'dataset_type': 'PSI',
+ 'import_type': 'COPY',
+ 'store_format': 'TFRECORDS',
+ 'analyzer_dataset_job_id': 0,
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'AUTHORIZED',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ },
+ ignore_fields=['created_at', 'updated_at', 'uuid'],
+ )
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(4)
+ meta_info = dataset.get_meta_info()
+ self.assertEqual(list(meta_info.schema_checkers), ['RAW_ID_CHECKER', 'NUMERIC_COLUMNS_CHECKER'])
+
+ def _fake_schema_check_test_data(self):
+ # schema check test
+ self.dataset_dir = tempfile.mkdtemp()
+ self.dataset_csv = Path(self.dataset_dir).joinpath('test.csv')
+ self.dataset_json = Path(self.dataset_dir).joinpath('validation_jsonschema.json')
+ self.error_dir = Path(self.dataset_dir).joinpath('error')
+ self.error_dir.mkdir()
+ self.error_json = self.error_dir.joinpath('schema_error.json')
+
+ with db.session_scope() as session:
+ self.schema_check_dataset = Dataset(name='schema_check_dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='schema check dataset',
+ path=str(self.dataset_dir),
+ project_id=1)
+ session.add(self.schema_check_dataset)
+ session.flush()
+ self.schema_check_batch = DataBatch(dataset_id=self.schema_check_dataset.id,
+ event_time=datetime(2021, 10, 28, 16, 37, 37),
+ comment='schema check batch')
+ session.add(self.schema_check_batch)
+ session.commit()
+
+ def __del__(self):
+ # delete the dataset path, created in function: test_post_datasets
+ dataset_path = os.path.join(self._storage_root, 'dataset/20200608_060606_test-post-dataset')
+ if self._file_manager.isdir(dataset_path):
+ self._file_manager.remove(dataset_path)
+
+
+class DatasetExportApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test_project')
+ session.add(project)
+ session.flush([project])
+
+ dataset = Dataset(id=1,
+ name='test_dataset',
+ dataset_type=DatasetType.PSI,
+ uuid='dataset uuid',
+ comment='comment',
+ path='/data/dataset/321',
+ project_id=project.id,
+ dataset_format=DatasetFormat.NONE_STRUCTURED.value,
+ store_format=StoreFormat.UNKNOWN,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ data_batch = DataBatch(id=1, name='0', comment='comment', dataset_id=1, path='/data/dataset/321/batch/0')
+ session.add_all([dataset, data_batch])
+
+ streaming_dataset = Dataset(id=2,
+ name='test_streaming_dataset',
+ dataset_type=DatasetType.STREAMING,
+ uuid='streaming dataset uuid',
+ comment='comment',
+ path='/data/dataset/streaming_dataset',
+ project_id=project.id,
+ dataset_format=DatasetFormat.TABULAR.value,
+ store_format=StoreFormat.TFRECORDS,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ streaming_data_batch_1 = DataBatch(id=2,
+ name='20220101',
+ comment='comment',
+ dataset_id=2,
+ path='/data/dataset/321/batch/20220101',
+ event_time=datetime(2022, 1, 1))
+ streaming_data_batch_2 = DataBatch(id=3,
+ name='20220102',
+ comment='comment',
+ dataset_id=2,
+ path='/data/dataset/321/batch/20220102',
+ event_time=datetime(2022, 1, 2))
+ session.add_all([streaming_dataset, streaming_data_batch_1, streaming_data_batch_2])
+
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.apis.SettingService.get_system_info',
+ lambda: setting_pb2.SystemInfo(pure_domain_name='test_domain'))
+ @patch('fedlearner_webconsole.dataset.apis.Envs.STORAGE_ROOT', '/data')
+ @patch('fedlearner_webconsole.utils.file_manager.FileManager.isdir', lambda *args: True)
+ @patch('fedlearner_webconsole.dataset.models.DataBatch.is_available', lambda _: True)
+ def test_export_dataset_none_streaming(self):
+ export_path = '/data/user_home/export_dataset'
+ resp = self.post_helper('/api/v2/datasets/1:export', {
+ 'export_path': export_path,
+ 'batch_id': 1,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ resp_data = self.get_response_data(resp)
+ self.assertEqual(resp_data, {'export_dataset_id': 3, 'dataset_job_id': 1})
+ export_dataset_name = 'export-test_dataset-0-0'
+ with db.session_scope() as session:
+ export_dataset: Dataset = session.query(Dataset).filter(Dataset.name == export_dataset_name).first()
+ self.assertEqual(export_dataset.dataset_kind, DatasetKindV2.EXPORTED)
+ self.assertEqual(export_dataset.store_format, StoreFormat.UNKNOWN)
+ self.assertEqual(export_dataset.dataset_type, DatasetType.PSI)
+ self.assertEqual(export_dataset.path, 'file://' + export_path)
+ self.assertFalse(export_dataset.is_published)
+ batch = export_dataset.get_single_batch()
+ self.assertEqual(batch.batch_name, '0')
+ self.assertEqual(batch.path, 'file://' + export_path + '/batch/0')
+ dataset_job: DatasetJob = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.kind, DatasetJobKind.EXPORT)
+ self.assertEqual(
+ dataset_job.get_global_configs(),
+ dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={'test_domain': dataset_pb2.DatasetJobConfig(dataset_uuid='dataset uuid')}))
+ dataset_job_stagees = dataset_job.dataset_job_stages
+ self.assertEqual(len(dataset_job_stagees), 1)
+ self.assertEqual(dataset_job_stagees[0].data_batch_id, batch.id)
+
+ @patch('fedlearner_webconsole.dataset.apis.SettingService.get_system_info',
+ lambda: setting_pb2.SystemInfo(pure_domain_name='test_domain'))
+ @patch('fedlearner_webconsole.dataset.apis.Envs.STORAGE_ROOT', '/data')
+ @patch('fedlearner_webconsole.utils.file_manager.FileManager.isdir', lambda *args: True)
+ @patch('fedlearner_webconsole.dataset.models.DataBatch.is_available', lambda _: True)
+ def test_export_dataset_streaming(self):
+ export_path = '/data/user_home/export_dataset'
+ export_path_with_space = ' ' + export_path + ' '
+ resp = self.post_helper('/api/v2/datasets/2:export', {'export_path': export_path_with_space})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ resp_data = self.get_response_data(resp)
+ self.assertEqual(resp_data, {'export_dataset_id': 3, 'dataset_job_id': 1})
+ export_dataset_name = 'export-test_streaming_dataset-0'
+ with db.session_scope() as session:
+ export_dataset: Dataset = session.query(Dataset).filter(Dataset.name == export_dataset_name).first()
+ self.assertEqual(export_dataset.dataset_kind, DatasetKindV2.EXPORTED)
+ self.assertEqual(export_dataset.store_format, StoreFormat.CSV)
+ self.assertEqual(export_dataset.dataset_type, DatasetType.STREAMING)
+ self.assertEqual(export_dataset.path, 'file://' + export_path)
+ self.assertFalse(export_dataset.is_published)
+ batches = export_dataset.data_batches
+ self.assertEqual(len(batches), 2)
+ self.assertEqual(batches[0].batch_name, '20220102')
+ self.assertEqual(batches[0].path, 'file://' + export_path + '/batch/20220102')
+ self.assertEqual(batches[0].event_time, datetime(2022, 1, 2))
+ self.assertEqual(batches[1].batch_name, '20220101')
+ self.assertEqual(batches[1].path, 'file://' + export_path + '/batch/20220101')
+ self.assertEqual(batches[1].event_time, datetime(2022, 1, 1))
+ dataset_job: DatasetJob = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.kind, DatasetJobKind.EXPORT)
+ self.assertEqual(
+ dataset_job.get_global_configs(),
+ dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={'test_domain': dataset_pb2.DatasetJobConfig(
+ dataset_uuid='streaming dataset uuid')}))
+ dataset_job_stagees = dataset_job.dataset_job_stages
+ self.assertEqual(len(dataset_job_stagees), 2)
+
+
+class BatchesApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test_project')
+ session.add(project)
+ session.flush([project])
+
+ dataset = Dataset(name='test_dataset',
+ dataset_type=DatasetType.PSI,
+ uuid=resource_uuid(),
+ comment='comment',
+ path='/data/dataset/321',
+ project_id=project.id,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(dataset)
+
+ data_source = DataSource(name='test_datasource',
+ uuid=resource_uuid(),
+ path='/upload/',
+ project_id=project.id,
+ dataset_kind=DatasetKindV2.SOURCE)
+ session.add(data_source)
+
+ session.commit()
+ self._project_id = project.id
+ self._dataset_id = dataset.id
+ self._data_source_id = data_source.id
+ self._data_source_uuid = data_source.uuid
+
+ def test_get_data_batches(self):
+ with db.session_scope() as session:
+ data_batch_1 = DataBatch(dataset_id=self._dataset_id,
+ name='20220101',
+ event_time=datetime(2022, 1, 1),
+ created_at=datetime(2022, 1, 1, 0, 0, 0),
+ comment='batch_1',
+ path='/data/dataset/123/batch/20220101')
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(dataset_id=self._dataset_id,
+ name='20220102',
+ event_time=datetime(2022, 1, 2),
+ created_at=datetime(2022, 1, 2, 0, 0, 0),
+ comment='batch_2',
+ path='/data/dataset/123/batch/20220102')
+ session.add(data_batch_2)
+ session.commit()
+ sorter_param = urllib.parse.quote('created_at asc')
+ response = self.get_helper(
+ f'/api/v2/datasets/{self._dataset_id}/batches?page=1&page_size=5&order_by={sorter_param}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ response,
+ [{
+ 'id': 1,
+ 'dataset_id': 1,
+ 'comment': 'batch_1',
+ 'created_at': to_timestamp(datetime(2022, 1, 1, 0, 0, 0)),
+ 'event_time': to_timestamp(datetime(2022, 1, 1)),
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'name': '20220101',
+ 'path': '/data/dataset/123/batch/20220101',
+ 'state': 'FAILED',
+ 'latest_parent_dataset_job_stage_id': 0,
+ 'latest_analyzer_dataset_job_stage_id': 0,
+ }, {
+ 'id': 2,
+ 'dataset_id': 1,
+ 'comment': 'batch_2',
+ 'created_at': to_timestamp(datetime(2022, 1, 2, 0, 0, 0)),
+ 'event_time': to_timestamp(datetime(2022, 1, 2)),
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'name': '20220102',
+ 'path': '/data/dataset/123/batch/20220102',
+ 'state': 'FAILED',
+ 'latest_parent_dataset_job_stage_id': 0,
+ 'latest_analyzer_dataset_job_stage_id': 0,
+ }],
+ ignore_fields=['updated_at'],
+ )
+ self.assertEqual(
+ json.loads(response.data).get('page_meta'), {
+ 'current_page': 1,
+ 'page_size': 5,
+ 'total_pages': 1,
+ 'total_items': 2
+ })
+
+
+class BatchApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test_project')
+ session.add(project)
+ session.flush([project])
+
+ dataset = Dataset(name='test_dataset',
+ dataset_type=DatasetType.STREAMING,
+ uuid=resource_uuid(),
+ comment='comment',
+ path='/data/dataset/123',
+ project_id=project.id,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(dataset)
+ session.flush()
+
+ data_batch = DataBatch(dataset_id=dataset.id,
+ name='20220101',
+ event_time=datetime(2022, 1, 1),
+ created_at=datetime(2022, 1, 1, 0, 0, 0),
+ comment='batch_1',
+ path='/data/dataset/123/batch/20220101')
+ session.add(data_batch)
+ session.flush()
+
+ session.commit()
+ self._project_id = project.id
+ self._dataset_id = dataset.id
+ self._data_batch_id = data_batch.id
+
+ def test_get_data_batch(self):
+ response = self.get_helper(f'/api/v2/datasets/{self._dataset_id}/batches/{self._data_batch_id}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ response,
+ {
+ 'id': self._data_batch_id,
+ 'dataset_id': self._dataset_id,
+ 'comment': 'batch_1',
+ 'created_at': to_timestamp(datetime(2022, 1, 1, 0, 0, 0)),
+ 'event_time': to_timestamp(datetime(2022, 1, 1)),
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'name': '20220101',
+ 'path': '/data/dataset/123/batch/20220101',
+ 'state': 'FAILED',
+ 'latest_parent_dataset_job_stage_id': 0,
+ 'latest_analyzer_dataset_job_stage_id': 0,
+ },
+ ignore_fields=['updated_at'],
+ )
+
+
+class BatchMetricsApiTest(BaseTestCase):
+
+ def test_get_batch_metrics(self):
+ with db.session_scope() as session:
+ default_dataset = Dataset(id=1,
+ name='dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path=os.path.join(tempfile.gettempdir(), 'dataset/123'),
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_format=DatasetFormat.TABULAR.value,
+ created_at=datetime(2012, 1, 14, 12, 0, 6))
+ session.add(default_dataset)
+ default_databatch = DataBatch(name='20220101',
+ id=111,
+ event_time=datetime(2022, 1, 1),
+ comment='comment',
+ state=BatchState.NEW,
+ dataset_id=1,
+ path='/data/test/batch/20220101')
+ session.add(default_databatch)
+ session.commit()
+ meta_file = DatasetDirectory(dataset_path=default_dataset.path).batch_meta_file(batch_name='20220101')
+ gfile.makedirs(meta_file.split('/_META')[0])
+ meta_data = {
+ 'dtypes': [{
+ 'key': 'f01',
+ 'value': 'bigint'
+ }],
+ 'sample': [
+ [
+ 1,
+ ],
+ [
+ 0,
+ ],
+ ],
+ 'count': 0,
+ 'features': {
+ 'f01': {
+ 'count': '2',
+ 'mean': '0.0015716767309123998',
+ 'stddev': '0.03961485047808605',
+ 'min': '0',
+ 'max': '1',
+ 'missing_count': '0'
+ },
+ },
+ 'hist': {
+ 'f01': {
+ 'x': [
+ 0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6000000000000001, 0.7000000000000001, 0.8, 0.9,
+ 1
+ ],
+ 'y': [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
+ },
+ },
+ }
+ with gfile.GFile(meta_file, 'w') as f:
+ f.write(json.dumps(meta_data))
+
+ feat_name = 'f01'
+ feature_response = self.get_helper(f'/api/v2/datasets/1/batches/111/feature_metrics?name={feat_name}')
+ self.assertEqual(feature_response.status_code, 200)
+ feature_data = self.get_response_data(feature_response)
+ golden_feature = {
+ 'name': feat_name,
+ 'metrics': {
+ 'count': '2',
+ 'mean': '0.0015716767309123998',
+ 'stddev': '0.03961485047808605',
+ 'min': '0',
+ 'max': '1',
+ 'missing_count': '0'
+ },
+ 'hist': {
+ 'x': [
+ 0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6000000000000001, 0.7000000000000001, 0.8, 0.9, 1
+ ],
+ 'y': [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
+ },
+ }
+ self.assertEqual(feature_data, golden_feature, 'should has feature data')
+
+
+class BatchesAnalyzeApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.dataset.apis.get_pure_domain_name', lambda _: 'test_domain')
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobService.create_as_coordinator')
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobStageService.create_dataset_job_stage_as_coordinator')
+ def test_analyze_data_batch(self, create_dataset_job_stage_as_coordinator: MagicMock,
+ mock_create_as_coordinator: MagicMock):
+ with db.session_scope() as session:
+ project = Project(id=1, name='test-project')
+ session.add(project)
+ dataset = Dataset(id=1,
+ name='default dataset',
+ uuid='dataset_uuid',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment2',
+ path='data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='20220101',
+ comment='comment',
+ event_time=datetime(2022, 1, 1),
+ dataset_id=1,
+ path='/data/dataset/123/batch/20220101')
+ session.add(data_batch)
+ session.commit()
+ mock_dataset_job = DatasetJob(id=1,
+ name='analyzer_dataset_job',
+ uuid='123',
+ project_id=1,
+ output_dataset_id=1,
+ input_dataset_id=1,
+ kind=DatasetJobKind.ANALYZER,
+ coordinator_id=0,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2022, 1, 1),
+ updated_at=datetime(2022, 1, 1),
+ creator_username='test user')
+ mock_dataset_job.set_global_configs(global_configs=dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ dataset_pb2.DatasetJobConfig(dataset_uuid='u123',
+ variables=[
+ Variable(name='name1', value='value1'),
+ Variable(name='name2', value='value2'),
+ Variable(name='name3', value='value3')
+ ])
+ }))
+ mock_create_as_coordinator.return_value = mock_dataset_job
+ response = self.post_helper(
+ '/api/v2/datasets/1/batches/1:analyze', {
+ 'dataset_job_config': {
+ 'variables': [{
+ 'name': 'name1',
+ 'value': 'value1',
+ }, {
+ 'name': 'name2',
+ 'value': 'value2',
+ }, {
+ 'name': 'name3',
+ 'value': 'value3',
+ }]
+ }
+ })
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.maxDiff = None
+ self.assertResponseDataEqual(
+ response, {
+ 'name': 'analyzer_dataset_job',
+ 'uuid': '123',
+ 'project_id': 1,
+ 'kind': 'ANALYZER',
+ 'state': 'PENDING',
+ 'created_at': to_timestamp(datetime(2022, 1, 1)),
+ 'updated_at': to_timestamp(datetime(2022, 1, 1)),
+ 'result_dataset_uuid': '',
+ 'result_dataset_name': '',
+ 'is_ready': False,
+ 'input_data_batch_num_example': 0,
+ 'output_data_batch_num_example': 0,
+ 'id': 1,
+ 'coordinator_id': 0,
+ 'workflow_id': 0,
+ 'finished_at': 0,
+ 'started_at': 0,
+ 'has_stages': False,
+ 'creator_username': 'test user',
+ 'scheduler_state': '',
+ 'global_configs': ANY,
+ 'time_range': {
+ 'days': 0,
+ 'hours': 0,
+ },
+ 'scheduler_message': '',
+ })
+ create_dataset_job_stage_as_coordinator.assert_called_once_with(
+ project_id=1,
+ dataset_job_id=1,
+ output_data_batch_id=1,
+ global_configs=dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ dataset_pb2.DatasetJobConfig(dataset_uuid='u123',
+ variables=[
+ Variable(name='name1', value='value1'),
+ Variable(name='name2', value='value2'),
+ Variable(name='name3', value='value3')
+ ])
+ }))
+
+
+class BatchRerunApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.dataset.apis.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobStageService.create_dataset_job_stage_as_coordinator')
+ def test_rerun_batch(self, mock_create_dataset_job_stage_as_coordinator: MagicMock, mock_list_flags: MagicMock):
+ with db.session_scope() as session:
+ project = Project(id=1, name='test-project')
+ session.add(project)
+ dataset = Dataset(id=1,
+ name='default dataset',
+ uuid='dataset_uuid',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment2',
+ path='data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='20220101',
+ comment='comment',
+ event_time=datetime(2022, 1, 1),
+ dataset_id=1,
+ path='/data/dataset/123/batch/20220101')
+ session.add(data_batch)
+ dataset_job = DatasetJob(id=1,
+ name='default dataset_job',
+ uuid='u123',
+ project_id=1,
+ output_dataset_id=1,
+ input_dataset_id=1,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN,
+ coordinator_id=0,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2022, 1, 1),
+ updated_at=datetime(2022, 1, 1),
+ creator_username='test user')
+ dataset_job.set_global_configs(global_configs=dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={
+ 'coordinator_domain':
+ dataset_pb2.DatasetJobConfig(dataset_uuid='u123',
+ variables=[
+ Variable(name='name1',
+ typed_value=Value(string_value='value1-1'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name2',
+ typed_value=Value(string_value='value1-2'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name3',
+ typed_value=Value(string_value='value1-3'),
+ value_type=Variable.ValueType.STRING),
+ ]),
+ 'participant_domain':
+ dataset_pb2.DatasetJobConfig(dataset_uuid='u123',
+ variables=[
+ Variable(name='name1',
+ typed_value=Value(string_value='value1-1'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name2',
+ typed_value=Value(string_value='value1-2'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name3',
+ typed_value=Value(string_value='value1-3'),
+ value_type=Variable.ValueType.STRING),
+ ]),
+ }))
+ session.add(dataset_job)
+ participant = Participant(id=1, name='participant_1', domain_name='fl-fake_domain_name_1.com')
+ session.add(participant)
+ session.commit()
+
+ mock_create_dataset_job_stage_as_coordinator.return_value = DatasetJobStage(id=1,
+ name='mock stage',
+ uuid='fake stage uuid',
+ state=DatasetJobState.PENDING,
+ dataset_job_id=1,
+ data_batch_id=1,
+ coordinator_id=1,
+ created_at=datetime(2022, 1, 1),
+ updated_at=datetime(2022, 1, 1))
+ rerun_config = {
+ 'dataset_job_parameter': {
+ 'global_configs': {
+ 'fl-coordinator_domain.com': {
+ 'variables': [{
+ 'name': 'name1',
+ 'typed_value': 'value2-1',
+ 'value_type': 'STRING',
+ }, {
+ 'name': 'name2',
+ 'typed_value': 'value2-2',
+ 'value_type': 'STRING',
+ }, {
+ 'name': 'name3',
+ 'typed_value': 'value2-3',
+ 'value_type': 'STRING',
+ }]
+ },
+ 'fl-participant_domain.com': {
+ 'variables': [{
+ 'name': 'name1',
+ 'typed_value': 'value2-1',
+ 'value_type': 'STRING',
+ }]
+ },
+ },
+ },
+ }
+ response = self.post_helper('/api/v2/datasets/1/batches/1:rerun', rerun_config)
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_create_dataset_job_stage_as_coordinator.assert_called_once_with(
+ project_id=1,
+ dataset_job_id=1,
+ output_data_batch_id=1,
+ global_configs=dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={
+ 'coordinator_domain':
+ dataset_pb2.DatasetJobConfig(dataset_uuid='u123',
+ variables=[
+ Variable(name='name1',
+ typed_value=Value(string_value='value2-1'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name2',
+ typed_value=Value(string_value='value2-2'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name3',
+ typed_value=Value(string_value='value2-3'),
+ value_type=Variable.ValueType.STRING),
+ ]),
+ 'participant_domain':
+ dataset_pb2.DatasetJobConfig(dataset_uuid='u123',
+ variables=[
+ Variable(name='name1',
+ typed_value=Value(string_value='value2-1'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name2',
+ typed_value=Value(string_value='value1-2'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='name3',
+ typed_value=Value(string_value='value1-3'),
+ value_type=Variable.ValueType.STRING),
+ ]),
+ }))
+
+ # test participant not support rerun
+ mock_list_flags.return_value = {'data_batch_rerun_enabled': False}
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.coordinator_id = 1
+ session.commit()
+ response = self.post_helper('/api/v2/datasets/1/batches/1:rerun', rerun_config)
+ self.assertEqual(response.status_code, HTTPStatus.METHOD_NOT_ALLOWED)
+
+
+class DataSourcesApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ default_project = Project(id=1, name='default_project')
+ datasource_1 = DataSource(id=100,
+ uuid=resource_uuid(),
+ name='datasource_1',
+ creator_username='test',
+ path='hdfs:///data/fake_path_1',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ is_published=False,
+ store_format=StoreFormat.TFRECORDS,
+ dataset_format=DatasetFormat.IMAGE.value,
+ dataset_type=DatasetType.STREAMING)
+ datasource_1.set_meta_info(meta=dataset_pb2.DatasetMetaInfo(datasource_type=DataSourceType.HDFS.value))
+ datasource_2 = DataSource(id=101,
+ uuid=resource_uuid(),
+ name='datasource_2',
+ creator_username='test',
+ path='hdfs:///data/fake_path_2',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 6),
+ is_published=False,
+ store_format=StoreFormat.CSV,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_type=DatasetType.PSI)
+ datasource_2.set_meta_info(meta=dataset_pb2.DatasetMetaInfo(datasource_type=DataSourceType.HDFS.value))
+ session.add(default_project)
+ session.add(datasource_1)
+ session.add(datasource_2)
+ session.commit()
+
+ def test_parse_data_source_url(self):
+ url = 'hdfs:///home/test'
+ data_source = dataset_pb2.DataSource(type=DataSourceType.HDFS.value,
+ url='hdfs:///home/test',
+ is_user_upload=False)
+ self.assertEqual(_parse_data_source_url(url), data_source)
+
+ url = '/data/test'
+ data_source = dataset_pb2.DataSource(type=DataSourceType.FILE.value,
+ url='file:///data/test',
+ is_user_upload=False)
+ self.assertEqual(_parse_data_source_url(url), data_source)
+
+ url = '/data/test'
+ data_source = dataset_pb2.DataSource(type=DataSourceType.FILE.value,
+ url='file:///data/test',
+ is_user_upload=False)
+ self.assertEqual(_parse_data_source_url(url), data_source)
+ url = 'hdfs/'
+ with self.assertRaises(ValidationError):
+ _parse_data_source_url(url)
+
+ @patch('fedlearner_webconsole.utils.file_manager.FileManager.listdir')
+ @patch('fedlearner_webconsole.utils.file_manager.FileManager.isdir', fake_isdir)
+ @patch('fedlearner_webconsole.dataset.apis.Envs.STORAGE_ROOT', new_callable=PropertyMock)
+ def test_data_source_check_connection(self, mock_storage_root: MagicMock, mock_listdir: MagicMock):
+ mock_storage_root.return_value = 'hdfs:///home/'
+ mock_listdir.return_value = ['_SUCCESS', 'test.csv']
+ resp = self.post_helper('/api/v2/data_sources:check_connection', {
+ 'data_source_url': 'hdfs:///home/',
+ 'file_num': 1,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {
+ 'file_names': ['_SUCCESS',],
+ 'extra_nums': 1,
+ })
+
+ mock_storage_root.reset_mock()
+ mock_storage_root.return_value = 'file:///data'
+ resp = self.post_helper('/api/v2/data_sources:check_connection', {'data_source_url': 'file:///data/test'})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {
+ 'file_names': ['_SUCCESS', 'test.csv'],
+ 'extra_nums': 0,
+ })
+
+ resp = self.post_helper('/api/v2/data_sources:check_connection', {'data_source_url': 'file:///data/fake_path'})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ resp = self.post_helper('/api/v2/data_sources:check_connection', {})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertIn('required', resp.json.get('details').get('json').get('data_source_url')[0])
+
+ resp = self.post_helper('/api/v2/data_sources:check_connection', {'data_source_url': 'hdfs:/home/'})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(resp.json.get('details'), 'invalid data_source_url: hdfs:/home/')
+
+ mock_listdir.reset_mock()
+ mock_listdir.return_value = ['20220801', '20220802']
+ resp = self.post_helper('/api/v2/data_sources:check_connection', {
+ 'data_source_url': 'hdfs:///home/',
+ 'dataset_type': 'STREAMING'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {
+ 'file_names': ['20220801', '20220802'],
+ 'extra_nums': 0,
+ })
+
+ mock_listdir.reset_mock()
+ mock_listdir.return_value = ['20220803-15', '2022080316']
+ resp = self.post_helper('/api/v2/data_sources:check_connection', {
+ 'data_source_url': 'hdfs:///home/',
+ 'dataset_type': 'STREAMING'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(resp.json.get('details'), 'illegal dir format: 2022080316')
+
+ @patch('fedlearner_webconsole.dataset.apis.Envs.STORAGE_ROOT', '2022080316')
+ @patch('fedlearner_webconsole.dataset.apis._validate_data_source', lambda *args: None)
+ def test_post_data_source(self):
+ resp = self.post_helper(
+ '/api/v2/data_sources', {
+ 'data_source': {
+ 'name': 'test',
+ 'comment': 'test comment',
+ 'data_source_url': 'hdfs:///home/fake_path',
+ },
+ 'project_id': 1
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ self.assertResponseDataEqual(resp, {
+ 'type': DataSourceType.HDFS.value,
+ 'url': 'hdfs:///home/fake_path',
+ 'name': 'test',
+ 'creator_username': 'ada',
+ 'project_id': 1,
+ 'is_user_upload': False,
+ 'is_user_export': False,
+ 'dataset_format': 'TABULAR',
+ 'store_format': 'UNKNOWN',
+ 'dataset_type': 'PSI',
+ 'comment': 'test comment',
+ },
+ ignore_fields=['created_at', 'id', 'uuid'])
+
+ resp_upload = self.post_helper('/api/v2/data_sources', {
+ 'data_source': {
+ 'name': 'test',
+ 'data_source_url': '/home/fake_path',
+ 'is_user_upload': True,
+ },
+ 'project_id': 1
+ })
+ self.assertEqual(resp_upload.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(
+ resp_upload.json.get('details'),
+ {'json': {
+ 'data_source': {
+ 'data_source_url': ['no access to unauchority path file:///home/fake_path!']
+ }
+ }})
+
+ resp_upload_hdfs = self.post_helper(
+ '/api/v2/data_sources', {
+ 'data_source': {
+ 'name': 'test',
+ 'data_source_url': 'hdfs:///home/fake_path',
+ 'is_user_upload': True,
+ 'dataset_format': 'TABULAR',
+ 'store_format': 'TFRECORDS',
+ 'dataset_type': 'STREAMING',
+ },
+ 'project_id': 1
+ })
+ self.assertEqual(resp_upload_hdfs.status_code, HTTPStatus.CREATED)
+ self.assertResponseDataEqual(resp_upload_hdfs, {
+ 'type': DataSourceType.HDFS.value,
+ 'url': 'hdfs:///home/fake_path',
+ 'name': 'test',
+ 'creator_username': 'ada',
+ 'project_id': 1,
+ 'is_user_upload': True,
+ 'is_user_export': False,
+ 'dataset_format': 'TABULAR',
+ 'store_format': 'TFRECORDS',
+ 'dataset_type': 'STREAMING',
+ 'comment': '',
+ },
+ ignore_fields=['created_at', 'id', 'uuid'])
+
+ resp_err = self.post_helper('/api/v2/data_sources', {
+ 'data_source': {
+ 'name': 'test',
+ 'data_source_url': 'fake:///home/fake_path',
+ },
+ 'project_id': 1
+ })
+ self.assertEqual(resp_err.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(resp_err.json.get('details'),
+ {'json': {
+ 'data_source': {
+ 'data_source_url': ['fake is not a supported data_source type']
+ }
+ }})
+
+ def test_delete_data_source(self):
+ resp = self.delete_helper('/api/v2/data_sources/100')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ dataset = session.query(DataSource).get(100)
+ self.assertIsNone(dataset)
+
+ def test_get_data_sources(self):
+ expected_result = [{
+ 'id': 101,
+ 'uuid': ANY,
+ 'name': 'datasource_2',
+ 'comment': '',
+ 'creator_username': 'test',
+ 'url': 'hdfs:///data/fake_path_2',
+ 'type': DataSourceType.HDFS.value,
+ 'project_id': 1,
+ 'created_at': to_timestamp(datetime(2012, 1, 14, 12, 0, 6)),
+ 'is_user_upload': False,
+ 'is_user_export': False,
+ 'dataset_format': 'TABULAR',
+ 'store_format': 'CSV',
+ 'dataset_type': 'PSI',
+ }, {
+ 'id': 100,
+ 'uuid': ANY,
+ 'name': 'datasource_1',
+ 'comment': '',
+ 'creator_username': 'test',
+ 'url': 'hdfs:///data/fake_path_1',
+ 'type': DataSourceType.HDFS.value,
+ 'project_id': 1,
+ 'created_at': to_timestamp(datetime(2012, 1, 14, 12, 0, 5)),
+ 'is_user_upload': False,
+ 'is_user_export': False,
+ 'dataset_format': 'IMAGE',
+ 'store_format': 'TFRECORDS',
+ 'dataset_type': 'STREAMING',
+ }]
+ resp = self.get_helper('/api/v2/data_sources')
+ self.assertResponseDataEqual(resp, expected_result)
+
+ resp = self.get_helper('/api/v2/data_sources?project_id=1')
+ self.assertResponseDataEqual(resp, expected_result)
+
+ resp = self.get_helper('/api/v2/data_sources?project_id=10')
+ self.assertResponseDataEqual(resp, [])
+
+ def test_get_data_source(self):
+ resp = self.get_helper('/api/v2/data_sources/100')
+ self.assertEqual(resp.status_code, 200)
+ self.assertResponseDataEqual(
+ resp, {
+ 'id': 100,
+ 'uuid': ANY,
+ 'name': 'datasource_1',
+ 'comment': '',
+ 'creator_username': 'test',
+ 'url': 'hdfs:///data/fake_path_1',
+ 'type': DataSourceType.HDFS.value,
+ 'project_id': 1,
+ 'created_at': to_timestamp(datetime(2012, 1, 14, 12, 0, 5)),
+ 'is_user_upload': False,
+ 'is_user_export': False,
+ 'dataset_format': 'IMAGE',
+ 'store_format': 'TFRECORDS',
+ 'dataset_type': 'STREAMING',
+ })
+ resp = self.get_helper('/api/v2/data_sources/1')
+ self.assertEqual(resp.status_code, 404)
+
+ @patch('fedlearner_webconsole.dataset.apis.Envs.STORAGE_ROOT', '/data')
+ def test_path_authority_validator(self):
+ _path_authority_validator('/data/test')
+ _path_authority_validator('hdfs:///home')
+ _path_authority_validator('file:///data/test')
+ with self.assertRaises(ValidationError):
+ _path_authority_validator('fake')
+ with self.assertRaises(ValidationError):
+ _path_authority_validator('/fake')
+ with self.assertRaises(ValidationError):
+ _path_authority_validator('file:///fake')
+
+
+class DataSourceTreeApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.utils.file_tree.FileTreeBuilder.build_with_root')
+ def test_get_tree(self, mock_build_with_root: MagicMock):
+ mock_build_with_root.return_value = FileTreeNode(filename='20221101',
+ path='20221101',
+ is_directory=True,
+ size=1024,
+ mtime=0,
+ files=[
+ FileTreeNode(filename='test.csv',
+ path='20221101/test.csv',
+ is_directory=False,
+ size=1024,
+ mtime=0),
+ ])
+ with db.session_scope() as session:
+ data_source = DataSource(id=100,
+ uuid=resource_uuid(),
+ name='datasource_1',
+ creator_username='test',
+ path='hdfs:///data/fake_path_1',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ is_published=False,
+ store_format=StoreFormat.TFRECORDS,
+ dataset_format=DatasetFormat.IMAGE.value,
+ dataset_type=DatasetType.STREAMING)
+ session.add(data_source)
+ session.commit()
+ resp = self.get_helper('/api/v2/data_sources/100/tree')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ resp, {
+ 'filename':
+ '20221101',
+ 'path':
+ '20221101',
+ 'is_directory':
+ True,
+ 'size':
+ 1024,
+ 'mtime':
+ 0,
+ 'files': [{
+ 'filename': 'test.csv',
+ 'path': '20221101/test.csv',
+ 'is_directory': False,
+ 'size': 1024,
+ 'mtime': 0,
+ 'files': [],
+ }],
+ })
+
+
+class ParticipantDatasetApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=10, name='test-project')
+ participant_1 = Participant(id=10, name='participant_1', domain_name='fake_domain_name_1')
+ project_participant_1 = ProjectParticipant(project_id=project.id, participant_id=participant_1.id)
+
+ session.add(project)
+ session.add(participant_1)
+ session.add(project_participant_1)
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.apis.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.list_datasets')
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.list_participant_datasets')
+ def test_get_paricipant(self, mock_list_participant_datasets: MagicMock, mock_list_datasets: MagicMock,
+ mock_list_flags: MagicMock):
+ dataref_1 = dataset_pb2.ParticipantDatasetRef(uuid='1',
+ name='fake_dataset_1',
+ format=DatasetFormat.TABULAR.name,
+ file_size=1000,
+ dataset_kind=DatasetKindV2.RAW.name,
+ dataset_type=DatasetType.PSI.name,
+ auth_status='PENDING')
+ dataref_2 = dataset_pb2.ParticipantDatasetRef(uuid='2',
+ name='fake_dataset_2',
+ format=DatasetFormat.TABULAR.name,
+ file_size=1000,
+ dataset_kind=DatasetKindV2.PROCESSED.name,
+ dataset_type=DatasetType.PSI.name,
+ auth_status='PENDING')
+ mock_return = service_pb2.ListParticipantDatasetsResponse(participant_datasets=[dataref_1, dataref_2])
+ mock_list_participant_datasets.return_value = mock_return
+ mock_list_flags.return_value = {'list_datasets_rpc_enabled': False}
+
+ # test no filter
+ resp = self.get_helper('/api/v2/project/10/participant_datasets')
+ self.assertEqual(resp.status_code, 200)
+ expect_data = [{
+ 'uuid': '1',
+ 'project_id': 10,
+ 'name': 'fake_dataset_1',
+ 'participant_id': 10,
+ 'format': DatasetFormat.TABULAR.name,
+ 'file_size': 1000,
+ 'updated_at': 0,
+ 'value': 0,
+ 'dataset_kind': 'RAW',
+ 'dataset_type': 'PSI',
+ 'auth_status': 'PENDING',
+ }, {
+ 'uuid': '2',
+ 'project_id': 10,
+ 'name': 'fake_dataset_2',
+ 'participant_id': 10,
+ 'format': DatasetFormat.TABULAR.name,
+ 'file_size': 1000,
+ 'updated_at': 0,
+ 'value': 0,
+ 'dataset_kind': 'PROCESSED',
+ 'dataset_type': 'PSI',
+ 'auth_status': 'PENDING',
+ }]
+ resp_data = self.get_response_data(resp)
+ self.assertCountEqual(resp_data, expect_data)
+ mock_list_participant_datasets.assert_called_once_with(kind=None, uuid=None)
+ mock_list_datasets.assert_not_called()
+ mock_list_participant_datasets.reset_mock()
+
+ # test filter uuid
+ mock_list_participant_datasets.return_value = mock_return
+ resp = self.get_helper('/api/v2/project/10/participant_datasets?uuid=1')
+ self.assertEqual(resp.status_code, 200)
+ expect_data = [{
+ 'uuid': '1',
+ 'project_id': 10,
+ 'name': 'fake_dataset_1',
+ 'participant_id': 10,
+ 'format': DatasetFormat.TABULAR.name,
+ 'file_size': 1000,
+ 'updated_at': 0,
+ 'value': 0,
+ 'dataset_kind': 'RAW',
+ 'dataset_type': 'PSI',
+ 'auth_status': 'PENDING',
+ }]
+ self.assertResponseDataEqual(resp, expect_data)
+ mock_list_participant_datasets.assert_called_once_with(kind=None, uuid='1')
+ mock_list_participant_datasets.reset_mock()
+
+ # test illegal kind
+ mock_list_participant_datasets.return_value = service_pb2.ListParticipantDatasetsResponse()
+ resp = self.get_helper('/api/v2/project/10/participant_datasets?kind=unkown')
+ self.assertEqual(resp.status_code, 400)
+ mock_list_participant_datasets.assert_not_called()
+
+ # test filter kind
+ resp = self.get_helper('/api/v2/project/10/participant_datasets?kind=raw')
+ self.assertEqual(resp.status_code, 200)
+ mock_list_participant_datasets.assert_called_once_with(kind='raw', uuid=None)
+
+ # test filter participant_id
+ mock_list_participant_datasets.reset_mock()
+ mock_list_participant_datasets.return_value = mock_return
+ resp = self.get_helper('/api/v2/project/10/participant_datasets?participant_id=10')
+ self.assertEqual(resp.status_code, 200)
+ self.assertEqual(len(self.get_response_data(resp)), 2)
+ mock_list_participant_datasets.assert_called_once_with(kind=None, uuid=None)
+
+ # test filter participant_id not found
+ mock_list_participant_datasets.reset_mock()
+ mock_list_participant_datasets.return_value = mock_return
+ resp = self.get_helper('/api/v2/project/10/participant_datasets?participant_id=100')
+ self.assertEqual(resp.status_code, 404)
+ mock_list_participant_datasets.assert_not_called()
+
+ # test list_datasets_rpc
+ mock_list_participant_datasets.reset_mock()
+ mock_list_datasets.reset_mock()
+ mock_list_flags.reset_mock()
+
+ mock_list_datasets.return_value = mock_return
+ mock_list_flags.return_value = {'list_datasets_rpc_enabled': True}
+ resp = self.get_helper('/api/v2/project/10/participant_datasets?uuid=1&kind=raw')
+ self.assertEqual(resp.status_code, 200)
+ self.assertResponseDataEqual(resp, expect_data)
+ mock_list_datasets.assert_called_once_with(kind=DatasetKindV2.RAW,
+ uuid='1',
+ state=ResourceState.SUCCEEDED,
+ time_range=None)
+ mock_list_participant_datasets.assert_not_called()
+
+ # test filter cron
+ mock_list_datasets.reset_mock()
+ mock_list_datasets.return_value = mock_return
+ resp = self.get_helper('/api/v2/project/10/participant_datasets?cron_interval=DAYS')
+ self.assertEqual(resp.status_code, 200)
+ mock_list_datasets.assert_called_once_with(kind=None,
+ uuid=None,
+ state=ResourceState.SUCCEEDED,
+ time_range=dataset_pb2.TimeRange(days=1))
+
+
+class PublishDatasetApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def test_publish_dataset(self):
+ with db.session_scope() as session:
+ published_dataset = Dataset(id=10,
+ uuid='uuid',
+ name='published_dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=False,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ session.add(published_dataset)
+ dataset_job = DatasetJob(workflow_id=0,
+ uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=1,
+ output_dataset_id=10,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job)
+ session.commit()
+ resp = self.post_helper('/api/v2/datasets/10:publish', {})
+ self.assertEqual(resp.status_code, 200)
+ self.assertResponseDataEqual(
+ resp,
+ {
+ 'id': 10,
+ 'uuid': 'uuid',
+ 'is_published': True,
+ 'name': 'published_dataset',
+ 'creator_username': 'test',
+ 'path': '/data/dataset/123',
+ 'comment': 'test comment',
+ 'project_id': 1,
+ 'dataset_kind': 'RAW',
+ 'dataset_format': 'TABULAR',
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'state_frontend': 'SUCCEEDED',
+ 'parent_dataset_job_id': 1,
+ 'workflow_id': 0,
+ 'value': 0,
+ 'schema_checkers': [],
+ 'dataset_type': 'STREAMING',
+ 'import_type': 'COPY',
+ 'store_format': 'TFRECORDS',
+ 'analyzer_dataset_job_id': 0,
+ 'publish_frontend_state': 'PUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ },
+ ignore_fields=['created_at', 'updated_at', 'deleted_at', 'data_source'],
+ )
+
+ @patch('fedlearner_webconsole.dataset.services.DatasetService.withdraw_dataset')
+ def test_revoke_published_dataset(self, fake_withdraw_dataset: MagicMock):
+ resp = self.delete_helper('/api/v2/datasets/11:publish')
+ self.assertEqual(resp.status_code, 204)
+ fake_withdraw_dataset.assert_called_once_with(11)
+
+
+class DatasetAuthorizehApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+
+ with db.session_scope() as session:
+ dataset = Dataset(id=10,
+ uuid='uuid',
+ name='default dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=False,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ auth_status=AuthStatus.PENDING)
+ dataset.set_participants_info(participants_info=ParticipantsInfo(
+ participants_map={'test_domain': ParticipantInfo(auth_status='PENDING')}))
+ session.add(dataset)
+ dataset_job = DatasetJob(workflow_id=0,
+ uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=1,
+ output_dataset_id=10,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job)
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.apis.SettingService.get_system_info',
+ lambda: setting_pb2.SystemInfo(pure_domain_name='test_domain'))
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobController.inform_auth_status')
+ def test_authorize_dataset(self, mock_inform_auth_status: MagicMock):
+ resp = self.post_helper('/api/v2/datasets/10:authorize')
+ self.assertEqual(resp.status_code, 200)
+ self.assertEqual(self.get_response_data(resp).get('local_auth_status'), 'AUTHORIZED')
+ self.assertEqual(
+ self.get_response_data(resp).get('participants_info'), {
+ 'participants_map': {
+ 'test_domain': {
+ 'auth_status': 'AUTHORIZED',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': '',
+ }
+ }
+ })
+ mock_inform_auth_status.assert_called_once_with(dataset_job=ANY, auth_status=AuthStatus.AUTHORIZED)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ dataset.get_participants_info(),
+ ParticipantsInfo(participants_map={'test_domain': ParticipantInfo(auth_status='AUTHORIZED')}))
+
+ @patch('fedlearner_webconsole.dataset.apis.SettingService.get_system_info',
+ lambda: setting_pb2.SystemInfo(pure_domain_name='test_domain'))
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobController.inform_auth_status')
+ def test_revoke_authorized_dataset(self, mock_inform_auth_status: MagicMock):
+ resp = self.delete_helper('/api/v2/datasets/10:authorize')
+ self.assertEqual(resp.status_code, 200)
+ self.assertEqual(self.get_response_data(resp).get('local_auth_status'), 'WITHDRAW')
+ self.assertEqual(
+ self.get_response_data(resp).get('participants_info'), {
+ 'participants_map': {
+ 'test_domain': {
+ 'auth_status': 'WITHDRAW',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': '',
+ }
+ }
+ })
+ mock_inform_auth_status.assert_called_once_with(dataset_job=ANY, auth_status=AuthStatus.WITHDRAW)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.auth_status, AuthStatus.WITHDRAW)
+ self.assertEqual(
+ dataset.get_participants_info(),
+ ParticipantsInfo(participants_map={'test_domain': ParticipantInfo(auth_status='WITHDRAW')}))
+
+
+class DatasetFlushAuthStatusApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+
+ with db.session_scope() as session:
+ project = Project(id=1, name='test_project')
+ session.add(project)
+ dataset = Dataset(id=10,
+ uuid='uuid',
+ name='default dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=False,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ auth_status=AuthStatus.AUTHORIZED)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'coordinator-domain-name': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'participant-domain-name': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ dataset.set_participants_info(participants_info=participants_info)
+ session.add(dataset)
+ dataset_job = DatasetJob(workflow_id=0,
+ uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=1,
+ output_dataset_id=10,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job)
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.controllers.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.dataset.controllers.ResourceServiceClient.list_datasets')
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobService.get_participants_need_distribute')
+ def test_authorize_dataset(self, mock_get_participants_need_distribute: MagicMock, mock_list_datasets: MagicMock,
+ mock_list_flags: MagicMock):
+ participant = Participant(id=1, name='test_participant', domain_name='fl-participant-domain-name.com')
+ mock_get_participants_need_distribute.return_value = [participant]
+ mock_list_datasets.return_value = ListDatasetsResponse(
+ participant_datasets=[dataset_pb2.ParticipantDatasetRef(auth_status=AuthStatus.AUTHORIZED.name)])
+ mock_list_flags.return_value = {'list_datasets_rpc_enabled': True}
+
+ resp = self.post_helper('/api/v2/datasets/10:flush_auth_status')
+ self.assertEqual(resp.status_code, 200)
+ self.assertEqual(
+ self.get_response_data(resp).get('participants_info'), {
+ 'participants_map': {
+ 'coordinator-domain-name': {
+ 'auth_status': 'AUTHORIZED',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': '',
+ },
+ 'participant-domain-name': {
+ 'auth_status': 'AUTHORIZED',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': '',
+ }
+ }
+ })
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertEqual(
+ dataset.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'coordinator-domain-name': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'participant-domain-name': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ }))
+
+
+class DatasetStateFixApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def test_dataset_state_fix(self):
+ self.signin_as_admin()
+ with db.session_scope() as session:
+ dataset = Dataset(id=10,
+ uuid='uuid',
+ name='dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=False,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ created_at=datetime(2022, 1, 1, 12, 0, 0),
+ updated_at=datetime(2022, 1, 1, 12, 0, 0))
+ session.add(dataset)
+ workflow = Workflow(id=11, state=WorkflowState.FAILED, name='fake_workflow')
+ session.add(workflow)
+ dataset_job = DatasetJob(workflow_id=11,
+ uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=1,
+ output_dataset_id=10,
+ state=DatasetJobState.RUNNING)
+ session.add(dataset_job)
+
+ session.commit()
+ resp = self.post_helper('/api/v2/datasets/10:state_fix', {})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.parent_dataset_job.state, DatasetJobState.FAILED)
+
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(11)
+ workflow.state = WorkflowState.COMPLETED
+ session.commit()
+ resp = self.post_helper('/api/v2/datasets/10:state_fix', {})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.parent_dataset_job.state, DatasetJobState.RUNNING)
+
+ resp = self.post_helper('/api/v2/datasets/10:state_fix', {'force': 'SUCCEEDED'})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.parent_dataset_job.state, DatasetJobState.SUCCEEDED)
+
+
+class DatasetJobsApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+
+ with db.session_scope() as session:
+ project = Project(name='test_project')
+ session.add(project)
+ session.flush([project])
+
+ input_dataset = Dataset(id=1,
+ uuid=resource_uuid(),
+ is_published=False,
+ name='input_dataset',
+ path='/data/dataset/test_123',
+ project_id=project.id,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_type=DatasetType.PSI,
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(input_dataset)
+ streaming_dataset = Dataset(id=2,
+ uuid=resource_uuid(),
+ is_published=False,
+ name='streaming_dataset',
+ path='/data/dataset/test_123',
+ project_id=project.id,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(streaming_dataset)
+
+ session.commit()
+ self._project_id = project.id
+ self._input_dataset_uuid = input_dataset.uuid
+
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobConfiger.from_kind',
+ lambda *args: FakeDatasetJobConfiger(None))
+ @patch('fedlearner_webconsole.utils.domain_name.get_pure_domain_name', lambda _: 'test_domain')
+ @patch('fedlearner_webconsole.dataset.apis.SettingService.get_system_info',
+ lambda: setting_pb2.SystemInfo(pure_domain_name='test_domain', domain_name='test_domain.fedlearner.net'))
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobService.get_participants_need_distribute')
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobService.create_as_coordinator')
+ def test_post_dataset_job(self, mock_create_as_coordinator: MagicMock,
+ mock_get_participants_need_distribute: MagicMock):
+ mock_get_participants_need_distribute.return_value = [
+ Participant(id=1, name='test_participant_1', domain_name='fl-test-domain-name-1.com'),
+ Participant(id=2, name='test_participant_2', domain_name='fl-test-domain-name-2.com')
+ ]
+
+ dataset_job = DatasetJob(uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5),
+ time_range=timedelta(days=1))
+ dataset_job.input_dataset = Dataset(uuid=resource_uuid(),
+ name='test_dataset',
+ dataset_format=DatasetFormat.IMAGE.value)
+ output_dataset = Dataset(id=2,
+ uuid=resource_uuid(),
+ is_published=False,
+ name='streaming_dataset',
+ path='/data/dataset/test_123',
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.RAW,
+ auth_status=AuthStatus.AUTHORIZED)
+ property_mock = PropertyMock(return_value=output_dataset)
+ DatasetJob.output_dataset = property_mock
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs()
+ global_configs.global_configs['test'].MergeFrom(dataset_pb2.DatasetJobConfig())
+ dataset_job.set_global_configs(global_configs)
+
+ # test with error output_dataset_id
+ mock_create_as_coordinator.reset_mock()
+ resp = self.post_helper(
+ '/api/v2/projects/1/dataset_jobs', {
+ 'dataset_job_parameter': {
+ 'global_configs': {
+ 'test_domain.fedlearner.net': {
+ 'dataset_uuid': self._input_dataset_uuid,
+ 'variables': []
+ },
+ },
+ 'dataset_job_kind': 'RSA_PSI_DATA_JOIN',
+ },
+ 'output_dataset_id': 100,
+ })
+
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ mock_create_as_coordinator.assert_not_called()
+
+ # test cron dataset_job
+ mock_create_as_coordinator.return_value = dataset_job
+
+ resp = self.post_helper(
+ '/api/v2/projects/1/dataset_jobs', {
+ 'dataset_job_parameter': {
+ 'global_configs': {
+ 'test_domain.fedlearner.net': {
+ 'dataset_uuid': self._input_dataset_uuid,
+ 'variables': []
+ },
+ },
+ 'dataset_job_kind': 'RSA_PSI_DATA_JOIN',
+ },
+ 'output_dataset_id': 2,
+ 'time_range': {
+ 'days': 1,
+ }
+ })
+
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ global_config = dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={'test_domain': dataset_pb2.DatasetJobConfig(dataset_uuid=self._input_dataset_uuid)})
+ mock_create_as_coordinator.assert_called_with(project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ output_dataset_id=2,
+ global_configs=global_config,
+ time_range=timedelta(days=1))
+ self.assertFalse(dataset_job.get_context().need_create_stage)
+
+ # test non-cron dataset_job
+ mock_create_as_coordinator.reset_mock()
+ dataset_job.time_range = None
+ mock_create_as_coordinator.return_value = dataset_job
+
+ resp = self.post_helper(
+ '/api/v2/projects/1/dataset_jobs', {
+ 'dataset_job_parameter': {
+ 'global_configs': {
+ 'test_domain.fedlearner.net': {
+ 'dataset_uuid': self._input_dataset_uuid,
+ 'variables': []
+ },
+ },
+ 'dataset_job_kind': 'RSA_PSI_DATA_JOIN',
+ },
+ 'output_dataset_id': 2,
+ 'time_range': {
+ 'hours': 1,
+ }
+ })
+
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ global_config = dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={'test_domain': dataset_pb2.DatasetJobConfig(dataset_uuid=self._input_dataset_uuid)})
+ mock_create_as_coordinator.assert_called_with(project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ output_dataset_id=2,
+ global_configs=global_config,
+ time_range=timedelta(hours=1))
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test_domain': ParticipantInfo(auth_status='AUTHORIZED'),
+ 'test-domain-name-1': ParticipantInfo(auth_status='PENDING'),
+ 'test-domain-name-2': ParticipantInfo(auth_status='PENDING'),
+ }))
+ self.assertTrue(dataset_job.get_context().need_create_stage)
+
+ def test_get_dataset_jobs(self):
+ with db.session_scope() as session:
+ output_dataset_1 = Dataset(id=4,
+ uuid=resource_uuid(),
+ is_published=False,
+ name='output_dataset_1',
+ path='/data/dataset/test_123',
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_type=DatasetType.PSI,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(output_dataset_1)
+ output_dataset_2 = Dataset(id=5,
+ uuid=resource_uuid(),
+ is_published=False,
+ name='output_dataset_2',
+ path='/data/dataset/test_123',
+ project_id=2,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_type=DatasetType.PSI,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(output_dataset_2)
+ dataset_job_1 = DatasetJob(uuid='test-uuid-1',
+ name='test_dataset_job_1',
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=1,
+ output_dataset_id=4,
+ coordinator_id=1,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ creator_username='test user 1')
+ session.add(dataset_job_1)
+
+ dataset_job_2 = DatasetJob(uuid='test-uuid-2',
+ name='test_dataset_job_2',
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=2,
+ output_dataset_id=4,
+ coordinator_id=0,
+ state=DatasetJobState.SUCCEEDED,
+ created_at=datetime(2012, 1, 14, 12, 0, 6),
+ creator_username='test user 2')
+ session.add(dataset_job_2)
+ dataset_job_3 = DatasetJob(uuid='test-uuid-3',
+ name='test_dataset_job_3',
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ project_id=2,
+ workflow_id=1,
+ input_dataset_id=3,
+ output_dataset_id=5,
+ coordinator_id=0,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ creator_username='test user 3')
+ session.add(dataset_job_3)
+ dataset_job_4 = DatasetJob(uuid='test-another-uuid-4',
+ name='test_another_dataset_job_4',
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=3,
+ output_dataset_id=4,
+ coordinator_id=0,
+ state=DatasetJobState.SUCCEEDED,
+ created_at=datetime(2012, 1, 14, 12, 0, 8),
+ creator_username='test user 4')
+ session.add(dataset_job_4)
+ session.commit()
+
+ get_response = self.get_helper('/api/v2/projects/2/dataset_jobs')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ dataset_jobs = self.get_response_data(get_response)
+ self.assertEqual(len(dataset_jobs), 1)
+ self.assertEqual(dataset_jobs, [{
+ 'id': 3,
+ 'name': 'test_dataset_job_3',
+ 'uuid': 'test-uuid-3',
+ 'kind': DatasetJobKind.RSA_PSI_DATA_JOIN.name,
+ 'project_id': 2,
+ 'result_dataset_id': 5,
+ 'result_dataset_name': 'output_dataset_2',
+ 'state': DatasetJobState.PENDING.name,
+ 'coordinator_id': 0,
+ 'created_at': ANY,
+ 'has_stages': False,
+ 'creator_username': 'test user 3',
+ }])
+
+ fake_sorter_param = urllib.parse.quote('fake_time asc')
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs?order_by={fake_sorter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.BAD_REQUEST)
+
+ sorter_param = urllib.parse.quote('created_at desc')
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs?order_by={sorter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ dataset_jobs = self.get_response_data(get_response)
+ self.assertEqual([dataset_job.get('id') for dataset_job in dataset_jobs], [4, 2, 1])
+
+ filter_param = urllib.parse.quote('(and(state:["SUCCEEDED"])(name~="test_dataset"))')
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ dataset_jobs = self.get_response_data(get_response)
+ self.assertEqual([dataset_job.get('id') for dataset_job in dataset_jobs], [2])
+
+ filter_param = urllib.parse.quote('(kind:["DATA_ALIGNMENT", "IMPORT_SOURCE"])')
+ sorter_param = urllib.parse.quote('created_at asc')
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs?filter={filter_param}&order_by={sorter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ dataset_jobs = self.get_response_data(get_response)
+ self.assertEqual([dataset_job.get('id') for dataset_job in dataset_jobs], [1, 2])
+
+ filter_param = urllib.parse.quote('(coordinator_id:[0])')
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ dataset_jobs = self.get_response_data(get_response)
+ self.assertEqual([dataset_job.get('id') for dataset_job in dataset_jobs], [4, 2])
+
+ filter_param = urllib.parse.quote('(input_dataset_id=1)')
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs?filter={filter_param}')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ dataset_jobs = self.get_response_data(get_response)
+ self.assertEqual([dataset_job.get('id') for dataset_job in dataset_jobs], [1])
+
+
+class DatasetJobDefinitionApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def test_get_wrong_dataset_job_definitions(self):
+ resp = self.get_helper('/api/v2/dataset_job_definitions/test')
+ self.assertEqual(resp.status_code, 400)
+
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobKind', lambda _: 'fake_handler')
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobConfiger.from_kind',
+ lambda *args: FakeDatasetJobConfiger(None))
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobService.is_local', lambda *args: True)
+ def test_get_dataset_job_definitions(self):
+ resp = self.get_helper('/api/v2/dataset_job_definitions/fake_handler')
+ self.assertEqual(resp.status_code, 200)
+ self.assertResponseDataEqual(
+ resp, {
+ 'variables': [{
+ 'name': 'hello',
+ 'value_type': 'NUMBER',
+ 'typed_value': 1.0,
+ 'value': '',
+ 'tag': '',
+ 'access_mode': 'UNSPECIFIED',
+ 'widget_schema': ''
+ }, {
+ 'name': 'hello_from_job',
+ 'value_type': 'NUMBER',
+ 'typed_value': 3.0,
+ 'tag': '',
+ 'value': '',
+ 'access_mode': 'UNSPECIFIED',
+ 'widget_schema': ''
+ }],
+ 'is_federated': False,
+ })
+
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobKind', lambda _: 'fake_federated_handler')
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobConfiger.from_kind',
+ lambda *args: FakeFederatedDatasetJobConfiger(None))
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobService.is_local', lambda *args: False)
+ def test_get_federated_dataset_job_definitions(self):
+ resp = self.get_helper('/api/v2/dataset_job_definitions/FAKE_HANDLER')
+ self.assertEqual(resp.status_code, 200)
+ self.assertResponseDataEqual(
+ resp, {
+ 'variables': [{
+ 'name': 'hello',
+ 'value_type': 'NUMBER',
+ 'tag': '',
+ 'typed_value': 1.0,
+ 'value': '',
+ 'access_mode': 'UNSPECIFIED',
+ 'widget_schema': ''
+ }, {
+ 'name': 'hello_from_job',
+ 'value_type': 'NUMBER',
+ 'typed_value': 3.0,
+ 'tag': '',
+ 'value': '',
+ 'access_mode': 'UNSPECIFIED',
+ 'widget_schema': ''
+ }],
+ 'is_federated': True,
+ })
+
+
+class DatasetJobApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.dataset.apis.RpcClient.get_dataset_job')
+ def test_get_datasetjob(self, mock_get_dataset_job: MagicMock):
+ get_response = self.get_helper('/api/v2/projects/1/dataset_jobs/123')
+ self.assertEqual(get_response.status_code, 404)
+
+ with db.session_scope() as session:
+ participant = Participant(name='test', domain_name='test_domain')
+ session.add(participant)
+ project = Project(name='test-project')
+ session.add(project)
+ session.flush([project, participant])
+ project_participant = ProjectParticipant(project_id=project.id, participant_id=participant.id)
+ session.add(project_participant)
+
+ output_dataset = Dataset(uuid='output_uuid', name='output_dataset')
+ session.add(output_dataset)
+ session.flush([output_dataset])
+ coordinator_dataset_job = DatasetJob(uuid='u12345',
+ name='coordinator_dataset_job',
+ project_id=project.id,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=output_dataset.id,
+ creator_username='test user',
+ time_range=timedelta(hours=1))
+ coordinator_dataset_job.set_global_configs(global_configs=dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={'our': dataset_pb2.DatasetJobConfig(dataset_uuid='u123')}))
+ context = dataset_pb2.DatasetJobContext(input_data_batch_num_example=1000,
+ output_data_batch_num_example=500)
+ coordinator_dataset_job.set_context(context)
+ session.add(coordinator_dataset_job)
+ mock_get_dataset_job.return_value = service_pb2.GetDatasetJobResponse(dataset_job=dataset_pb2.DatasetJob(
+ global_configs=coordinator_dataset_job.get_global_configs(),
+ scheduler_state=DatasetJobSchedulerState.STOPPED.name,
+ ))
+
+ participant_dataset_job = DatasetJob(
+ uuid='u54321',
+ name='participant_dataset_job',
+ project_id=project.id,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=output_dataset.id,
+ coordinator_id=participant.id,
+ creator_username='test user',
+ time_range=timedelta(days=1),
+ )
+ session.add(participant_dataset_job)
+ session.commit()
+ coordinator_dataset_job_id = coordinator_dataset_job.id
+ participant_dataset_job_id = participant_dataset_job.id
+
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs/{coordinator_dataset_job_id}')
+ self.assertEqual(get_response.status_code, 200)
+ self.assertResponseDataEqual(
+ get_response,
+ {
+ 'id': 1,
+ 'uuid': 'u12345',
+ 'name': 'coordinator_dataset_job',
+ 'project_id': 1,
+ 'kind': 'RSA_PSI_DATA_JOIN',
+ 'state': 'PENDING',
+ 'result_dataset_uuid': 'output_uuid',
+ 'result_dataset_name': 'output_dataset',
+ 'global_configs': {
+ 'global_configs': {
+ 'our': {
+ 'dataset_uuid': 'u123',
+ 'variables': []
+ }
+ }
+ },
+ 'input_data_batch_num_example': 1000,
+ 'output_data_batch_num_example': 500,
+ 'coordinator_id': 0,
+ 'workflow_id': 0,
+ 'is_ready': False,
+ 'started_at': 0,
+ 'finished_at': 0,
+ 'has_stages': False,
+ 'creator_username': 'test user',
+ 'scheduler_state': 'PENDING',
+ 'time_range': {
+ 'days': 0,
+ 'hours': 1,
+ },
+ 'scheduler_message': '',
+ },
+ ignore_fields=['created_at', 'updated_at'],
+ )
+
+ get_response = self.get_helper(f'/api/v2/projects/1/dataset_jobs/{participant_dataset_job_id}')
+ self.assertEqual(get_response.status_code, 200)
+ self.assertResponseDataEqual(
+ get_response,
+ {
+ 'id': 2,
+ 'uuid': 'u54321',
+ 'name': 'participant_dataset_job',
+ 'project_id': 1,
+ 'kind': 'RSA_PSI_DATA_JOIN',
+ 'state': 'PENDING',
+ 'result_dataset_uuid': 'output_uuid',
+ 'result_dataset_name': 'output_dataset',
+ 'global_configs': {
+ 'global_configs': {
+ 'our': {
+ 'dataset_uuid': 'u123',
+ 'variables': []
+ }
+ }
+ },
+ 'input_data_batch_num_example': 0,
+ 'output_data_batch_num_example': 0,
+ 'coordinator_id': 1,
+ 'workflow_id': 0,
+ 'is_ready': False,
+ 'started_at': 0,
+ 'finished_at': 0,
+ 'has_stages': False,
+ 'creator_username': 'test user',
+ 'scheduler_state': 'STOPPED',
+ 'time_range': {
+ 'days': 1,
+ 'hours': 0,
+ },
+ 'scheduler_message': '',
+ },
+ ignore_fields=['created_at', 'updated_at'],
+ )
+ mock_get_dataset_job.assert_called_once_with(uuid='u54321')
+
+ def test_delete_dataset_job(self):
+ # no dataset
+ response = self.delete_helper('/api/v2/projects/1/dataset_jobs/1')
+ self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+
+ # delete successfully
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='test-uuid',
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.FAILED,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ coordinator_id=0)
+ session.add(dataset_job)
+ session.commit()
+ response = self.delete_helper('/api/v2/projects/1/dataset_jobs/1')
+ self.assertEqual(response.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ dataset = session.query(DatasetJob).execution_options(include_deleted=True).get(1)
+ self.assertIsNotNone(dataset.deleted_at)
+
+
+class DatasetJobStopApiTest(BaseTestCase):
+
+ def test_no_dataset_job(self):
+ response = self.post_helper('/api/v2/projects/1/dataset_jobs/1:stop')
+ self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+
+ @patch('fedlearner_webconsole.dataset.apis.DatasetJobController.stop')
+ def test_stop_dataset_job(self, mock_stop: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='u54321',
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=0,
+ coordinator_id=0,
+ )
+ session.add(dataset_job)
+ session.commit()
+ response = self.post_helper('/api/v2/projects/1/dataset_jobs/1:stop')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_stop.assert_called_once_with(uuid='u54321')
+
+
+class DatasetJobStopSchedulerApiTest(BaseTestCase):
+
+ def test_stop_scheduler_dataset_job(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='u54321',
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=0,
+ coordinator_id=0,
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE,
+ time_range=timedelta(days=1))
+ session.add(dataset_job)
+ session.commit()
+ response = self.post_helper('/api/v2/projects/1/dataset_jobs/1:stop_scheduler')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.STOPPED)
+
+
+class DatasetJobStagesApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='dataset_job uuid',
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=0,
+ coordinator_id=0,
+ )
+ session.add(dataset_job)
+ dataset_job_stage_1 = DatasetJobStage(id=1,
+ uuid='uuid_1',
+ name='default dataset job stage 1',
+ project_id=1,
+ workflow_id=1,
+ created_at=datetime(2022, 1, 1, 0, 0, 0),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 15),
+ state=DatasetJobState.PENDING)
+ session.add(dataset_job_stage_1)
+ dataset_job_stage_2 = DatasetJobStage(id=2,
+ uuid='uuid_2',
+ name='default dataset job stage 2',
+ project_id=1,
+ workflow_id=2,
+ created_at=datetime(2022, 1, 2, 0, 0, 0),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 15),
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job_stage_2)
+ dataset_job_stage_3 = DatasetJobStage(id=3,
+ uuid='uuid_3',
+ name='default dataset job stage 3',
+ project_id=1,
+ workflow_id=1,
+ created_at=datetime(2022, 1, 3, 0, 0, 0),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 15),
+ state=DatasetJobState.STOPPED)
+ session.add(dataset_job_stage_3)
+ session.commit()
+
+ def test_get_dataset_job_stages(self):
+ response = self.get_helper('/api/v2/projects/1/dataset_jobs/1/dataset_job_stages')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(response, [{
+ 'created_at': ANY,
+ 'dataset_job_id': 1,
+ 'id': 3,
+ 'kind': DatasetJobKind.RSA_PSI_DATA_JOIN.name,
+ 'name': 'default dataset job stage 3',
+ 'output_data_batch_id': 1,
+ 'project_id': 1,
+ 'state': DatasetJobState.STOPPED.name
+ }, {
+ 'created_at': ANY,
+ 'dataset_job_id': 1,
+ 'id': 2,
+ 'kind': DatasetJobKind.RSA_PSI_DATA_JOIN.name,
+ 'name': 'default dataset job stage 2',
+ 'output_data_batch_id': 1,
+ 'project_id': 1,
+ 'state': DatasetJobState.SUCCEEDED.name
+ }, {
+ 'created_at': ANY,
+ 'dataset_job_id': 1,
+ 'id': 1,
+ 'kind': DatasetJobKind.RSA_PSI_DATA_JOIN.name,
+ 'name': 'default dataset job stage 1',
+ 'output_data_batch_id': 1,
+ 'project_id': 1,
+ 'state': DatasetJobState.PENDING.name
+ }])
+ filter_param = urllib.parse.quote('(state:["STOPPED", "SUCCEEDED"])')
+ sorter_param = urllib.parse.quote('created_at asc')
+ response = self.get_helper(f'/api/v2/projects/1/dataset_jobs/1/dataset_job_stages?\
+ page=1&page_size=5&filter={filter_param}&order_by={sorter_param}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual([job_stage.get('id') for job_stage in data], [2, 3])
+ self.assertEqual(
+ json.loads(response.data).get('page_meta'), {
+ 'current_page': 1,
+ 'page_size': 5,
+ 'total_pages': 1,
+ 'total_items': 2
+ })
+
+
+class DatasetJobStageApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='dataset_job uuid',
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=0,
+ coordinator_id=0,
+ )
+ session.add(dataset_job)
+ dataset_job_stage_1 = DatasetJobStage(id=1,
+ uuid='uuid_1',
+ name='default dataset job stage 1',
+ project_id=1,
+ workflow_id=1,
+ created_at=datetime(2022, 1, 1, 0, 0, 0),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 15),
+ state=DatasetJobState.PENDING,
+ coordinator_id=0)
+ session.add(dataset_job_stage_1)
+ session.commit()
+
+ def test_get_dataset_job_stage(self):
+ response = self.get_helper('/api/v2/projects/1/dataset_jobs/1/dataset_job_stages/1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ response, {
+ 'id': 1,
+ 'uuid': 'uuid_1',
+ 'workflow_id': 1,
+ 'dataset_job_id': 1,
+ 'dataset_job_uuid': 'dataset_job uuid',
+ 'event_time': to_timestamp(datetime(2022, 1, 15)),
+ 'is_ready': False,
+ 'kind': 'RSA_PSI_DATA_JOIN',
+ 'name': 'default dataset job stage 1',
+ 'output_data_batch_id': 1,
+ 'project_id': 1,
+ 'state': 'PENDING',
+ 'created_at': ANY,
+ 'updated_at': ANY,
+ 'started_at': 0,
+ 'finished_at': 0,
+ 'input_data_batch_num_example': 0,
+ 'output_data_batch_num_example': 0,
+ 'scheduler_message': '',
+ })
+ response = self.get_helper('/api/v2/projects/2/dataset_jobs/2/dataset_job_stages/2')
+ self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+
+
+class ChildrenbDatasetsApiTest(BaseTestCase):
+
+ def test_get_sub_dataset_api(self):
+ with db.session_scope() as session:
+ parent_dataset = Dataset(id=1,
+ uuid='parent_dataset uuid',
+ name='parent_dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=True,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(parent_dataset)
+ dataset_job = DatasetJob(workflow_id=0,
+ uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job)
+ analyzer_dataset_job = DatasetJob(workflow_id=0,
+ uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.ANALYZER,
+ input_dataset_id=1,
+ output_dataset_id=1,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(analyzer_dataset_job)
+ child_dataset = Dataset(id=2,
+ uuid='child_dataset uuid',
+ name='child_dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=True,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ store_format=StoreFormat.TFRECORDS)
+ session.add(child_dataset)
+ export_dataset_job = DatasetJob(workflow_id=0,
+ uuid=resource_uuid(),
+ project_id=1,
+ kind=DatasetJobKind.EXPORT,
+ input_dataset_id=1,
+ output_dataset_id=3,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(export_dataset_job)
+ export_dataset = Dataset(id=3,
+ uuid='export_dataset uuid',
+ name='export_dataset',
+ creator_username='test',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=False,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_kind=DatasetKindV2.EXPORTED,
+ store_format=StoreFormat.CSV)
+ session.add(export_dataset)
+ session.commit()
+ response = self.get_helper('/api/v2/datasets/1/children_datasets')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(response, [{
+ 'id': 2,
+ 'project_id': 1,
+ 'name': 'child_dataset',
+ 'creator_username': 'test',
+ 'created_at': ANY,
+ 'path': '/data/dataset/123',
+ 'dataset_format': 'TABULAR',
+ 'comment': 'test comment',
+ 'state_frontend': 'SUCCEEDED',
+ 'dataset_kind': 'PROCESSED',
+ 'data_source': ANY,
+ 'file_size': 0,
+ 'is_published': True,
+ 'num_example': 0,
+ 'uuid': 'child_dataset uuid',
+ 'total_value': 0,
+ 'store_format': 'TFRECORDS',
+ 'dataset_type': 'STREAMING',
+ 'import_type': 'COPY',
+ 'publish_frontend_state': 'PUBLISHED',
+ 'auth_frontend_state': 'AUTH_APPROVED',
+ 'local_auth_status': 'PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ },
+ }])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/auth_service.py b/web_console_v2/api/fedlearner_webconsole/dataset/auth_service.py
new file mode 100644
index 000000000..5a4aaeafd
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/auth_service.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+
+
+class AuthService(object):
+
+ def __init__(self, session: Session, dataset_job: DatasetJob):
+ self._session = session
+ self._dataset_job = dataset_job
+ self._output_dataset: Dataset = dataset_job.output_dataset
+
+ def initialize_participants_info_as_coordinator(self, participants: List[Participant]):
+ participants_info = ParticipantsInfo()
+ for participant in participants:
+ # default auth status is pending
+ participant_info = ParticipantInfo(auth_status=AuthStatus.PENDING.value)
+ participants_info.participants_map[participant.pure_domain_name()].CopyFrom(participant_info)
+
+ coordinator_domain_name = SettingService.get_system_info().pure_domain_name
+ coordinator_info = ParticipantInfo(auth_status=self._output_dataset.auth_status.name)
+ participants_info.participants_map[coordinator_domain_name].CopyFrom(coordinator_info)
+
+ self._output_dataset.set_participants_info(participants_info=participants_info)
+
+ def initialize_participants_info_as_participant(self, participants_info: ParticipantsInfo):
+ self._output_dataset.set_participants_info(participants_info=participants_info)
+
+ def update_auth_status(self, domain_name: str, auth_status: AuthStatus):
+ participants_info = self._output_dataset.get_participants_info()
+ participants_info.participants_map[domain_name].auth_status = auth_status.name
+ self._output_dataset.set_participants_info(participants_info=participants_info)
+
+ def check_local_authorized(self) -> bool:
+ if not Flag.DATASET_AUTH_STATUS_CHECK_ENABLED.value:
+ return True
+ return self._output_dataset.auth_status == AuthStatus.AUTHORIZED
+
+ def check_participants_authorized(self) -> bool:
+ if not Flag.DATASET_AUTH_STATUS_CHECK_ENABLED.value:
+ return True
+ return self._output_dataset.is_all_participants_authorized()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/auth_service_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/auth_service_test.py
new file mode 100644
index 000000000..cca03bee5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/auth_service_test.py
@@ -0,0 +1,145 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, PropertyMock, patch
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.dataset.models import (Dataset, DatasetKindV2, ImportType, DatasetType, DatasetJob,
+ DatasetJobKind, DatasetJobState)
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.flag.models import _Flag
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class AuthServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ dataset = Dataset(id=1,
+ uuid='dataset uuid',
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ is_published=True,
+ import_type=ImportType.COPY,
+ auth_status=AuthStatus.AUTHORIZED)
+ session.add(dataset)
+ dataset_job = DatasetJob(uuid='dataset_job uuid',
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ project_id=1,
+ workflow_id=0,
+ input_dataset_id=0,
+ output_dataset_id=1,
+ coordinator_id=0)
+ session.add(dataset_job)
+ session.commit()
+
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ def test_initialize_participants_info_as_coordinator(self, mock_system_info: MagicMock):
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test-domain-name-coordinator', name='coordinator')
+ particiapnt_1 = Participant(id=1, name='test_participant_1', domain_name='fl-test-domain-name-1.com')
+ particiapnt_2 = Participant(id=2, name='test_participant_2', domain_name='fl-test-domain-name-2.com')
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ AuthService(session=session, dataset_job=dataset_job).initialize_participants_info_as_coordinator(
+ participants=[particiapnt_1, particiapnt_2])
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test-domain-name-1'].auth_status,
+ AuthStatus.PENDING.value)
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test-domain-name-2'].auth_status,
+ AuthStatus.PENDING.value)
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test-domain-name-coordinator'].
+ auth_status, AuthStatus.AUTHORIZED.value)
+
+ def test_initialize_participants_info_as_participant(self):
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test-domain-name-coordinator': ParticipantInfo(auth_status='AUTHORIZED'),
+ 'test-domain-name-1': ParticipantInfo(auth_status='PENDING')
+ })
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ AuthService(session=session, dataset_job=dataset_job).initialize_participants_info_as_participant(
+ participants_info=participants_info)
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test-domain-name-1'].auth_status,
+ AuthStatus.PENDING.value)
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test-domain-name-coordinator'].
+ auth_status, AuthStatus.AUTHORIZED.value)
+
+ def test_update_auth_status(self):
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test-domain-name-coordinator': ParticipantInfo(auth_status='AUTHORIZED'),
+ 'test-domain-name-1': ParticipantInfo(auth_status='PENDING')
+ })
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.output_dataset.set_participants_info(participants_info)
+ AuthService(session=session, dataset_job=dataset_job).update_auth_status(domain_name='test-domain-name-1',
+ auth_status=AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test-domain-name-1'].auth_status,
+ AuthStatus.AUTHORIZED.value)
+
+ @patch('fedlearner_webconsole.flag.models.Flag.DATASET_AUTH_STATUS_CHECK_ENABLED', new_callable=PropertyMock)
+ def test_check_local_authorized(self, mock_dataset_auth_status_check_enabled: MagicMock):
+ with db.session_scope() as session:
+ mock_dataset_auth_status_check_enabled.return_value = _Flag('dataset_auth_status_check_enabled', True)
+ dataset_job = session.query(DatasetJob).get(1)
+ auth_service = AuthService(session=session, dataset_job=dataset_job)
+ self.assertTrue(auth_service.check_local_authorized())
+ dataset_job.output_dataset.auth_status = AuthStatus.WITHDRAW
+ self.assertFalse(auth_service.check_local_authorized())
+
+ mock_dataset_auth_status_check_enabled.reset_mock()
+ mock_dataset_auth_status_check_enabled.return_value = _Flag('dataset_auth_status_check_enabled', False)
+ self.assertTrue(auth_service.check_local_authorized())
+
+ @patch('fedlearner_webconsole.flag.models.Flag.DATASET_AUTH_STATUS_CHECK_ENABLED', new_callable=PropertyMock)
+ def test_check_participants_authorized(self, mock_dataset_auth_status_check_enabled: MagicMock):
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test-domain-name-coordinator': ParticipantInfo(auth_status='AUTHORIZED'),
+ 'test-domain-name-1': ParticipantInfo(auth_status='PENDING')
+ })
+ with db.session_scope() as session:
+ mock_dataset_auth_status_check_enabled.return_value = _Flag('dataset_auth_status_check_enabled', True)
+ dataset_job = session.query(DatasetJob).get(1)
+ auth_service = AuthService(session=session, dataset_job=dataset_job)
+ self.assertTrue(auth_service.check_participants_authorized())
+ dataset_job.output_dataset.set_participants_info(participants_info)
+ self.assertFalse(auth_service.check_participants_authorized())
+
+ mock_dataset_auth_status_check_enabled.reset_mock()
+ mock_dataset_auth_status_check_enabled.return_value = _Flag('dataset_auth_status_check_enabled', False)
+ self.assertTrue(auth_service.check_participants_authorized())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/batch_stats.py b/web_console_v2/api/fedlearner_webconsole/dataset/batch_stats.py
new file mode 100644
index 000000000..58199e882
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/batch_stats.py
@@ -0,0 +1,108 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+from envs import Envs
+import enum
+from multiprocessing import get_context, Queue
+import queue
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput
+from fedlearner_webconsole.dataset.models import DataBatch, BatchState
+from fedlearner_webconsole.dataset.services import DataReader
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.utils.hooks import pre_start_hook
+
+_BATCH_STATS_LOG = 'batch stats'
+
+
+class BatchStatsItemState(enum.Enum):
+ SUCCESSED = 'SUCCESSED'
+ FAILED = 'FAILED'
+
+
+def batch_stats_sub_process(batch_id: int, q: Queue):
+ # as we need connect to db in sub process, we should pre-set environment in hook
+ # TODO(wangsen.0914): support start process in a unify func
+ pre_start_hook()
+ with db.session_scope() as session:
+ batch: DataBatch = session.query(DataBatch).get(batch_id)
+ batch_path = get_batch_data_path(batch)
+ batch_name = batch.batch_name
+ dataset_path = batch.dataset.path
+ meta = DataReader(dataset_path).metadata(batch_name=batch_name)
+ batch_num_feature = meta.num_feature
+ batch_num_example = meta.num_example
+ batch_file_size = FileOperator().getsize(batch_path)
+ q.put([batch_num_feature, batch_num_example, batch_file_size])
+
+
+class BatchStatsRunner(IRunnerV2):
+
+ def _set_batch_stats(self, batch_id: int):
+ try:
+ context = get_context('spawn')
+ internal_queue = context.Queue()
+ # The memory will not release after batch stats, so a new process is initialized to do that.
+ batch_stats_process = context.Process(target=batch_stats_sub_process,
+ kwargs={
+ 'batch_id': batch_id,
+ 'q': internal_queue,
+ },
+ daemon=True)
+ batch_stats_process.start()
+ try:
+ # wait 10 min as some customer hdfs system may cause long time to read
+ batch_num_feature, batch_num_example, batch_file_size = internal_queue.get(timeout=600)
+ except queue.Empty as e:
+ batch_stats_process.terminate()
+ raise RuntimeError('run batch_stats_sub_process failed') from e
+ finally:
+ batch_stats_process.join()
+ batch_stats_process.close()
+ internal_queue.close()
+ with db.session_scope() as session:
+ batch = session.query(DataBatch).get(batch_id)
+ batch.num_feature = batch_num_feature
+ batch.num_example = batch_num_example
+ batch.file_size = batch_file_size
+ logging.info(f'[{_BATCH_STATS_LOG}]: total batch data size is {batch.file_size}')
+ batch.state = BatchState.SUCCESS
+ session.commit()
+ logging.info(f'[{_BATCH_STATS_LOG}]: finish batch stats task')
+ except Exception: # pylint: disable=broad-except
+ with db.session_scope() as session:
+ batch = session.query(DataBatch).get(batch_id)
+ batch.state = BatchState.FAILED
+ session.commit()
+ raise
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ logging.info(f'[{_BATCH_STATS_LOG}]: start batch stats task')
+ try:
+ batch_id = context.input.batch_stats_input.batch_id
+ logging.info(f'[{_BATCH_STATS_LOG}]: collect raw dataset stats info, batch id: {batch_id}')
+ self._set_batch_stats(batch_id)
+ return RunnerStatus.DONE, RunnerOutput()
+ except Exception as e: # pylint: disable=broad-except
+ error_message = str(e)
+ logging.error(f'[{_BATCH_STATS_LOG}] err: {error_message}, envs: {Envs.__dict__}')
+ return RunnerStatus.FAILED, RunnerOutput(error_message=error_message)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/batch_stats_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/batch_stats_test.py
new file mode 100644
index 000000000..7e718fa07
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/batch_stats_test.py
@@ -0,0 +1,99 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime
+from unittest.mock import patch, MagicMock
+from multiprocessing import Queue
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.dataset.batch_stats import BatchStatsRunner, batch_stats_sub_process
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, BatchState, DatasetType
+from fedlearner_webconsole.dataset.services import DataReader
+from fedlearner_webconsole.db import db, turn_db_timezone_to_utc
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, BatchStatsInput
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+def fake_batch_stats_sub_process(batch_id: int, q: Queue):
+ q.put([10, 666, 789123])
+
+
+class BatchStatsRunnerTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.dataset.batch_stats.batch_stats_sub_process', fake_batch_stats_sub_process)
+ def test_run_for_batch(self):
+ with db.session_scope() as session:
+ dataset = Dataset(id=1, name='test_dataset', path='/test_dataset', dataset_type=DatasetType.PSI)
+ session.add(dataset)
+ batch = DataBatch(id=2,
+ name='0',
+ dataset_id=dataset.id,
+ path='/test_dataset/1/batch/0',
+ event_time=datetime(2021, 10, 28, 16, 37, 37))
+ session.add(batch)
+ session.commit()
+
+ runner = BatchStatsRunner()
+
+ runner_input = RunnerInput(batch_stats_input=BatchStatsInput(batch_id=2))
+ context = RunnerContext(index=0, input=runner_input)
+
+ # Succeeded case
+ status, _ = runner.run(context)
+ self.assertEqual(status, RunnerStatus.DONE)
+ with db.session_scope() as session:
+ batch = session.query(DataBatch).get(2)
+ self.assertEqual(batch.state, BatchState.SUCCESS)
+ self.assertEqual(batch.num_feature, 10)
+ self.assertEqual(batch.num_example, 666)
+ self.assertEqual(batch.file_size, 789123)
+
+ @patch('fedlearner_webconsole.dataset.batch_stats.FileOperator.getsize')
+ @patch('fedlearner_webconsole.dataset.batch_stats.DataReader')
+ @patch('fedlearner_webconsole.dataset.services.DataReader.metadata')
+ @patch('fedlearner_webconsole.utils.hooks.get_database_uri')
+ def test_batch_stats_sub_process(self, mock_get_database_uri: MagicMock, mock_metadata: MagicMock,
+ mock_data_reader: MagicMock, mock_getsize: MagicMock):
+ with db.session_scope() as session:
+ dataset = Dataset(id=1, name='test_dataset', path='/test_dataset', dataset_type=DatasetType.PSI)
+ session.add(dataset)
+ batch = DataBatch(id=2,
+ name='0',
+ dataset_id=dataset.id,
+ path='/test_dataset/1/batch/0',
+ event_time=datetime(2021, 10, 28, 16, 37, 37))
+ session.add(batch)
+ session.commit()
+ mock_metadata_res = MagicMock()
+ mock_metadata_res.num_feature = 10
+ mock_metadata_res.num_example = 666
+ mock_getsize.return_value = 789123
+ mock_get_database_uri.return_value = turn_db_timezone_to_utc(self.__class__.Config.SQLALCHEMY_DATABASE_URI)
+
+ mock_data_reader.return_value = DataReader('/test_dataset')
+ mock_metadata.return_value = mock_metadata_res
+
+ queue = Queue()
+ batch_stats_sub_process(batch_id=2, q=queue)
+ batch_num_feature, batch_num_example, batch_file_size = queue.get()
+ self.assertEqual(batch_num_feature, 10)
+ self.assertEqual(batch_num_example, 666)
+ self.assertEqual(batch_file_size, 789123)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/consts.py b/web_console_v2/api/fedlearner_webconsole/dataset/consts.py
new file mode 100644
index 000000000..6c9cd5019
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/consts.py
@@ -0,0 +1,26 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+PLACEHOLDER = 'PLACEHOLDER'
+
+CRON_SCHEDULER_FOLDER_NOT_READY_ERROR_MESSAGE = f'数据源下未找到满足格式要求的文件夹,请确认文件夹以{PLACEHOLDER}格式命名'
+CRON_SCHEDULER_CERTAIN_FOLDER_NOT_READY_ERROR_MESSAGE = \
+ f'{PLACEHOLDER}文件夹检查失败,请确认数据源下存在以{PLACEHOLDER}格式命名的文件夹,且文件夹下有_SUCCESS文件'
+CRON_SCHEDULER_BATCH_NOT_READY_ERROR_MESSAGE = f'未找到满足格式要求的数据批次,请确保输入数据集有{PLACEHOLDER}格式命名的数据批次'
+CRON_SCHEDULER_CERTAIN_BATCH_NOT_READY_ERROR_MESSAGE = f'数据批次{PLACEHOLDER}检查失败,请确认该批次命名格式及状态'
+CRON_SCHEDULER_SUCCEEDED_MESSAGE = f'已成功发起{PLACEHOLDER}批次处理任务'
+
+ERROR_BATCH_SIZE = -1
+MANUFACTURER = 'dm9sY2VuZ2luZQ=='
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/controllers.py b/web_console_v2/api/fedlearner_webconsole/dataset/controllers.py
new file mode 100644
index 000000000..253b17ccd
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/controllers.py
@@ -0,0 +1,210 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import logging
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.proto.two_pc_pb2 import LaunchDatasetJobData, LaunchDatasetJobStageData, \
+ StopDatasetJobData, StopDatasetJobStageData, TransactionData, TwoPcType
+from fedlearner_webconsole.rpc.v2.job_service_client import JobServiceClient
+from fedlearner_webconsole.rpc.v2.resource_service_client import ResourceServiceClient
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.workflow import fill_variables
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobStage, DatasetJobState
+from fedlearner_webconsole.dataset.job_configer.dataset_job_configer import DatasetJobConfiger
+from fedlearner_webconsole.dataset.services import DatasetJobService
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.exceptions import InvalidArgumentException, InternalException
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.workflow.workflow_controller import create_ready_workflow
+from fedlearner_webconsole.two_pc.transaction_manager import TransactionManager
+from fedlearner_webconsole.flag.models import Flag
+
+
+class DatasetJobController:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def _transfer_state(self, uuid: str, target_state: DatasetJobState):
+ dataset_job = self._session.query(DatasetJob).filter_by(uuid=uuid).first()
+
+ participants = DatasetJobService(session=self._session).get_participants_need_distribute(dataset_job)
+ if target_state == DatasetJobState.RUNNING:
+ data = LaunchDatasetJobData(dataset_job_uuid=dataset_job.uuid)
+ two_pc_type = TwoPcType.LAUNCH_DATASET_JOB
+ transaction_data = TransactionData(launch_dataset_job_data=data)
+ elif target_state == DatasetJobState.STOPPED:
+ data = StopDatasetJobData(dataset_job_uuid=dataset_job.uuid)
+ two_pc_type = TwoPcType.STOP_DATASET_JOB
+ transaction_data = TransactionData(stop_dataset_job_data=data)
+ else:
+ raise InternalException(f'cannot transfer dataset_job state to {target_state.name} by two_pc')
+
+ tm = TransactionManager(project_name=dataset_job.project.name,
+ project_token=dataset_job.project.token,
+ participants=[participant.domain_name for participant in participants],
+ two_pc_type=two_pc_type)
+ successed, message = tm.run(data=transaction_data)
+ if not successed:
+ err_msg = f'error when try to transfer dataset_job state to {target_state.name} by 2PC, ' \
+ f'dataset_job_id: {dataset_job.id}, message: {message}'
+ logging.error(err_msg)
+ raise InternalException(err_msg)
+
+ def start(self, uuid: str):
+ self._transfer_state(uuid=uuid, target_state=DatasetJobState.RUNNING)
+
+ def stop(self, uuid: str):
+ self._transfer_state(uuid=uuid, target_state=DatasetJobState.STOPPED)
+
+ # stop all related dataset_job_stage
+ dataset_job_stage_ids = self._session.query(DatasetJobStage.id).outerjoin(
+ DatasetJob, DatasetJobStage.dataset_job_id == DatasetJob.id).filter(DatasetJob.uuid == uuid).all()
+ for dataset_job_stage_id, *_ in dataset_job_stage_ids:
+ # check each dataset_job_stage, stop by 2pc if is not finished.
+ # we don't recheck job_stage state as TransactionManager will check dataset_job_stage state in new session.
+ dataset_job_stage = self._session.query(DatasetJobStage).get(dataset_job_stage_id)
+ if not dataset_job_stage.is_finished():
+ DatasetJobStageController(self._session).stop(uuid=dataset_job_stage.uuid)
+
+ def inform_auth_status(self, dataset_job: DatasetJob, auth_status: AuthStatus):
+ participants = DatasetJobService(self._session).get_participants_need_distribute(dataset_job)
+ for participant in participants:
+ client = ResourceServiceClient.from_project_and_participant(domain_name=participant.domain_name,
+ project_name=dataset_job.project.name)
+ try:
+ client.inform_dataset(dataset_uuid=dataset_job.output_dataset.uuid, auth_status=auth_status)
+ except grpc.RpcError as err:
+ logging.warning(
+ f'[dataset_job_controller]: failed to inform particiapnt {participant.name} dataset auth_status, '\
+ f'dataset name: {dataset_job.output_dataset.name}, exception: {err}'
+ )
+
+ def update_auth_status_cache(self, dataset_job: DatasetJob):
+ participants = DatasetJobService(self._session).get_participants_need_distribute(dataset_job)
+ for participant in participants:
+ try:
+ # check flag
+ client = SystemServiceClient.from_participant(domain_name=participant.domain_name)
+ resp = client.list_flags()
+ # if participant not supports list dataset rpc, just set AUTHORIZED
+ if not resp.get(Flag.LIST_DATASETS_RPC_ENABLED.name):
+ AuthService(self._session,
+ dataset_job=dataset_job).update_auth_status(domain_name=participant.pure_domain_name(),
+ auth_status=AuthStatus.AUTHORIZED)
+ continue
+ client = ResourceServiceClient.from_project_and_participant(domain_name=participant.domain_name,
+ project_name=dataset_job.project.name)
+ resp = client.list_datasets(uuid=dataset_job.output_dataset.uuid)
+ if len(resp.participant_datasets) == 0 or not resp.participant_datasets[0].auth_status:
+ logging.warning(
+ '[dataset_job_controller]: update auth_status cache failed as dataset not found, ' \
+ f'or auth_status is None, particiapnt name: {participant.name}, ' \
+ f'dataset name: {dataset_job.output_dataset.name}'
+ )
+ continue
+ participant_auth_status = AuthStatus[resp.participant_datasets[0].auth_status]
+ AuthService(self._session,
+ dataset_job=dataset_job).update_auth_status(domain_name=participant.pure_domain_name(),
+ auth_status=participant_auth_status)
+ except grpc.RpcError as err:
+ logging.warning(
+ '[dataset_job_controller]: failed to update dataset auth_status_cache, ' \
+ f'particiapnt name: {participant.name}, ' \
+ f'dataset name: {dataset_job.output_dataset.name}, exception: {err}'
+ )
+
+
+class DatasetJobStageController:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def create_ready_workflow(self, dataset_job_stage: DatasetJobStage) -> Workflow:
+ dataset_job: DatasetJob = dataset_job_stage.dataset_job
+ if not dataset_job_stage.is_coordinator():
+ coordinator = self._session.query(Participant).get(dataset_job_stage.coordinator_id)
+ if coordinator is None:
+ raise InvalidArgumentException(f'failed to find participant {dataset_job_stage.coordinator_id}')
+ try:
+ client = JobServiceClient.from_project_and_participant(coordinator.domain_name,
+ dataset_job_stage.project.name)
+ pulled_dataset_job_stage = client.get_dataset_job_stage(dataset_job_stage_uuid=dataset_job_stage.uuid)
+ except grpc.RpcError as err:
+ logging.error(f'failed to call GetDatasetJobStage with status code {err.code()}, \
+ and details {err.details()}')
+ raise
+ config = pulled_dataset_job_stage.dataset_job_stage.workflow_definition
+ global_configs = pulled_dataset_job_stage.dataset_job_stage.global_configs
+
+ else:
+ # TODO(liuhehan): refactor to use rpc get config
+ config = DatasetJobConfiger.from_kind(dataset_job.kind, self._session).get_config()
+ global_configs = dataset_job_stage.get_global_configs()
+
+ result_dataset = self._session.query(Dataset).get(dataset_job.output_dataset_id)
+ global_configs = DatasetJobConfiger.from_kind(dataset_job.kind, self._session).config_local_variables(
+ global_configs, result_dataset.uuid, dataset_job_stage.event_time)
+
+ domain_name = SettingService.get_system_info().pure_domain_name
+ filled_config = fill_variables(config=config, variables=global_configs.global_configs[domain_name].variables)
+ workflow = create_ready_workflow(
+ session=self._session,
+ name=f'{dataset_job.kind.value}-{dataset_job_stage.uuid}',
+ config=filled_config,
+ project_id=dataset_job_stage.project_id,
+ uuid=dataset_job_stage.uuid,
+ )
+ self._session.flush()
+ dataset_job_stage.workflow_id = workflow.id
+
+ return workflow
+
+ def _transfer_state(self, uuid: str, target_state: DatasetJobState):
+ dataset_job_stage: DatasetJobStage = self._session.query(DatasetJobStage).filter_by(uuid=uuid).first()
+
+ assert target_state in [DatasetJobState.RUNNING, DatasetJobState.STOPPED]
+ if target_state == DatasetJobState.RUNNING:
+ data = LaunchDatasetJobStageData(dataset_job_stage_uuid=uuid)
+ two_pc_type = TwoPcType.LAUNCH_DATASET_JOB_STAGE
+ transaction_data = TransactionData(launch_dataset_job_stage_data=data)
+ else:
+ data = StopDatasetJobStageData(dataset_job_stage_uuid=uuid)
+ two_pc_type = TwoPcType.STOP_DATASET_JOB_STAGE
+ transaction_data = TransactionData(stop_dataset_job_stage_data=data)
+
+ participants = DatasetJobService(session=self._session).get_participants_need_distribute(
+ dataset_job_stage.dataset_job)
+ tm = TransactionManager(project_name=dataset_job_stage.project.name,
+ project_token=dataset_job_stage.project.token,
+ participants=[participant.domain_name for participant in participants],
+ two_pc_type=two_pc_type)
+ succeeded, message = tm.run(data=transaction_data)
+ if not succeeded:
+ err_msg = f'error when try to transfer dataset_job_stage state to {target_state.name} by 2PC, ' \
+ f'dataset_job_stage_id: {dataset_job_stage.id}, message: {message}'
+ logging.error(err_msg)
+ raise InternalException(err_msg)
+
+ def start(self, uuid: str):
+ self._transfer_state(uuid=uuid, target_state=DatasetJobState.RUNNING)
+
+ def stop(self, uuid: str):
+ self._transfer_state(uuid=uuid, target_state=DatasetJobState.STOPPED)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/controllers_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/controllers_test.py
new file mode 100644
index 000000000..d60969f5c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/controllers_test.py
@@ -0,0 +1,473 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=protected-access
+from datetime import datetime
+import unittest
+from unittest.mock import MagicMock, patch
+from google.protobuf.struct_pb2 import Value
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.dataset import FakeDatasetJobConfiger
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.utils.const import SYSTEM_WORKFLOW_CREATOR_USERNAME
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.controllers import DatasetJobController, DatasetJobStageController
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetJobStage, DatasetJobState
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto import dataset_pb2, service_pb2
+from fedlearner_webconsole.proto.rpc.v2 import job_service_pb2
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.two_pc_pb2 import LaunchDatasetJobData, LaunchDatasetJobStageData, \
+ StopDatasetJobData, StopDatasetJobStageData, TransactionData, TwoPcType
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo, ParticipantsInfo
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import ListDatasetsResponse
+
+
+def get_dataset_job_pb(*args, **kwargs) -> service_pb2.GetDatasetJobResponse:
+ dataset_job = dataset_pb2.DatasetJob(uuid='u1234')
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs()
+ global_configs.global_configs['test_domain'].MergeFrom(
+ dataset_pb2.DatasetJobConfig(dataset_uuid=resource_uuid(),
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ dataset_job.global_configs.MergeFrom(global_configs)
+ return service_pb2.GetDatasetJobResponse(dataset_job=dataset_job)
+
+
+class DatasetJobControllerTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _PARTICIPANT_ID = 1
+ _OUTPUT_DATASET_ID = 1
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=self._PROJECT_ID, name='test-project')
+ participant = Participant(id=self._PARTICIPANT_ID, name='participant_1', domain_name='fake_domain_name_1')
+ project_participant = ProjectParticipant(project_id=self._PROJECT_ID, participant_id=self._PARTICIPANT_ID)
+ session.add(project)
+ session.add(participant)
+ session.add(project_participant)
+ output_dataset = Dataset(id=self._OUTPUT_DATASET_ID,
+ name='test_output_dataset',
+ uuid=resource_uuid(),
+ path='/data/dataset/test_dataset')
+ session.add(output_dataset)
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobService.need_distribute')
+ @patch('fedlearner_webconsole.dataset.controllers.TransactionManager')
+ def test_transfer_state(self, mock_transaction_manager: MagicMock, mock_need_distribute: MagicMock):
+ dataset_job_id = 10
+ workflow_id = 11
+ with db.session_scope() as session:
+ uuid = resource_uuid()
+ workflow = Workflow(id=workflow_id, uuid=uuid)
+ dataset_job = DatasetJob(id=dataset_job_id,
+ uuid=uuid,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ project_id=self._PROJECT_ID,
+ workflow_id=workflow_id,
+ input_dataset_id=1,
+ output_dataset_id=2)
+ session.add(workflow)
+ session.add(dataset_job)
+ session.commit()
+
+ mock_need_distribute.return_value = False
+ mock_run = MagicMock(return_value=(True, ''))
+ mock_transaction_manager.return_value = MagicMock(run=mock_run)
+
+ # test illegal target state
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(dataset_job_id)
+ with self.assertRaises(InternalException):
+ DatasetJobController(session)._transfer_state(uuid=dataset_job.uuid,
+ target_state=DatasetJobState.SUCCEEDED)
+ mock_transaction_manager.assert_not_called()
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(dataset_job_id)
+ DatasetJobController(session)._transfer_state(uuid=dataset_job.uuid, target_state=DatasetJobState.RUNNING)
+ data = LaunchDatasetJobData(dataset_job_uuid=dataset_job.uuid)
+ mock_run.assert_called_with(data=TransactionData(launch_dataset_job_data=data))
+ mock_transaction_manager.assert_called_with(project_name='test-project',
+ project_token=None,
+ two_pc_type=TwoPcType.LAUNCH_DATASET_JOB,
+ participants=[])
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(dataset_job_id)
+ DatasetJobController(session)._transfer_state(uuid=dataset_job.uuid, target_state=DatasetJobState.STOPPED)
+ data = StopDatasetJobData(dataset_job_uuid=dataset_job.uuid)
+ mock_run.assert_called_with(data=TransactionData(stop_dataset_job_data=data))
+ mock_transaction_manager.assert_called_with(project_name='test-project',
+ project_token=None,
+ two_pc_type=TwoPcType.STOP_DATASET_JOB,
+ participants=[])
+
+ mock_need_distribute.return_value = True
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(dataset_job_id)
+ DatasetJobController(session)._transfer_state(uuid=dataset_job.uuid, target_state=DatasetJobState.RUNNING)
+ data = LaunchDatasetJobData(dataset_job_uuid=dataset_job.uuid)
+ mock_run.assert_called_with(data=TransactionData(launch_dataset_job_data=data))
+ mock_transaction_manager.assert_called_with(project_name='test-project',
+ project_token=None,
+ two_pc_type=TwoPcType.LAUNCH_DATASET_JOB,
+ participants=['fake_domain_name_1'])
+
+ mock_run = MagicMock(return_value=(False, ''))
+ mock_transaction_manager.return_value = MagicMock(run=mock_run)
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(dataset_job_id)
+ with self.assertRaises(InternalException):
+ DatasetJobController(session)._transfer_state(uuid=dataset_job.uuid,
+ target_state=DatasetJobState.RUNNING)
+
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobController._transfer_state')
+ def test_start(self, mock_transfer_state: MagicMock):
+ with db.session_scope() as session:
+ DatasetJobController(session).start(uuid=1)
+ mock_transfer_state.assert_called_once_with(uuid=1, target_state=DatasetJobState.RUNNING)
+
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobController._transfer_state')
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobStageController.stop')
+ def test_stop(self, mock_dataset_job_stage_stop: MagicMock, mock_transfer_state: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='u54321',
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=0,
+ coordinator_id=0,
+ )
+ session.add(dataset_job)
+ dataset_job_stage_1 = DatasetJobStage(id=1,
+ uuid='job_stage uuid_1',
+ name='default dataset job stage 1',
+ project_id=1,
+ workflow_id=1,
+ created_at=datetime(2022, 1, 1, 0, 0, 0),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 1),
+ state=DatasetJobState.PENDING)
+ session.add(dataset_job_stage_1)
+ dataset_job_stage_2 = DatasetJobStage(id=2,
+ uuid='job_stage uuid_2',
+ name='default dataset job stage 2',
+ project_id=1,
+ workflow_id=1,
+ created_at=datetime(2022, 1, 2, 0, 0, 0),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 2),
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job_stage_2)
+ session.commit()
+ with db.session_scope() as session:
+ DatasetJobController(session).stop(uuid='u54321')
+ mock_transfer_state.assert_called_once_with(uuid='u54321', target_state=DatasetJobState.STOPPED)
+ mock_dataset_job_stage_stop.assert_called_once_with(uuid='job_stage uuid_1')
+
+ @patch('fedlearner_webconsole.dataset.controllers.ResourceServiceClient.inform_dataset')
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobService.get_participants_need_distribute')
+ def test_inform_auth_status(self, mock_get_participants_need_distribute: MagicMock, mock_inform_dataset: MagicMock):
+ particiapnt = Participant(id=1, name='test_participant', domain_name='fl-test-domain-name.com')
+ mock_get_participants_need_distribute.return_value = [particiapnt]
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='u54321',
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ coordinator_id=0,
+ )
+ session.add(dataset_job)
+ session.flush()
+ DatasetJobController(session=session).inform_auth_status(dataset_job=dataset_job,
+ auth_status=AuthStatus.AUTHORIZED)
+ mock_inform_dataset.assert_called_once_with(dataset_uuid=dataset_job.output_dataset.uuid,
+ auth_status=AuthStatus.AUTHORIZED)
+
+ @patch('fedlearner_webconsole.dataset.controllers.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.dataset.controllers.ResourceServiceClient.list_datasets')
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobService.get_participants_need_distribute')
+ def test_update_auth_status_cache(self, mock_get_participants_need_distribute: MagicMock,
+ mock_list_datasets: MagicMock, mock_list_flags: MagicMock):
+ particiapnt = Participant(id=1, name='test_participant', domain_name='fl-test-domain-name.com')
+ mock_get_participants_need_distribute.return_value = [particiapnt]
+ mock_list_datasets.return_value = ListDatasetsResponse(
+ participant_datasets=[dataset_pb2.ParticipantDatasetRef(auth_status=AuthStatus.AUTHORIZED.name)])
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='u54321',
+ project_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=123,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ coordinator_id=0,
+ )
+ session.add(dataset_job)
+ dataset: Dataset = session.query(Dataset).get(self._OUTPUT_DATASET_ID)
+ participants_info = ParticipantsInfo(
+ participants_map={'test-domain-name': ParticipantInfo(auth_status=AuthStatus.PENDING.name)})
+ dataset.set_participants_info(participants_info=participants_info)
+ session.flush()
+ mock_list_flags.return_value = {'list_datasets_rpc_enabled': False}
+ DatasetJobController(session=session).update_auth_status_cache(dataset_job=dataset_job)
+ mock_list_datasets.assert_not_called()
+ mock_list_flags.reset_mock()
+ mock_list_flags.return_value = {'list_datasets_rpc_enabled': True}
+ DatasetJobController(session=session).update_auth_status_cache(dataset_job=dataset_job)
+ mock_list_datasets.assert_called_once_with(uuid=dataset_job.output_dataset.uuid)
+ self.assertEqual(dataset.get_participants_info().participants_map['test-domain-name'].auth_status,
+ AuthStatus.AUTHORIZED.name)
+
+
+def get_dataset_job_stage_pb(*args, **kwargs) -> job_service_pb2.GetDatasetJobStageResponse:
+ dataset_job_stage = dataset_pb2.DatasetJobStage(uuid='dataset_job_stage uuid')
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs()
+ global_configs.global_configs['test_domain'].MergeFrom(
+ dataset_pb2.DatasetJobConfig(dataset_uuid=resource_uuid(),
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ dataset_job_stage.global_configs.MergeFrom(global_configs)
+ return job_service_pb2.GetDatasetJobStageResponse(dataset_job_stage=dataset_job_stage)
+
+
+class DatasetJobStageControllerTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _PARTICIPANT_ID = 1
+ _OUTPUT_DATASET_ID = 1
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=self._PROJECT_ID, name='test-project')
+ participant = Participant(id=self._PARTICIPANT_ID, name='participant_1', domain_name='fake_domain_name_1')
+ project_participant = ProjectParticipant(project_id=self._PROJECT_ID, participant_id=self._PARTICIPANT_ID)
+ session.add(project)
+ session.add(participant)
+ session.add(project_participant)
+ output_dataset = Dataset(id=self._OUTPUT_DATASET_ID,
+ name='test_output_dataset',
+ uuid='output_dataset uuid',
+ path='/data/dataset/test_dataset')
+ session.add(output_dataset)
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobConfiger.from_kind',
+ lambda *args: FakeDatasetJobConfiger(None))
+ @patch('fedlearner_webconsole.dataset.controllers.SettingService.get_system_info',
+ lambda: SystemInfo(name='test', domain_name='test_domain.fedlearner.net'))
+ def test_create_ready_workflow_coordinator(self):
+
+ with db.session_scope() as session:
+
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='dataset_job uuid',
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ coordinator_id=0,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=0,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ )
+ session.add(dataset_job)
+ uuid = resource_uuid()
+ dataset_job_stage = DatasetJobStage(
+ id=1,
+ uuid=uuid,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=1,
+ data_batch_id=1,
+ coordinator_id=0,
+ )
+ session.add(dataset_job_stage)
+
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs()
+ global_configs.global_configs['test_domain'].MergeFrom(
+ dataset_pb2.DatasetJobConfig(dataset_uuid=resource_uuid(),
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ dataset_job_stage.set_global_configs(global_configs)
+ session.flush()
+
+ wf = DatasetJobStageController(session).create_ready_workflow(dataset_job_stage)
+ self.assertEqual(wf.uuid, uuid)
+ self.assertEqual(wf.creator, SYSTEM_WORKFLOW_CREATOR_USERNAME)
+
+ @patch('fedlearner_webconsole.dataset.controllers.JobServiceClient.get_dataset_job_stage')
+ @patch('fedlearner_webconsole.dataset.job_configer.import_source_configer.ImportSourceConfiger.'\
+ 'config_local_variables')
+ @patch('fedlearner_webconsole.dataset.controllers.SettingService.get_system_info',
+ lambda: SystemInfo(name='test', domain_name='test_domain.fedlearner.net'))
+ def test_create_ready_workflow_participant(self, mock_config_local_variables: MagicMock,
+ mock_get_dataset_job_stage: MagicMock):
+ get_dataset_job_stage_response = get_dataset_job_stage_pb()
+ mock_get_dataset_job_stage.return_value = get_dataset_job_stage_response
+ mock_config_local_variables.return_value = get_dataset_job_stage_response.dataset_job_stage.global_configs
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(
+ id=1,
+ uuid='dataset_job uuid',
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ coordinator_id=0,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=0,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ )
+ session.add(dataset_job)
+ uuid = resource_uuid()
+ dataset_job_stage = DatasetJobStage(
+ id=1,
+ uuid=uuid,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 1),
+ coordinator_id=self._PARTICIPANT_ID,
+ )
+ session.add(dataset_job_stage)
+ session.flush()
+
+ wf = DatasetJobStageController(session).create_ready_workflow(dataset_job_stage)
+ self.assertEqual(wf.uuid, uuid)
+ self.assertEqual(wf.creator, SYSTEM_WORKFLOW_CREATOR_USERNAME)
+ mock_config_local_variables.assert_called_once_with(
+ get_dataset_job_stage_response.dataset_job_stage.global_configs, 'output_dataset uuid',
+ datetime(2022, 1, 1))
+
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobService.get_participants_need_distribute')
+ @patch('fedlearner_webconsole.dataset.controllers.TransactionManager')
+ def test_transfer_state(self, mock_transaction_manager: MagicMock,
+ mock_get_participants_need_distribute: MagicMock):
+ dataset_job_id = 10
+ dataset_job_stage_id = 11
+ workflow_id = 12
+ with db.session_scope() as session:
+ uuid = resource_uuid()
+ workflow = Workflow(id=workflow_id, uuid=uuid)
+ dataset_job = DatasetJob(id=dataset_job_id,
+ uuid=resource_uuid(),
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2)
+ dataset_job_stage = DatasetJobStage(id=dataset_job_stage_id,
+ name='stage_1',
+ uuid=uuid,
+ dataset_job_id=dataset_job_id,
+ workflow_id=workflow_id,
+ project_id=self._PROJECT_ID,
+ data_batch_id=1,
+ state=DatasetJobState.PENDING)
+ session.add(workflow)
+ session.add(dataset_job)
+ session.add(dataset_job_stage)
+ session.commit()
+
+ mock_get_participants_need_distribute.return_value = []
+ mock_run = MagicMock(return_value=(True, ''))
+ mock_transaction_manager.return_value = MagicMock(run=mock_run)
+
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_id)
+ DatasetJobStageController(session)._transfer_state(uuid=dataset_job_stage.uuid,
+ target_state=DatasetJobState.RUNNING)
+ data = LaunchDatasetJobStageData(dataset_job_stage_uuid=dataset_job_stage.uuid)
+ mock_run.assert_called_with(data=TransactionData(launch_dataset_job_stage_data=data))
+ mock_transaction_manager.assert_called_with(project_name='test-project',
+ project_token=None,
+ two_pc_type=TwoPcType.LAUNCH_DATASET_JOB_STAGE,
+ participants=[])
+
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_id)
+ DatasetJobStageController(session)._transfer_state(uuid=dataset_job_stage.uuid,
+ target_state=DatasetJobState.STOPPED)
+ data = StopDatasetJobStageData(dataset_job_stage_uuid=dataset_job_stage.uuid)
+ mock_run.assert_called_with(data=TransactionData(stop_dataset_job_stage_data=data))
+ mock_transaction_manager.assert_called_with(project_name='test-project',
+ project_token=None,
+ two_pc_type=TwoPcType.STOP_DATASET_JOB_STAGE,
+ participants=[])
+
+ mock_get_participants_need_distribute.return_value = [Participant(domain_name='fake_domain_name_1')]
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_id)
+ DatasetJobStageController(session)._transfer_state(uuid=dataset_job_stage.uuid,
+ target_state=DatasetJobState.RUNNING)
+ data = LaunchDatasetJobStageData(dataset_job_stage_uuid=dataset_job_stage.uuid)
+ mock_run.assert_called_with(data=TransactionData(launch_dataset_job_stage_data=data))
+ mock_transaction_manager.assert_called_with(project_name='test-project',
+ project_token=None,
+ two_pc_type=TwoPcType.LAUNCH_DATASET_JOB_STAGE,
+ participants=['fake_domain_name_1'])
+
+ mock_run = MagicMock(return_value=(False, ''))
+ mock_transaction_manager.return_value = MagicMock(run=mock_run)
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_id)
+ with self.assertRaises(InternalException):
+ DatasetJobStageController(session)._transfer_state(uuid=dataset_job_stage.uuid,
+ target_state=DatasetJobState.RUNNING)
+
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobStageController._transfer_state')
+ def test_start(self, mock_transfer_state: MagicMock):
+ with db.session_scope() as session:
+ DatasetJobStageController(session).start(uuid=1)
+ mock_transfer_state.assert_called_once_with(uuid=1, target_state=DatasetJobState.RUNNING)
+
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobStageController._transfer_state')
+ def test_stop(self, mock_transfer_state: MagicMock):
+ with db.session_scope() as session:
+ DatasetJobStageController(session).stop(uuid=1)
+ mock_transfer_state.assert_called_once_with(uuid=1, target_state=DatasetJobState.STOPPED)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/data_path.py b/web_console_v2/api/fedlearner_webconsole/dataset/data_path.py
new file mode 100644
index 000000000..f7dc0ebba
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/data_path.py
@@ -0,0 +1,26 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+from fedlearner_webconsole.dataset.models import DataBatch, ImportType
+from fedlearner_webconsole.utils.file_manager import FileManager
+
+
+# we put this func out of data_batch model as this func will read file
+def get_batch_data_path(data_batch: DataBatch):
+ if data_batch.dataset.import_type == ImportType.NO_COPY:
+ source_batch_path = DatasetDirectory(data_batch.dataset.path).source_batch_path_file(data_batch.batch_name)
+ return FileManager().read(source_batch_path)
+ return data_batch.path
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/data_path_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/data_path_test.py
new file mode 100644
index 000000000..954ede9d2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/data_path_test.py
@@ -0,0 +1,77 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+import unittest
+from unittest.mock import MagicMock, patch
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, DatasetKindV2, ImportType, DatasetType, DataBatch
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+
+
+class DataPathTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ dataset = Dataset(id=1,
+ uuid=resource_uuid(),
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=False,
+ import_type=ImportType.NO_COPY)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='20220701',
+ dataset_id=1,
+ path='/data/test/batch/20220701',
+ event_time=datetime.strptime('20220701', '%Y%m%d'),
+ file_size=100,
+ num_example=10,
+ num_feature=3,
+ latest_parent_dataset_job_stage_id=1)
+ session.add(data_batch)
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.data_path.FileManager.read')
+ def test_get_batch_data_path(self, mock_read: MagicMock):
+ source_path = '/data/data_source/batch_1'
+ mock_read.return_value = source_path
+ # test get data_path when import_type is NO_COPY
+ with db.session_scope() as session:
+ data_batch: DataBatch = session.query(DataBatch).get(1)
+ self.assertEqual(get_batch_data_path(data_batch), source_path)
+ mock_read.assert_called_once_with('/data/dataset/123/batch/20220701/source_batch_path')
+ # test get data_path when import_type is COPY
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(1)
+ dataset.import_type = ImportType.COPY
+ session.commit()
+ with db.session_scope() as session:
+ data_batch: DataBatch = session.query(DataBatch).get(1)
+ self.assertEqual(get_batch_data_path(data_batch), data_batch.path)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/data_pipeline.py b/web_console_v2/api/fedlearner_webconsole/dataset/data_pipeline.py
deleted file mode 100644
index 70b6d4588..000000000
--- a/web_console_v2/api/fedlearner_webconsole/dataset/data_pipeline.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import io
-import logging
-import os
-import tarfile
-import traceback
-
-from enum import Enum
-from copy import deepcopy
-from typing import Tuple, Optional, List
-from uuid import uuid4
-
-from envs import Envs
-
-from fedlearner_webconsole.composer.interface import IItem, IRunner, ItemType
-from fedlearner_webconsole.composer.models import Context, RunnerStatus
-from fedlearner_webconsole.sparkapp.service import SparkAppService
-from fedlearner_webconsole.sparkapp.schema import SparkAppConfig
-
-
-class DataPipelineType(Enum):
- ANALYZER = 'analyzer'
- CONVERTER = 'converter'
- TRANSFORMER = 'transformer'
-
-
-class DataPipelineItem(IItem):
- def __init__(self, task_id: int):
- self.id = task_id
-
- def type(self) -> ItemType:
- return ItemType.DATA_PIPELINE
-
- def get_id(self) -> int:
- return self.id
-
-
-class DataPipelineRunner(IRunner):
- TYPE_PARAMS_MAPPER = {
- DataPipelineType.ANALYZER: {
- 'files_dir': 'fedlearner_webconsole/dataset/sparkapp/pipeline',
- 'main_application': 'pipeline/analyzer.py',
- },
- DataPipelineType.CONVERTER: {
- 'files_dir': 'fedlearner_webconsole/dataset/sparkapp/pipeline',
- 'main_application': 'pipeline/converter.py',
- },
- DataPipelineType.TRANSFORMER: {
- 'files_dir': 'fedlearner_webconsole/dataset/sparkapp/pipeline',
- 'main_application': 'pipeline/transformer.py',
- }
- }
-
- SPARKAPP_STATE_TO_RUNNER_STATUS = {
- '': RunnerStatus.RUNNING,
- 'SUBMITTED': RunnerStatus.RUNNING,
- 'PENDING_RERUN': RunnerStatus.RUNNING,
- 'RUNNING': RunnerStatus.RUNNING,
- 'COMPLETED': RunnerStatus.DONE,
- 'SUCCEEDING': RunnerStatus.DONE,
- 'FAILED': RunnerStatus.FAILED,
- 'SUBMISSION_FAILED': RunnerStatus.FAILED,
- 'INVALIDATING': RunnerStatus.FAILED,
- 'FAILING': RunnerStatus.FAILED,
- 'UNKNOWN': RunnerStatus.FAILED
- }
-
- def __init__(self, task_id: int) -> None:
- self.task_id = task_id
- self.task_type = None
- self.files_dir = None
- self.files_path = None
- self.main_application = None
- self.command = []
- self.sparkapp_name = None
- self.args = {}
- self.started = False
- self.error_msg = False
-
- self.spark_service = SparkAppService()
-
- def start(self, context: Context):
- try:
- self.started = True
- self.args = deepcopy(context.data.get(str(self.task_id), {}))
- self.task_type = DataPipelineType(self.args.pop('task_type'))
- name = self.args.pop('sparkapp_name')
- job_id = uuid4().hex
- self.sparkapp_name = f'pipe-{self.task_type.value}-{job_id}-{name}'
-
- params = self.__class__.TYPE_PARAMS_MAPPER[self.task_type]
- self.files_dir = os.path.join(Envs.BASE_DIR, params['files_dir'])
- self.files_path = Envs.SPARKAPP_FILES_PATH
- self.main_application = params['main_application']
- self.command = self.args.pop('input')
-
- files = None
- if self.files_path is None:
- files_obj = io.BytesIO()
- with tarfile.open(fileobj=files_obj, mode='w') as f:
- f.add(self.files_dir)
- files = files_obj.getvalue()
-
- config = {
- 'name': self.sparkapp_name,
- 'files': files,
- 'files_path': self.files_path,
- 'image_url': Envs.SPARKAPP_IMAGE_URL,
- 'volumes': gen_sparkapp_volumes(Envs.SPARKAPP_VOLUMES),
- 'driver_config': {
- 'cores':
- 1,
- 'memory':
- '4g',
- 'volume_mounts':
- gen_sparkapp_volume_mounts(Envs.SPARKAPP_VOLUME_MOUNTS),
- },
- 'executor_config': {
- 'cores':
- 2,
- 'memory':
- '4g',
- 'instances':
- 1,
- 'volume_mounts':
- gen_sparkapp_volume_mounts(Envs.SPARKAPP_VOLUME_MOUNTS),
- },
- 'main_application': f'${{prefix}}/{self.main_application}',
- 'command': self.command,
- }
- config_dict = SparkAppConfig.from_dict(config)
- resp = self.spark_service.submit_sparkapp(config=config_dict)
- logging.info(
- f'created spark app, name: {name}, '
- f'config: {config_dict.__dict__}, resp: {resp.__dict__}')
- except Exception as e: # pylint: disable=broad-except
- self.error_msg = f'[composer] failed to run this item, err: {e}, \
- trace: {traceback.format_exc()}'
-
- def result(self, context: Context) -> Tuple[RunnerStatus, dict]:
- if self.error_msg:
- context.set_data(f'failed_{self.task_id}',
- {'error': self.error_msg})
- return RunnerStatus.FAILED, {}
- if not self.started:
- return RunnerStatus.RUNNING, {}
- resp = self.spark_service.get_sparkapp_info(self.sparkapp_name)
- logging.info(f'sparkapp resp: {resp.__dict__}')
- if not resp.state:
- return RunnerStatus.RUNNING, {}
- return self.__class__.SPARKAPP_STATE_TO_RUNNER_STATUS.get(
- resp.state, RunnerStatus.FAILED), resp.to_dict()
-
-
-def gen_sparkapp_volumes(value: str) -> Optional[List[dict]]:
- if value != 'data':
- return None
- # TODO: better to read from conf
- return [{
- 'name': 'data',
- 'persistentVolumeClaim': {
- 'claimName': 'pvc-fedlearner-default'
- }
- }]
-
-
-def gen_sparkapp_volume_mounts(value: str) -> Optional[List[dict]]:
- if value != 'data':
- return None
- # TODO: better to read from conf
- return [{'name': 'data', 'mountPath': '/data'}]
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/dataset_directory.py b/web_console_v2/api/fedlearner_webconsole/dataset/dataset_directory.py
new file mode 100644
index 000000000..544b2151f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/dataset_directory.py
@@ -0,0 +1,101 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# copy from fedlearner_web_console_v2/web_console_v2/inspection/dataset_directory.py
+
+import os
+
+
+class DatasetDirectory(object):
+ """
+ Dataset struct
+ |
+ |--- batch ---- batch_name_1 --- real data files
+ | |
+ | |- batch_name_2 --- real data files
+ | |
+ | |- batch_name_3 --- real data files
+ |
+ |--- meta --- batch_name_1 --- thumbnails (only for image) --- preview image (.png)
+ | | |
+ | | |- _META
+ | |
+ | |- batch_name_2 --- thumbnails (only for image) --- preview image (.png)
+ | | |
+ | | |- _META
+ | |
+ | |- batch_name_3 --- thumbnails (only for image) --- preview image (.png)
+ | | |
+ | | |- _META
+ |
+ |--- errors --- batch_name_1 --- error message files (.csv)
+ | |
+ | |- batch_name_2 --- error message files (.csv)
+ | |
+ | |- batch_name_3 --- error message files (.csv)
+ |
+ |--- side_output --- batch_name_1 --- intermedia data
+ | |
+ | |- batch_name_2 --- intermedia data
+ | |
+ | |- batch_name_3 --- intermedia data
+ |
+ |--- _META (now move to meta/batch_name, delete in future)
+ |
+ |--- schema.json
+
+ """
+ _BATCH_DIR = 'batch'
+ _META_DIR = 'meta'
+ _ERRORS_DIR = 'errors'
+ _SIDE_OUTPUT_DIR = 'side_output'
+ _THUMBNAILS_DIR = 'thumbnails'
+ _META_FILE = '_META'
+ _SCHEMA_FILE = 'schema.json'
+ _SOURCE_BATCH_PATH_FILE = 'source_batch_path'
+
+ def __init__(self, dataset_path: str):
+ self._dataset_path = dataset_path
+
+ @property
+ def dataset_path(self) -> str:
+ return self._dataset_path
+
+ def batch_path(self, batch_name: str) -> str:
+ return os.path.join(self._dataset_path, self._BATCH_DIR, batch_name)
+
+ def errors_path(self, batch_name: str) -> str:
+ return os.path.join(self._dataset_path, self._ERRORS_DIR, batch_name)
+
+ def thumbnails_path(self, batch_name: str) -> str:
+ return os.path.join(self._dataset_path, self._META_DIR, batch_name, self._THUMBNAILS_DIR)
+
+ def side_output_path(self, batch_name: str) -> str:
+ return os.path.join(self._dataset_path, self._SIDE_OUTPUT_DIR, batch_name)
+
+ def source_batch_path_file(self, batch_name: str) -> str:
+ return os.path.join(self.batch_path(batch_name), self._SOURCE_BATCH_PATH_FILE)
+
+ def batch_meta_file(self, batch_name) -> str:
+ return os.path.join(self._dataset_path, self._META_DIR, batch_name, self._META_FILE)
+
+ @property
+ def schema_file(self) -> str:
+ return os.path.join(self._dataset_path, self._SCHEMA_FILE)
+
+ # TODO(liuhehan): remove it in future
+ @property
+ def meta_file(self) -> str:
+ return os.path.join(self._dataset_path, self._META_FILE)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/dataset_directory_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/dataset_directory_test.py
new file mode 100644
index 000000000..a1c3a88df
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/dataset_directory_test.py
@@ -0,0 +1,66 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# copy from fedlearner_web_console_v2/web_console_v2/inspection/dataset_directory_test.py
+
+import unittest
+
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+
+
+class UtilTest(unittest.TestCase):
+ _DATASET_PATH = '/fakepath/test_dataset'
+ _BATCH_NAME = 'test_batch_name'
+
+ def setUp(self) -> None:
+ super().setUp()
+ self._dataset_dir = DatasetDirectory(dataset_path=self._DATASET_PATH)
+
+ def test_dataset_path(self):
+ self.assertEqual(self._dataset_dir.dataset_path, self._DATASET_PATH)
+
+ def test_batch_path(self):
+ self.assertEqual(self._dataset_dir.batch_path(self._BATCH_NAME),
+ f'{self._DATASET_PATH}/batch/{self._BATCH_NAME}')
+
+ def test_errors_path(self):
+ self.assertEqual(self._dataset_dir.errors_path(self._BATCH_NAME),
+ f'{self._DATASET_PATH}/errors/{self._BATCH_NAME}')
+
+ def test_thumbnails_path(self):
+ self.assertEqual(self._dataset_dir.thumbnails_path(self._BATCH_NAME),
+ f'{self._DATASET_PATH}/meta/{self._BATCH_NAME}/thumbnails')
+
+ def test_batch_meta_file(self):
+ self.assertEqual(self._dataset_dir.batch_meta_file(self._BATCH_NAME),
+ f'{self._DATASET_PATH}/meta/{self._BATCH_NAME}/_META')
+
+ def test_tmp_path(self):
+ self.assertEqual(self._dataset_dir.side_output_path(self._BATCH_NAME),
+ f'{self._DATASET_PATH}/side_output/{self._BATCH_NAME}')
+
+ def test_schema_file(self):
+ self.assertEqual(self._dataset_dir.schema_file, f'{self._DATASET_PATH}/schema.json')
+
+ def test_meta_file(self):
+ self.assertEqual(self._dataset_dir.meta_file, f'{self._DATASET_PATH}/_META')
+
+ def test_source_batch_path_file(self):
+ self.assertEqual(self._dataset_dir.source_batch_path_file(self._BATCH_NAME),
+ f'{self._DATASET_PATH}/batch/{self._BATCH_NAME}/source_batch_path')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/delete_dependency.py b/web_console_v2/api/fedlearner_webconsole/dataset/delete_dependency.py
new file mode 100644
index 000000000..499f6be64
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/delete_dependency.py
@@ -0,0 +1,72 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from sqlalchemy import or_
+from fedlearner_webconsole.mmgr.models import ModelJob
+from fedlearner_webconsole.workflow.models import WorkflowExternalState
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJobSchedulerState, ResourceState, DatasetJob
+from typing import List, Tuple
+
+
+class DatasetDeleteDependency(object):
+
+ def __init__(self, session) -> None:
+ self._session = session
+ self._check_pipeline = [self._check_model_jobs, self._check_dataset, self._check_dataset_jobs]
+
+ def is_deletable(self, dataset: Dataset) -> Tuple[bool, List[str]]:
+ # warning: No lock on modelJob table
+ # TODO(wangzeju): Ensure correct check results when concurrently modifying modelJob
+ is_deletable, msg = True, []
+ for check_func in self._check_pipeline:
+ result = check_func(dataset=dataset)
+ is_deletable, msg = is_deletable & result[0], msg + result[1]
+ return is_deletable, msg
+
+ def _check_model_jobs(self, dataset: Dataset) -> Tuple[bool, List[str]]:
+ dataset_id = dataset.id
+ is_deletable, msg = True, []
+ model_jobs: List[ModelJob] = self._session.query(ModelJob).filter_by(dataset_id=dataset_id).all()
+ for model_job in model_jobs:
+ state = model_job.state
+ if state not in [
+ WorkflowExternalState.COMPLETED, WorkflowExternalState.FAILED, WorkflowExternalState.STOPPED,
+ WorkflowExternalState.INVALID
+ ]:
+ is_deletable, msg = False, msg + [f'The Model Job: {model_job.name} is using this dataset']
+ return is_deletable, msg
+
+ def _check_dataset_jobs(self, dataset: Dataset) -> Tuple[bool, List[str]]:
+ is_deletable, msg = True, []
+ dataset_jobs = self._session.query(DatasetJob).filter(DatasetJob.input_dataset_id == dataset.id).all()
+ for dataset_job in dataset_jobs:
+ if not dataset_job.is_finished():
+ is_deletable, msg = False, msg + [
+ f'dependent dataset_job is not finished, dataset_job_id: {dataset_job.id}'
+ ]
+ cron_dataset_jobs = self._session.query(DatasetJob).filter(
+ or_(DatasetJob.input_dataset_id == dataset.id, DatasetJob.output_dataset_id == dataset.id)).filter(
+ DatasetJob.scheduler_state == DatasetJobSchedulerState.RUNNABLE).all()
+ if cron_dataset_jobs:
+ is_deletable, msg = False, msg + [
+ 'dependent cron dataset_job is still runnable, plz stop scheduler first! ' \
+ f'dataset_jobs_id: {[cron_dataset_job.id for cron_dataset_job in cron_dataset_jobs]}'
+ ]
+ return is_deletable, msg
+
+ def _check_dataset(self, dataset: Dataset) -> Tuple[bool, List[str]]:
+ if not dataset.get_frontend_state() in [ResourceState.SUCCEEDED, ResourceState.FAILED]:
+ return False, [f'The dataset {dataset.name} is being processed']
+ return True, []
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/delete_dependency_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/delete_dependency_test.py
new file mode 100644
index 000000000..808a5cf1e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/delete_dependency_test.py
@@ -0,0 +1,211 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, PropertyMock
+from datetime import datetime
+
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.workflow.models import WorkflowExternalState
+from fedlearner_webconsole.mmgr.models import ModelJob
+from fedlearner_webconsole.dataset.delete_dependency import DatasetDeleteDependency
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetJobState, DatasetType, \
+ DatasetJobSchedulerState
+from fedlearner_webconsole.db import db
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class DatasetDeleteDependencyTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ self.default_dataset1 = Dataset(name='default dataset1',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5))
+ session.add(self.default_dataset1)
+ self.default_dataset2 = Dataset(name='default dataset2',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5))
+ session.add(self.default_dataset2)
+ self.default_dataset3 = Dataset(name='default dataset3',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5))
+ session.add(self.default_dataset3)
+ self.default_dataset4 = Dataset(name='default dataset4',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5))
+ session.add(self.default_dataset4)
+ self.default_dataset5 = Dataset(name='default dataset5',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5))
+ session.add(self.default_dataset5)
+ session.commit()
+ with db.session_scope() as session:
+ parent_dataset_job1 = DatasetJob(id=1,
+ uuid='parent_dataset_job_uuid_1',
+ project_id=1,
+ input_dataset_id=100,
+ output_dataset_id=1,
+ state=DatasetJobState.RUNNING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(parent_dataset_job1)
+ child_dataset_job1 = DatasetJob(id=2,
+ uuid='child_dataset_job_uuid_1',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=100,
+ state=DatasetJobState.FAILED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(child_dataset_job1)
+ parent_dataset_job2 = DatasetJob(id=3,
+ uuid='parent_dataset_job_uuid_2',
+ project_id=1,
+ input_dataset_id=100,
+ output_dataset_id=2,
+ state=DatasetJobState.SUCCEEDED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(parent_dataset_job2)
+ child_dataset_job2 = DatasetJob(id=4,
+ uuid='child_dataset_job_uuid_2',
+ project_id=1,
+ input_dataset_id=2,
+ output_dataset_id=100,
+ state=DatasetJobState.PENDING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(child_dataset_job2)
+ self.default_job1 = Job(name='test-train-job-1',
+ state=JobState.WAITING,
+ job_type=JobType.NN_MODEL_TRANINING,
+ workflow_id=1,
+ project_id=1)
+ session.add(self.default_job1)
+ self.default_model_job1 = ModelJob(id=1,
+ name='test-nn-job-1',
+ job_name=self.default_job1.name,
+ dataset_id=3)
+ session.add(self.default_model_job1)
+ self.default_job2 = Job(name='test-train-job-2',
+ state=JobState.COMPLETED,
+ job_type=JobType.NN_MODEL_TRANINING,
+ workflow_id=1,
+ project_id=1)
+ session.add(self.default_job2)
+ self.default_model_job2 = ModelJob(id=2,
+ name='test-nn-job-2',
+ job_name=self.default_job2.name,
+ dataset_id=4)
+ session.add(self.default_model_job2)
+ parent_dataset_job4 = DatasetJob(id=5,
+ uuid='parent_dataset_job_uuid_4',
+ project_id=1,
+ input_dataset_id=100,
+ output_dataset_id=4,
+ state=DatasetJobState.SUCCEEDED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(parent_dataset_job4)
+ parent_dataset_job5 = DatasetJob(id=6,
+ uuid='parent_dataset_job_uuid_5',
+ project_id=1,
+ input_dataset_id=100,
+ output_dataset_id=5,
+ state=DatasetJobState.SUCCEEDED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(parent_dataset_job5)
+ child_dataset_job5 = DatasetJob(id=7,
+ uuid='child_dataset_job_uuid_5',
+ project_id=1,
+ input_dataset_id=5,
+ output_dataset_id=100,
+ state=DatasetJobState.FAILED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(child_dataset_job5)
+ session.commit()
+
+ def test_is_deletable(self):
+ # TODO(wangzeju): Not covering all branches
+ with db.session_scope() as session:
+ dataset_delete_dependency = DatasetDeleteDependency(session)
+ # test delete not finish dataset
+ dataset1 = session.query(Dataset).get(1)
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset1)
+ self.assertFalse(is_deletable)
+
+ # test dataset wtih running dependent dataset_job
+ dataset2 = session.query(Dataset).get(2)
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset2)
+ self.assertFalse(is_deletable)
+
+ # test dataset with runnable cron dataset_job
+ dataset_job = session.query(DatasetJob).get(4)
+ dataset_job.state = DatasetJobState.SUCCEEDED
+ dataset_job.scheduler_state = DatasetJobSchedulerState.RUNNABLE
+ session.flush()
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset2)
+ self.assertFalse(is_deletable)
+ print(msg)
+ self.assertEqual(
+ msg[0], 'dependent cron dataset_job is still runnable, plz stop scheduler first! dataset_jobs_id: [4]')
+
+ # test the dataset is being used by model job
+ dataset3 = session.query(Dataset).get(3)
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset3)
+ self.assertFalse(is_deletable)
+
+ # test the model job is not being used dataset
+ dataset4 = session.query(Dataset).get(4)
+ with patch('fedlearner_webconsole.mmgr.models.ModelJob.state', new_callable=PropertyMock) as mock_state:
+ mock_state.return_value = WorkflowExternalState.COMPLETED
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset4)
+ self.assertTrue(is_deletable)
+
+ with patch('fedlearner_webconsole.mmgr.models.ModelJob.state', new_callable=PropertyMock) as mock_state:
+ mock_state.return_value = WorkflowExternalState.STOPPED
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset4)
+ self.assertTrue(is_deletable)
+
+ with patch('fedlearner_webconsole.mmgr.models.ModelJob.state', new_callable=PropertyMock) as mock_state:
+ mock_state.return_value = WorkflowExternalState.INVALID
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset4)
+ self.assertTrue(is_deletable)
+
+ with patch('fedlearner_webconsole.mmgr.models.ModelJob.state', new_callable=PropertyMock) as mock_state:
+ mock_state.return_value = WorkflowExternalState.RUNNING
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset4)
+ self.assertFalse(is_deletable)
+
+ # test deleteble dataset
+ dataset5 = session.query(Dataset).get(5)
+ is_deletable, msg = dataset_delete_dependency.is_deletable(dataset5)
+ self.assertTrue(is_deletable)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/filter_funcs.py b/web_console_v2/api/fedlearner_webconsole/dataset/filter_funcs.py
new file mode 100644
index 000000000..4747bf4c7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/filter_funcs.py
@@ -0,0 +1,56 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from sqlalchemy import and_, or_
+
+from fedlearner_webconsole.dataset.models import Dataset, DatasetFormat, PublishFrontendState
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.proto.filtering_pb2 import SimpleExpression
+
+
+def dataset_format_filter_op_in(simple_exp: SimpleExpression):
+ filter_list = []
+ dataset_format_str_list = [dataset_format.name for dataset_format in DatasetFormat]
+ for dataset_format in simple_exp.list_value.string_list:
+ if dataset_format not in dataset_format_str_list:
+ raise ValueError(f'dataset_format does not has type {dataset_format}')
+ filter_list.append(DatasetFormat[dataset_format].value)
+ return Dataset.dataset_format.in_(filter_list)
+
+
+def dataset_format_filter_op_equal(simple_exp: SimpleExpression):
+ for dataset_format in DatasetFormat:
+ if simple_exp.string_value == dataset_format.name:
+ return Dataset.dataset_format == dataset_format.value
+ raise ValueError(f'dataset_format does not has type {simple_exp.string_value}')
+
+
+def dataset_publish_frontend_filter_op_equal(simple_exp: SimpleExpression):
+ if simple_exp.string_value == PublishFrontendState.PUBLISHED.name:
+ return and_(Dataset.is_published.is_(True), Dataset.ticket_status == TicketStatus.APPROVED)
+ if simple_exp.string_value == PublishFrontendState.TICKET_PENDING.name:
+ return and_(Dataset.is_published.is_(True), Dataset.ticket_status == TicketStatus.PENDING)
+ if simple_exp.string_value == PublishFrontendState.TICKET_DECLINED.name:
+ return and_(Dataset.is_published.is_(True), Dataset.ticket_status == TicketStatus.DECLINED)
+ return Dataset.is_published.is_(False)
+
+
+def dataset_auth_status_filter_op_in(simple_exp: SimpleExpression):
+ filter_list = [AuthStatus[auth_status] for auth_status in simple_exp.list_value.string_list]
+ filter_exp = Dataset.auth_status.in_(filter_list)
+ if AuthStatus.AUTHORIZED in filter_list:
+ filter_exp = or_(Dataset.auth_status.is_(None), filter_exp)
+ return filter_exp
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/filter_funcs_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/filter_funcs_test.py
new file mode 100644
index 000000000..8f7a03e31
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/filter_funcs_test.py
@@ -0,0 +1,94 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.filter_funcs import dataset_format_filter_op_equal, dataset_format_filter_op_in, \
+ dataset_publish_frontend_filter_op_equal, dataset_auth_status_filter_op_in
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.proto.filtering_pb2 import SimpleExpression
+
+
+class FilterFuncsTest(NoWebServerTestCase):
+
+ def test_dataset_format_filter_op_in(self):
+ # test pass
+ exepression = SimpleExpression(list_value=SimpleExpression.ListValue(string_list=['TABULAR', 'IMAGE']))
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_format_filter_op_in(exepression))
+ self.assertTrue('WHERE datasets_v2.dataset_format IN (0, 1)' in self.generate_mysql_statement(query))
+ # test raise
+ with self.assertRaises(ValueError):
+ exepression = SimpleExpression(list_value=SimpleExpression.ListValue(string_list=['FAKE']))
+ dataset_format_filter_op_in(exepression)
+
+ def test_dataset_format_filter_op_euqal(self):
+ # test pass
+ exepression = SimpleExpression(string_value='TABULAR')
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_format_filter_op_equal(exepression))
+ self.assertTrue('WHERE datasets_v2.dataset_format = 0' in self.generate_mysql_statement(query))
+ # test raise
+ with self.assertRaises(ValueError):
+ exepression = SimpleExpression(list_value=SimpleExpression.ListValue(string_list=['FAKE']))
+ dataset_format_filter_op_equal(exepression)
+
+ def test_dataset_publish_frontend_filter_op_equal(self):
+ # test published
+ exepression = SimpleExpression(string_value='PUBLISHED')
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_publish_frontend_filter_op_equal(exepression))
+ self.assertTrue('WHERE datasets_v2.is_published IS true AND datasets_v2.ticket_status = \'APPROVED\'' in
+ self.generate_mysql_statement(query))
+
+ # test unpublished
+ exepression = SimpleExpression(string_value='UNPUBLISHED')
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_publish_frontend_filter_op_equal(exepression))
+ self.assertTrue('WHERE datasets_v2.is_published IS false' in self.generate_mysql_statement(query))
+
+ # test ticket_pending
+ exepression = SimpleExpression(string_value='TICKET_PENDING')
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_publish_frontend_filter_op_equal(exepression))
+ self.assertTrue('WHERE datasets_v2.is_published IS true AND datasets_v2.ticket_status = \'PENDING\'' in
+ self.generate_mysql_statement(query))
+
+ # test ticket declined
+ exepression = SimpleExpression(string_value='TICKET_DECLINED')
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_publish_frontend_filter_op_equal(exepression))
+ self.assertTrue('WHERE datasets_v2.is_published IS true AND datasets_v2.ticket_status = \'DECLINED\'' in
+ self.generate_mysql_statement(query))
+
+ def test_dataset_auth_status_filter_op_in(self):
+ # test authorized
+ exepression = SimpleExpression(list_value=SimpleExpression.ListValue(string_list=['AUTHORIZED']))
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_auth_status_filter_op_in(exepression))
+ self.assertTrue('WHERE datasets_v2.auth_status IS NULL OR datasets_v2.auth_status IN (\'AUTHORIZED\')' in
+ self.generate_mysql_statement(query))
+ # test others
+ exepression = SimpleExpression(list_value=SimpleExpression.ListValue(string_list=['PENDING', 'WITHDRAW']))
+ with db.session_scope() as session:
+ query = session.query(Dataset).filter(dataset_auth_status_filter_op_in(exepression))
+ self.assertTrue(
+ 'WHERE datasets_v2.auth_status IN (\'PENDING\', \'WITHDRAW\')' in self.generate_mysql_statement(query))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/import_handler.py b/web_console_v2/api/fedlearner_webconsole/dataset/import_handler.py
deleted file mode 100644
index 38a800c07..000000000
--- a/web_console_v2/api/fedlearner_webconsole/dataset/import_handler.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import logging
-import threading
-import os
-from concurrent.futures.thread import ThreadPoolExecutor
-from datetime import timedelta, datetime
-
-from fedlearner_webconsole.dataset.models import DataBatch, BatchState
-from fedlearner_webconsole.db import db
-from fedlearner_webconsole.utils.file_manager import FileManager
-from fedlearner_webconsole.proto import dataset_pb2
-
-
-class ImportHandler(object):
- def __init__(self):
- self._executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 3)
- self._file_manager = FileManager()
- self._pending_imports = set()
- self._running_imports = set()
- self._import_lock = threading.Lock()
- self._app = None
-
- def __del__(self):
- self._executor.shutdown()
-
- def init(self, app):
- self._app = app
-
- def schedule_to_handle(self, dataset_batch_ids):
- if isinstance(dataset_batch_ids, int):
- dataset_batch_ids = [dataset_batch_ids]
- self._pending_imports.update(dataset_batch_ids)
-
- def _copy_file(self,
- source_path,
- destination_path,
- move=False,
- num_retry=3):
- logging.info('%s from %s to %s', 'moving' if move else 'copying',
- source_path, destination_path)
- # Creates parent folders if needed
- parent_folder = os.path.dirname(destination_path)
- self._file_manager.mkdir(parent_folder)
- success = False
- error_message = ''
- for _ in range(num_retry):
- try:
- if move:
- self._file_manager.move(source_path, destination_path)
- else:
- self._file_manager.copy(source_path, destination_path)
- success = True
- break
- except Exception as e: # pylint: disable=broad-except
- logging.error(
- 'Error occurred when importing file from %s to %s',
- source_path, destination_path)
- error_message = str(e)
- file = dataset_pb2.File(source_path=source_path,
- destination_path=destination_path)
- if not success:
- file.error_message = error_message
- file.state = dataset_pb2.File.State.FAILED
- else:
- file.size = self._file_manager.ls(destination_path)[0].size
- file.state = dataset_pb2.File.State.COMPLETED
- return file
-
- def _import_batch(self, batch_id):
- self._import_lock.acquire()
- if batch_id in self._running_imports:
- return
- self._running_imports.add(batch_id)
- self._import_lock.release()
-
- # Pushes app context to make db session work
- self._app.app_context().push()
-
- logging.info('Importing batch %d', batch_id)
- batch = DataBatch.query.get(batch_id)
- batch.state = BatchState.IMPORTING
- db.session.commit()
- db.session.refresh(batch)
- details = batch.get_details()
-
- for file in details.files:
- if file.state == dataset_pb2.File.State.UNSPECIFIED:
- # Recovers the state
- try:
- destination_existed = len(
- self._file_manager.ls(file.destination_path)) > 0
- except Exception: # pylint: disable=broad-except
- destination_existed = False
- if destination_existed:
- file.state = dataset_pb2.File.State.COMPLETED
- continue
- # Moves/Copies
- file.MergeFrom(
- self._copy_file(source_path=file.source_path,
- destination_path=file.destination_path,
- move=batch.move))
-
- batch.set_details(details)
- db.session.commit()
-
- self._import_lock.acquire()
- self._running_imports.remove(batch_id)
- self._import_lock.release()
-
- def handle(self, pull=False):
- """Handles all the batches in the queue or all batches which
- should be imported."""
- batches_to_run = self._pending_imports
- self._pending_imports = set()
- if pull:
- # TODO: should separate pull logic to a cron job,
- # otherwise there will be a race condition that two handlers
- # are trying to move the same batch
- one_hour_ago = datetime.utcnow() - timedelta(hours=1)
- pulled_batches = db.session.query(DataBatch.id).filter(
- (DataBatch.state == BatchState.NEW) |
- (DataBatch.state == BatchState.IMPORTING))\
- .filter(DataBatch.updated_at < one_hour_ago)\
- .all()
- pulled_ids = [bid for bid, in pulled_batches]
- batches_to_run.update(pulled_ids)
-
- for batch in batches_to_run:
- self._executor.submit(self._import_batch, batch)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/BUILD.bazel
new file mode 100644
index 000000000..aef156e80
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/BUILD.bazel
@@ -0,0 +1,107 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "job_configer",
+ srcs = [
+ "analyzer_configer.py",
+ "base_configer.py",
+ "data_alignment_configer.py",
+ "dataset_job_configer.py",
+ "export_configer.py",
+ "hash_data_join_configer.py",
+ "import_source_configer.py",
+ "light_client_ot_psi_data_join_configer.py",
+ "light_client_rsa_psi_data_join_configer.py",
+ "ot_psi_data_join_configer.py",
+ "rsa_psi_data_join_configer.py",
+ ],
+ data = [
+ "//web_console_v2/api/fedlearner_webconsole/sys_preset_templates",
+ ],
+ imports = ["../../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:data_path_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:schema_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_rsa//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "dataset_job_configer_test",
+ size = "medium",
+ srcs = [
+ "dataset_job_configer_test.py",
+ ],
+ imports = ["../../.."],
+ main = "dataset_job_configer_test.py",
+ deps = [
+ ":job_configer",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "light_client_ot_psi_data_join_configer_test",
+ size = "small",
+ srcs = [
+ "light_client_ot_psi_data_join_configer_test.py",
+ ],
+ imports = ["../../.."],
+ main = "light_client_ot_psi_data_join_configer_test.py",
+ deps = [
+ ":job_configer",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "export_configer_test",
+ size = "small",
+ srcs = [
+ "export_configer_test.py",
+ ],
+ imports = ["../../.."],
+ main = "export_configer_test.py",
+ deps = [
+ ":job_configer",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/analyzer_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/analyzer_configer.py
new file mode 100644
index 000000000..ff5a83d83
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/analyzer_configer.py
@@ -0,0 +1,79 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+from typing import List, Optional
+
+from fedlearner_webconsole.dataset.models import Dataset, DatasetFormat
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config, filter_user_variables
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class AnalyzerConfiger(BaseConfiger):
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(self._session).get_workflow_template(name='sys-preset-analyzer')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ # return variables which tag is RESOURCE_ALLOCATION or INPUT_PARAM
+ return filter_user_variables(list(zip_workflow_variables(self.get_config())))
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+ input_dataset: Dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ input_batch_path_variable = make_variable(name='input_batch_path', typed_value=input_dataset.path)
+ set_variable_value_to_job_config(job_config, input_batch_path_variable)
+ data_type_variable = make_variable(name='data_type',
+ typed_value=DatasetFormat(input_dataset.dataset_format).name.lower())
+ set_variable_value_to_job_config(job_config, data_type_variable)
+
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ dataset: Dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ data_batch = self._get_data_batch(dataset, event_time)
+ batch_name = data_batch.name or data_batch.batch_name
+ dataset_path = dataset.path
+ thumbnail_path = DatasetDirectory(dataset_path).thumbnails_path(batch_name)
+ thumbnail_path_variable = make_variable(name='thumbnail_path', typed_value=thumbnail_path)
+ set_variable_value_to_job_config(job_config, thumbnail_path_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/base_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/base_configer.py
new file mode 100644
index 000000000..f4e763bcc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/base_configer.py
@@ -0,0 +1,107 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+from datetime import datetime
+
+from typing import List, Optional
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.dataset.models import Dataset, DataBatch, DatasetType
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+
+
+def get_my_pure_domain_name() -> str:
+ """Get pure domain name of our side
+
+ Returns:
+ str: pure domain name
+ """
+ return SettingService.get_system_info().pure_domain_name
+
+
+def set_variable_value_to_job_config(job_config: dataset_pb2.DatasetJobConfig, target_variable: common_pb2.Variable):
+ for variable in job_config.variables:
+ if variable.name == target_variable.name and variable.value_type == target_variable.value_type:
+ variable.typed_value.CopyFrom(target_variable.typed_value)
+ break
+ else:
+ job_config.variables.append(target_variable)
+
+
+def filter_user_variables(variables: List[common_pb2.Variable]) -> List[common_pb2.Variable]:
+ user_variables = []
+ for variable in variables:
+ if variable.tag in ['RESOURCE_ALLOCATION', 'INPUT_PARAM']:
+ user_variables.append(variable)
+ return user_variables
+
+
+class BaseConfiger(metaclass=abc.ABCMeta):
+ """This is base interface aimed to config dataset_job global_configs for different job kind
+ Routines:
+ user_variables:
+ Usage: Get a list of variables that one can configure itself.
+ When: [Coordinator] API user gets the dataset_job definitions.
+ auto_config_variables:
+ Usage: Auto config some variables that are needed real job without letting users know.
+ When: [Coordinator] API Layer that creates the dataset_job resource.
+ config_local_vairables:
+ Usage: Config some local variables that're sensitive to each participants.
+ When: [Participant] DatasetJob Scheduler of each participants
+ """
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ @abc.abstractmethod
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ """Get workflow_definition of this dataset_job_kind
+
+ Returns:
+ workflow_definition_pb2.WorkflowDefinition: workflow definition according to given kind
+ """
+
+ @property
+ @abc.abstractmethod
+ def user_variables(self) -> List[common_pb2.Variable]:
+ pass
+
+ @abc.abstractmethod
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ pass
+
+ @abc.abstractmethod
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ pass
+
+ def _get_data_batch(self, dataset: Dataset, event_time: Optional[datetime] = None) -> DataBatch:
+ if dataset.dataset_type == DatasetType.PSI:
+ return dataset.get_single_batch()
+ data_batch: DataBatch = self._session.query(DataBatch).filter(DataBatch.dataset_id == dataset.id).filter(
+ DataBatch.event_time == event_time).first()
+ if data_batch is None:
+ raise InvalidArgumentException(
+ details=f'failed to find data_batch, event_time: {to_timestamp(event_time)}, \
+ dataset id: {dataset.id}')
+ return data_batch
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/data_alignment_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/data_alignment_configer.py
new file mode 100644
index 000000000..277162225
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/data_alignment_configer.py
@@ -0,0 +1,121 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+import json
+import os
+
+from typing import List, Optional
+
+from fedlearner_webconsole.dataset.models import Dataset, DatasetFormat
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.schema import spark_schema_to_json_schema
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class DataAlignmentConfiger(BaseConfiger):
+ USER_VARIABLES_NAME_SET = {
+ 'driver_cores',
+ 'driver_mem',
+ 'executor_cores',
+ 'executor_mem',
+ }
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(self._session).get_workflow_template(name='sys-preset-alignment-task')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ real_user_variables = []
+ for variable in zip_workflow_variables(self.get_config()):
+ if variable.name in self.USER_VARIABLES_NAME_SET:
+ real_user_variables.append(variable)
+
+ return real_user_variables
+
+ def auto_config_variables(
+ self, global_configs: dataset_pb2.DatasetJobGlobalConfigs) -> dataset_pb2.DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+ dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ dataset_path = dataset.path
+
+ spark_schema = FileManager().read(os.path.join(dataset_path, 'schema.json'))
+ json_schema_str = json.dumps(spark_schema_to_json_schema(json.loads(spark_schema)))
+ for job_config in global_configs.global_configs.values():
+ json_schema_variable = make_variable(name='json_schema', typed_value=json_schema_str)
+ set_variable_value_to_job_config(job_config, json_schema_variable)
+ data_type_variable = make_variable(name='data_type',
+ typed_value=DatasetFormat(dataset.dataset_format).name.lower())
+ set_variable_value_to_job_config(job_config, data_type_variable)
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ output_dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ input_batch = self._get_data_batch(input_dataset, event_time)
+ output_batch = self._get_data_batch(output_dataset, event_time)
+ input_batch_path = get_batch_data_path(input_batch)
+ output_dataset_path = output_dataset.path
+ output_batch_path = output_batch.path
+ output_batch_name = output_batch.batch_name
+ thumbnail_path = DatasetDirectory(dataset_path=output_dataset_path).thumbnails_path(
+ batch_name=output_batch_name)
+
+ input_dataset_path_variable = make_variable(name='input_dataset_path', typed_value=input_dataset.path)
+ set_variable_value_to_job_config(job_config, input_dataset_path_variable)
+
+ input_batch_path_variable = make_variable(name='input_batch_path', typed_value=input_batch_path)
+ set_variable_value_to_job_config(job_config, input_batch_path_variable)
+
+ output_dataset_path_variable = make_variable(name='output_dataset_path', typed_value=output_dataset_path)
+ set_variable_value_to_job_config(job_config, output_dataset_path_variable)
+
+ output_batch_path_variable = make_variable(name='output_batch_path', typed_value=output_batch_path)
+ set_variable_value_to_job_config(job_config, output_batch_path_variable)
+
+ thumbnail_path_variable = make_variable(name='thumbnail_path', typed_value=thumbnail_path)
+ set_variable_value_to_job_config(job_config, thumbnail_path_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=output_batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/dataset_job_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/dataset_job_configer.py
new file mode 100644
index 000000000..e43082e56
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/dataset_job_configer.py
@@ -0,0 +1,54 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+from typing import Union
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.dataset.models import DatasetJobKind
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger
+from fedlearner_webconsole.dataset.job_configer.import_source_configer import ImportSourceConfiger
+from fedlearner_webconsole.dataset.job_configer.data_alignment_configer import DataAlignmentConfiger
+from fedlearner_webconsole.dataset.job_configer.rsa_psi_data_join_configer import RsaPsiDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.export_configer import ExportConfiger
+from fedlearner_webconsole.dataset.job_configer.light_client_rsa_psi_data_join_configer import \
+ LightClientRsaPsiDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.ot_psi_data_join_configer import OtPsiDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.light_client_ot_psi_data_join_configer import \
+ LightClientOtPsiDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.hash_data_join_configer import HashDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.analyzer_configer import AnalyzerConfiger
+
+
+class DatasetJobConfiger(metaclass=abc.ABCMeta):
+
+ @classmethod
+ def from_kind(cls, kind: Union[DatasetJobKind, str], session: Session) -> BaseConfiger:
+ hanlders_mapper = {
+ DatasetJobKind.IMPORT_SOURCE: ImportSourceConfiger,
+ DatasetJobKind.DATA_ALIGNMENT: DataAlignmentConfiger,
+ DatasetJobKind.RSA_PSI_DATA_JOIN: RsaPsiDataJoinConfiger,
+ DatasetJobKind.EXPORT: ExportConfiger,
+ DatasetJobKind.LIGHT_CLIENT_RSA_PSI_DATA_JOIN: LightClientRsaPsiDataJoinConfiger,
+ DatasetJobKind.OT_PSI_DATA_JOIN: OtPsiDataJoinConfiger,
+ DatasetJobKind.LIGHT_CLIENT_OT_PSI_DATA_JOIN: LightClientOtPsiDataJoinConfiger,
+ DatasetJobKind.HASH_DATA_JOIN: HashDataJoinConfiger,
+ DatasetJobKind.ANALYZER: AnalyzerConfiger,
+ }
+
+ if isinstance(kind, str):
+ kind = DatasetJobKind(kind)
+
+ return hanlders_mapper[kind](session)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/dataset_job_configer_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/dataset_job_configer_test.py
new file mode 100644
index 000000000..2911590d9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/dataset_job_configer_test.py
@@ -0,0 +1,864 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# TODO(liuhehan): split this UT to multi-files
+# pylint: disable=protected-access
+from datetime import datetime, timedelta
+import json
+import os
+import unittest
+from unittest.mock import patch
+
+from google.protobuf.struct_pb2 import Value
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobConfig, DatasetJobGlobalConfigs
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.dataset.models import DataBatch, DataSource, Dataset, DatasetJob, DatasetJobKind, \
+ DatasetJobState, DatasetKindV2, DatasetMetaInfo, ImportType, DatasetType
+from fedlearner_webconsole.dataset.job_configer.import_source_configer import ImportSourceConfiger
+from fedlearner_webconsole.dataset.job_configer.data_alignment_configer import DataAlignmentConfiger
+from fedlearner_webconsole.dataset.job_configer.rsa_psi_data_join_configer import RsaPsiDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.light_client_rsa_psi_data_join_configer import \
+ LightClientRsaPsiDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.ot_psi_data_join_configer import OtPsiDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.hash_data_join_configer import HashDataJoinConfiger
+from fedlearner_webconsole.dataset.job_configer.analyzer_configer import AnalyzerConfiger
+
+
+def fake_spark_schema(*args) -> str:
+ del args
+
+ return json.dumps({
+ 'type':
+ 'struct',
+ 'fields': [{
+ 'name': 'raw_id',
+ 'type': 'integer',
+ 'nullable': True,
+ 'metadata': {}
+ }, {
+ 'name': 'f01',
+ 'type': 'float',
+ 'nullable': True,
+ 'metadata': {}
+ }, {
+ 'name': 'image',
+ 'type': 'binary',
+ 'nullable': True,
+ 'metadata': {}
+ }]
+ })
+
+
+class DatasetJobConfigersTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.maxDiff = None
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+
+ test_project = Project(name='test_project')
+ session.add(test_project)
+ session.flush([test_project])
+
+ data_source = DataSource(id=1,
+ name='test_data_source',
+ uuid=resource_uuid(),
+ path='/data/some_data_source/')
+ session.add(data_source)
+
+ test_input_dataset = Dataset(id=2,
+ name='test_input_dataset',
+ uuid=resource_uuid(),
+ is_published=False,
+ project_id=test_project.id,
+ path='/data/dataset/test_input_dataset',
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(test_input_dataset)
+ session.flush([test_input_dataset])
+
+ test_input_data_batch = DataBatch(dataset_id=test_input_dataset.id,
+ path=os.path.join(test_input_dataset.path, 'batch/test_input_data_batch'))
+ session.add(test_input_data_batch)
+
+ test_output_dataset = Dataset(id=3,
+ name='test_output_dataset',
+ uuid=resource_uuid(),
+ is_published=True,
+ project_id=test_project.id,
+ path='/data/dataset/test_output_dataset',
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(test_output_dataset)
+ session.flush([test_output_dataset])
+
+ test_output_data_batch = DataBatch(dataset_id=test_output_dataset.id,
+ path=os.path.join(test_output_dataset.path,
+ 'batch/test_output_data_batch'))
+ session.add(test_output_data_batch)
+
+ test_input_streaming_dataset = Dataset(id=4,
+ name='test_input_dataset',
+ uuid=resource_uuid(),
+ is_published=False,
+ project_id=test_project.id,
+ path='/data/dataset/test_input_dataset',
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(test_input_streaming_dataset)
+ session.flush()
+
+ test_input_streaming_data_batch = DataBatch(dataset_id=test_input_streaming_dataset.id,
+ event_time=datetime(2022, 1, 1),
+ path=os.path.join(test_input_dataset.path, 'batch/20220101'))
+ session.add(test_input_streaming_data_batch)
+
+ test_output_streaming_dataset = Dataset(id=5,
+ name='test_output_dataset',
+ uuid=resource_uuid(),
+ is_published=True,
+ project_id=test_project.id,
+ path='/data/dataset/test_output_dataset',
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(test_output_streaming_dataset)
+ session.flush()
+
+ test_output_streaming_data_batch = DataBatch(dataset_id=test_output_streaming_dataset.id,
+ event_time=datetime(2022, 1, 1),
+ path=os.path.join(test_output_dataset.path, 'batch/20220101'))
+ session.add(test_output_streaming_data_batch)
+
+ self._data_source_uuid = data_source.uuid
+ self._input_dataset_uuid = test_input_dataset.uuid
+ self._output_dataset_uuid = test_output_dataset.uuid
+ self._input_streaming_dataset_uuid = test_input_streaming_dataset.uuid
+ self._output_streaming_dataset_uuid = test_output_streaming_dataset.uuid
+
+ session.commit()
+
+ def test_get_data_batch(self):
+ # test PSI dataset
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).filter(Dataset.uuid == self._output_dataset_uuid).first()
+ data_batch = ImportSourceConfiger(session)._get_data_batch(dataset=dataset)
+ self.assertEqual(data_batch.path, '/data/dataset/test_output_dataset/batch/test_output_data_batch')
+
+ # test STREAMING dataset
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).filter(Dataset.uuid == self._output_streaming_dataset_uuid).first()
+ data_batch = ImportSourceConfiger(session)._get_data_batch(dataset=dataset, event_time=datetime(2022, 1, 1))
+ self.assertEqual(data_batch.path, '/data/dataset/test_output_dataset/batch/20220101')
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_import_source(self):
+ with db.session_scope() as session:
+ # This is a test to notify the change of template
+ config = ImportSourceConfiger(session).get_config()
+ self.assertEqual(len(config.variables), 22)
+
+ with db.session_scope() as session:
+ global_configs = ImportSourceConfiger(session).auto_config_variables(global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._data_source_uuid)}))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='file_format',
+ value='tfrecords',
+ typed_value=Value(string_value='tfrecords'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='data_type',
+ value='tabular',
+ typed_value=Value(string_value='tabular'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ with db.session_scope() as session:
+ global_configs = ImportSourceConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._data_source_uuid)}),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_batch_path',
+ value='/data/some_data_source/',
+ typed_value=Value(string_value='/data/some_data_source/'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(
+ name='batch_path',
+ value='/data/dataset/test_output_dataset/batch/test_output_data_batch',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='thumbnail_path',
+ value='/data/dataset/test_output_dataset/meta/test_output_data_batch/thumbnails',
+ typed_value=Value(
+ string_value='/data/dataset/test_output_dataset/meta/test_output_data_batch/thumbnails'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='checkers',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='import_type',
+ value='COPY',
+ typed_value=Value(string_value='COPY'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ with db.session_scope() as session:
+ output_dataset: Dataset = session.query(Dataset).filter(Dataset.uuid == self._output_dataset_uuid).first()
+ output_dataset.set_meta_info(DatasetMetaInfo(schema_checkers=['RAW_ID_CHECKER', 'NUMERIC_COLUMNS_CHECKER']))
+ output_dataset.import_type = ImportType.NO_COPY
+ session.commit()
+
+ with db.session_scope() as session:
+ global_configs = ImportSourceConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._data_source_uuid)}),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_batch_path',
+ value='/data/some_data_source/',
+ typed_value=Value(string_value='/data/some_data_source/'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(
+ name='batch_path',
+ value='/data/dataset/test_output_dataset/batch/test_output_data_batch',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='thumbnail_path',
+ value='/data/dataset/test_output_dataset/meta/test_output_data_batch/thumbnails',
+ typed_value=Value(
+ string_value='/data/dataset/test_output_dataset/meta/test_output_data_batch/thumbnails'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='checkers',
+ value='RAW_ID_CHECKER,NUMERIC_COLUMNS_CHECKER',
+ typed_value=Value(string_value='RAW_ID_CHECKER,NUMERIC_COLUMNS_CHECKER'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='import_type',
+ value='NO_COPY',
+ typed_value=Value(string_value='NO_COPY'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='skip_analyzer',
+ value='true',
+ typed_value=Value(string_value='true'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ with db.session_scope() as session:
+ dataset_job_streaming = DatasetJob(id=1,
+ uuid='dataset_job streaming',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=5,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ state=DatasetJobState.PENDING,
+ time_range=timedelta(days=1))
+ session.add(dataset_job_streaming)
+ session.commit()
+
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = ImportSourceConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._data_source_uuid)}),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_batch_path',
+ value='/data/some_data_source/20220101',
+ typed_value=Value(string_value='/data/some_data_source/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='batch_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='thumbnail_path',
+ value='/data/dataset/test_output_dataset/meta/20220101/thumbnails',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/meta/20220101/thumbnails'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='checkers',
+ value='',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='import_type',
+ value='COPY',
+ typed_value=Value(string_value='COPY'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ @patch('fedlearner_webconsole.dataset.job_configer.data_alignment_configer.FileManager.read', fake_spark_schema)
+ def test_data_alignment(self):
+ with db.session_scope() as session:
+ # This is a test to notify the change of template
+ config = DataAlignmentConfiger(session).get_config()
+ variables = zip_workflow_variables(config)
+ self.assertEqual(len(list(variables)), 18)
+
+ with db.session_scope() as session:
+ self.assertEqual(len(DataAlignmentConfiger(session).user_variables), 4)
+
+ with db.session_scope() as session:
+ global_configs = DataAlignmentConfiger(session).auto_config_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }))
+ for job_config in global_configs.global_configs.values():
+ variables = job_config.variables
+ input_batch_path_variable = [v for v in variables if v.name == 'json_schema'][0]
+ self.assertEqual(input_batch_path_variable.value_type, common_pb2.Variable.ValueType.STRING)
+ data_type_variable = [v for v in variables if v.name == 'data_type'][0]
+ self.assertEqual(data_type_variable.typed_value.string_value, 'tabular')
+
+ with db.session_scope() as session:
+ global_configs = DataAlignmentConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_dataset_uuid)
+ variables = global_configs.global_configs['test_domain'].variables
+ input_batch_path_variable = [v for v in variables if v.name == 'input_batch_path'][0]
+ self.assertEqual(input_batch_path_variable.typed_value.string_value,
+ '/data/dataset/test_input_dataset/batch/test_input_data_batch')
+ thumbnail_path_variable = [v for v in variables if v.name == 'thumbnail_path'][0]
+ self.assertEqual(thumbnail_path_variable.typed_value.string_value,
+ '/data/dataset/test_output_dataset/meta/test_output_data_batch/thumbnails')
+
+ variables = global_configs.global_configs['test_domain_2'].variables
+ self.assertListEqual([v for v in variables if v.name == 'input_batch_path'], [])
+ self.assertListEqual([v for v in variables if v.name == 'thumbnail_path'], [])
+
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = DataAlignmentConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_dataset_path',
+ value='/data/dataset/test_input_dataset',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='input_batch_path',
+ value='/data/dataset/test_input_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='thumbnail_path',
+ value='/data/dataset/test_output_dataset/meta/20220101/thumbnails',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/meta/20220101/thumbnails'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_rsa_psi_data_join(self):
+ with db.session_scope() as session:
+ # This is a test to notify the change of template
+ config = RsaPsiDataJoinConfiger(session).get_config()
+ variables = zip_workflow_variables(config)
+ self.assertEqual(len(list(variables)), 20)
+
+ with db.session_scope() as session:
+ self.assertEqual(len(RsaPsiDataJoinConfiger(session).user_variables), 9)
+
+ with db.session_scope() as session:
+ global_configs = RsaPsiDataJoinConfiger(session).auto_config_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }))
+
+ for pure_domain_name, job_config in global_configs.global_configs.items():
+ if pure_domain_name == 'test_domain':
+ for var in job_config.variables:
+ if var.name == 'role':
+ self.assertEqual(var.typed_value.string_value, 'Leader')
+ elif var.name == 'rsa_key_pem':
+ self.assertIn('-----BEGIN RSA PRIVATE KEY-----', var.typed_value.string_value)
+ else:
+ for var in job_config.variables:
+ if var.name == 'role':
+ self.assertEqual(var.typed_value.string_value, 'Follower')
+ elif var.name == 'rsa_key_pem':
+ self.assertIn('-----BEGIN RSA PUBLIC KEY-----', var.typed_value.string_value)
+
+ for var in job_config.variables:
+ if var.name == 'rsa_key_path':
+ self.assertEqual(var.typed_value.string_value, '')
+
+ with db.session_scope() as session:
+ global_configs = RsaPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_dataset_uuid)
+
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='dataset',
+ value='/data/dataset/test_input_dataset/batch/test_input_data_batch',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/test_input_data_batch'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=Variable.ValueType.STRING),
+ Variable(
+ name='output_batch_path',
+ value='/data/dataset/test_output_dataset/batch/test_output_data_batch',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/test_output_data_batch'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=Variable.ValueType.STRING),
+ ])
+
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = RsaPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='dataset',
+ value='/data/dataset/test_input_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_light_client_rsa_psi_data_join(self):
+ with db.session_scope() as session:
+ global_configs = LightClientRsaPsiDataJoinConfiger(session).auto_config_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid)}))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [])
+
+ with db.session_scope() as session:
+ global_configs = LightClientRsaPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ DatasetJobConfig(dataset_uuid=self._input_dataset_uuid,
+ variables=[
+ common_pb2.Variable(name='input_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='input_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ }),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ common_pb2.Variable(name='input_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(
+ name='input_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/test_input_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(
+ name='output_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = LightClientRsaPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_dataset_path',
+ value='/data/dataset/test_input_dataset',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='input_batch_path',
+ value='/data/dataset/test_input_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_ot_psi_data_join_configer(self):
+ with db.session_scope() as session:
+ global_configs = OtPsiDataJoinConfiger(session).auto_config_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ DatasetJobConfig(dataset_uuid=self._input_dataset_uuid,
+ variables=[
+ common_pb2.Variable(name='role',
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ]),
+ 'test_domain_2':
+ DatasetJobConfig(dataset_uuid='u12345',
+ variables=[
+ common_pb2.Variable(name='role',
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ }))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ common_pb2.Variable(name='role',
+ typed_value=Value(string_value='server'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ self.assertEqual(list(global_configs.global_configs['test_domain_2'].variables), [
+ common_pb2.Variable(name='role',
+ typed_value=Value(string_value='client'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ with db.session_scope() as session:
+ global_configs = OtPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ DatasetJobConfig(dataset_uuid=self._input_dataset_uuid,
+ variables=[
+ common_pb2.Variable(name='input_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='input_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ }),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ common_pb2.Variable(name='input_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(
+ name='input_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/test_input_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(
+ name='output_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = OtPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_dataset_path',
+ value='/data/dataset/test_input_dataset',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='input_batch_path',
+ value='/data/dataset/test_input_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_hash_data_join_configer(self):
+ with db.session_scope() as session:
+ global_configs = HashDataJoinConfiger(session).auto_config_variables(global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ DatasetJobConfig(dataset_uuid=self._input_dataset_uuid,
+ variables=[
+ common_pb2.Variable(name='role',
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ]),
+ 'test_domain_2':
+ DatasetJobConfig(dataset_uuid='u12345',
+ variables=[
+ common_pb2.Variable(name='role',
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ }))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ common_pb2.Variable(name='role',
+ typed_value=Value(string_value='server'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ self.assertEqual(list(global_configs.global_configs['test_domain_2'].variables), [
+ common_pb2.Variable(name='role',
+ typed_value=Value(string_value='client'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ with db.session_scope() as session:
+ global_configs = HashDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ DatasetJobConfig(dataset_uuid=self._input_dataset_uuid,
+ variables=[
+ common_pb2.Variable(name='input_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='input_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ }),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ common_pb2.Variable(name='input_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(
+ name='input_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/test_input_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(
+ name='output_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = HashDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_dataset_path',
+ value='/data/dataset/test_input_dataset',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='input_batch_path',
+ value='/data/dataset/test_input_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_analyzer_configer(self):
+ with db.session_scope() as session:
+ # This is a test to notify the change of template
+ config = AnalyzerConfiger(session).get_config()
+ self.assertEqual(len(config.variables), 14)
+
+ with db.session_scope() as session:
+ global_configs = AnalyzerConfiger(session).auto_config_variables(global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._data_source_uuid)}))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ common_pb2.Variable(name='input_batch_path',
+ value='/data/some_data_source/',
+ typed_value=Value(string_value='/data/some_data_source/'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ common_pb2.Variable(name='data_type',
+ value='tabular',
+ typed_value=Value(string_value='tabular'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ with db.session_scope() as session:
+ global_configs = AnalyzerConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._data_source_uuid)}),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='thumbnail_path',
+ value='/data/dataset/test_output_dataset/meta/test_output_data_batch/thumbnails',
+ typed_value=Value(
+ string_value='/data/dataset/test_output_dataset/meta/test_output_data_batch/thumbnails'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = AnalyzerConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid)}),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='thumbnail_path',
+ value='/data/dataset/test_output_dataset/meta/20220101/thumbnails',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/meta/20220101/thumbnails'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=common_pb2.Variable.ValueType.STRING),
+ ])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/export_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/export_configer.py
new file mode 100644
index 000000000..95900715c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/export_configer.py
@@ -0,0 +1,75 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+from typing import List, Optional
+
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, filter_user_variables, \
+ get_my_pure_domain_name, set_variable_value_to_job_config
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class ExportConfiger(BaseConfiger):
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(self._session).get_workflow_template(name='sys-preset-export-dataset')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ # return variables which tag is RESOURCE_ALLOCATION or INPUT_PARAM
+ return filter_user_variables(list(zip_workflow_variables(self.get_config())))
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset: Dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ output_dataset: Dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ input_batch = self._get_data_batch(input_dataset, event_time)
+ output_batch = self._get_data_batch(output_dataset, event_time)
+
+ dataset_path_variable = make_variable(name='dataset_path', typed_value=input_dataset.path)
+ set_variable_value_to_job_config(job_config, dataset_path_variable)
+
+ file_format_variable = make_variable(name='file_format', typed_value=input_dataset.store_format.name.lower())
+ set_variable_value_to_job_config(job_config, file_format_variable)
+
+ batch_name_variable = make_variable(name='batch_name', typed_value=input_batch.batch_name)
+ set_variable_value_to_job_config(job_config, batch_name_variable)
+
+ export_path_variable = make_variable(name='export_path', typed_value=output_batch.path)
+ set_variable_value_to_job_config(job_config, export_path_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/export_configer_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/export_configer_test.py
new file mode 100644
index 000000000..99727fbc9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/export_configer_test.py
@@ -0,0 +1,183 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+import os
+import unittest
+from unittest.mock import patch
+from google.protobuf.struct_pb2 import Value
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, DatasetKindV2, DatasetType
+from fedlearner_webconsole.dataset.job_configer.export_configer import ExportConfiger
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobConfig, DatasetJobGlobalConfigs
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+
+
+class ExportConfigersTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+
+ test_project = Project(name='test_project')
+ session.add(test_project)
+ session.flush([test_project])
+
+ test_input_dataset = Dataset(id=2,
+ name='test_input_dataset',
+ uuid=resource_uuid(),
+ is_published=False,
+ project_id=test_project.id,
+ path='/data/dataset/test_input_dataset',
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(test_input_dataset)
+ session.flush([test_input_dataset])
+
+ test_input_data_batch = DataBatch(dataset_id=test_input_dataset.id,
+ path=os.path.join(test_input_dataset.path, 'batch/0'))
+ session.add(test_input_data_batch)
+
+ test_output_dataset = Dataset(id=3,
+ name='test_output_dataset',
+ uuid=resource_uuid(),
+ is_published=True,
+ project_id=test_project.id,
+ path='/data/dataset/test_output_dataset',
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(test_output_dataset)
+ session.flush([test_output_dataset])
+
+ test_output_data_batch = DataBatch(dataset_id=test_output_dataset.id,
+ path=os.path.join(test_output_dataset.path, 'batch/0'))
+ session.add(test_output_data_batch)
+
+ test_input_streaming_dataset = Dataset(id=4,
+ name='test_input_dataset',
+ uuid=resource_uuid(),
+ is_published=False,
+ project_id=test_project.id,
+ path='/data/dataset/test_input_dataset',
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(test_input_streaming_dataset)
+ session.flush()
+
+ test_input_streaming_data_batch = DataBatch(dataset_id=test_input_streaming_dataset.id,
+ event_time=datetime(2022, 1, 1),
+ path=os.path.join(test_input_dataset.path, 'batch/20220101'))
+ session.add(test_input_streaming_data_batch)
+
+ test_output_streaming_dataset = Dataset(id=5,
+ name='test_output_dataset',
+ uuid=resource_uuid(),
+ is_published=True,
+ project_id=test_project.id,
+ path='/data/dataset/test_output_dataset',
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(test_output_streaming_dataset)
+ session.flush()
+
+ test_output_streaming_data_batch = DataBatch(dataset_id=test_output_streaming_dataset.id,
+ event_time=datetime(2022, 1, 1),
+ path=os.path.join(test_output_dataset.path, 'batch/20220101'))
+ session.add(test_output_streaming_data_batch)
+
+ self._input_dataset_uuid = test_input_dataset.uuid
+ self._output_dataset_uuid = test_output_dataset.uuid
+ self._input_streaming_dataset_uuid = test_input_streaming_dataset.uuid
+ self._output_streaming_dataset_uuid = test_output_streaming_dataset.uuid
+
+ session.commit()
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_export(self):
+ with db.session_scope() as session:
+ # This is a test to notify the change of template
+ config = ExportConfiger(session).get_config()
+ variables = zip_workflow_variables(config)
+ self.assertEqual(len(list(variables)), 13)
+
+ with db.session_scope() as session:
+ self.assertEqual(len(ExportConfiger(session).user_variables), 8)
+
+ with db.session_scope() as session:
+ resp = ExportConfiger(session).auto_config_variables(global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid)}))
+ self.assertEqual(list(resp.global_configs['test_domain'].variables), [])
+
+ # test none_streaming dataset
+ with db.session_scope() as session:
+ global_configs = ExportConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid)}),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='dataset_path',
+ value='/data/dataset/test_input_dataset',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='file_format',
+ value='tfrecords',
+ typed_value=Value(string_value='tfrecords'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='batch_name',
+ value='0',
+ typed_value=Value(string_value='0'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='export_path',
+ value='/data/dataset/test_output_dataset/batch/0',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/0'),
+ value_type=Variable.ValueType.STRING),
+ ])
+
+ # test streaming dataset
+ with db.session_scope() as session:
+ global_configs = ExportConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid)}),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='dataset_path',
+ value='/data/dataset/test_input_dataset',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='file_format',
+ value='tfrecords',
+ typed_value=Value(string_value='tfrecords'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='export_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=Variable.ValueType.STRING),
+ ])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/hash_data_join_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/hash_data_join_configer.py
new file mode 100644
index 000000000..90d40684e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/hash_data_join_configer.py
@@ -0,0 +1,95 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+from typing import List, Optional
+
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class HashDataJoinConfiger(BaseConfiger):
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(
+ self._session).get_workflow_template(name='sys-preset-hash-data-join-analyzer')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ # return all variables and frontend will filter them by tag
+ return list(zip_workflow_variables(self.get_config()))
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+ dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ for pure_domain_name, job_config in global_configs.global_configs.items():
+ if pure_domain_name == my_domain_name:
+ role_variable = make_variable(name='role', typed_value='server')
+ else:
+ role_variable = make_variable(name='role', typed_value='client')
+ set_variable_value_to_job_config(job_config, role_variable)
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ output_dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ input_batch = self._get_data_batch(input_dataset, event_time)
+ output_batch = self._get_data_batch(output_dataset, event_time)
+
+ input_batch_path = get_batch_data_path(input_batch)
+ output_batch_path = output_batch.path
+
+ input_dataset_path_variable = make_variable(name='input_dataset_path', typed_value=input_dataset.path)
+ set_variable_value_to_job_config(job_config, input_dataset_path_variable)
+
+ input_batch_path_variable = make_variable(name='input_batch_path', typed_value=input_batch_path)
+ set_variable_value_to_job_config(job_config, input_batch_path_variable)
+
+ output_dataset_path_variable = make_variable(name='output_dataset_path', typed_value=output_dataset.path)
+ set_variable_value_to_job_config(job_config, output_dataset_path_variable)
+
+ output_batch_path_variable = make_variable(name='output_batch_path', typed_value=output_batch_path)
+ set_variable_value_to_job_config(job_config, output_batch_path_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=output_batch.batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/import_source_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/import_source_configer.py
new file mode 100644
index 000000000..78a3e9610
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/import_source_configer.py
@@ -0,0 +1,122 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+import os
+
+from typing import List, Optional
+
+from fedlearner_webconsole.dataset.models import (Dataset, DatasetFormat, DatasetKindV2, ImportType, StoreFormat,
+ DatasetType)
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config, filter_user_variables
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+from fedlearner_webconsole.dataset.util import parse_event_time_to_daily_folder_name, \
+ parse_event_time_to_hourly_folder_name
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class ImportSourceConfiger(BaseConfiger):
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(self._session).get_workflow_template(name='sys-preset-converter-analyzer')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ # return variables which tag is RESOURCE_ALLOCATION or INPUT_PARAM
+ return filter_user_variables(list(zip_workflow_variables(self.get_config())))
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+ input_dataset: Dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ if input_dataset.store_format is None:
+ if input_dataset.dataset_kind == DatasetKindV2.SOURCE:
+ raise InvalidArgumentException(f'data_source {input_dataset.name} is too old and has no store_format, \
+ please create a new data_source')
+ input_dataset.store_format = StoreFormat.TFRECORDS
+ file_format_variable = make_variable(name='file_format', typed_value=input_dataset.store_format.name.lower())
+ set_variable_value_to_job_config(job_config, file_format_variable)
+ data_type_variable = make_variable(name='data_type',
+ typed_value=DatasetFormat(input_dataset.dataset_format).name.lower())
+ set_variable_value_to_job_config(job_config, data_type_variable)
+
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+ input_dataset: Dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+ output_dataset: Dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ if output_dataset.dataset_type == DatasetType.PSI:
+ input_batch_path = input_dataset.path
+ else:
+ if output_dataset.parent_dataset_job.is_hourly_cron():
+ folder_name = parse_event_time_to_hourly_folder_name(event_time)
+ else:
+ folder_name = parse_event_time_to_daily_folder_name(event_time)
+ input_batch_path = os.path.join(input_dataset.path, folder_name)
+ output_data_batch = self._get_data_batch(output_dataset, event_time)
+ output_batch_path = output_data_batch.path
+ output_batch_name = output_data_batch.batch_name
+ output_dataset_path = output_dataset.path
+ thumbnail_path = DatasetDirectory(dataset_path=output_dataset_path).thumbnails_path(
+ batch_name=output_batch_name)
+ schema_checkers = list(output_dataset.get_meta_info().schema_checkers)
+
+ # Note: Following vairables's name should be equal to template `sys-preset-converter-analyzer`
+ input_batch_path_variable = make_variable(name='input_batch_path', typed_value=input_batch_path)
+ set_variable_value_to_job_config(job_config, input_batch_path_variable)
+
+ dataset_path_variable = make_variable(name='dataset_path', typed_value=output_dataset_path)
+ set_variable_value_to_job_config(job_config, dataset_path_variable)
+
+ batch_path_variable = make_variable(name='batch_path', typed_value=output_batch_path)
+ set_variable_value_to_job_config(job_config, batch_path_variable)
+
+ thumbnail_path_variable = make_variable(name='thumbnail_path', typed_value=thumbnail_path)
+ set_variable_value_to_job_config(job_config, thumbnail_path_variable)
+
+ schema_checkers_variable = make_variable(name='checkers', typed_value=','.join(schema_checkers))
+ set_variable_value_to_job_config(job_config, schema_checkers_variable)
+
+ import_type_variable = make_variable(name='import_type', typed_value=output_dataset.import_type.name)
+ set_variable_value_to_job_config(job_config, import_type_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=output_batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ if output_dataset.import_type == ImportType.NO_COPY:
+ skip_analyzer_variable = make_variable(name='skip_analyzer', typed_value='true')
+ set_variable_value_to_job_config(job_config, skip_analyzer_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_ot_psi_data_join_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_ot_psi_data_join_configer.py
new file mode 100644
index 000000000..e1c8715a4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_ot_psi_data_join_configer.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+from typing import List, Optional
+
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class LightClientOtPsiDataJoinConfiger(BaseConfiger):
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(self._session).get_workflow_template(name='sys-preset-light-ot-data-join')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ # return all variables and frontend will filter them by tag
+ return list(zip_workflow_variables(self.get_config()))
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ output_dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ input_batch = self._get_data_batch(input_dataset, event_time)
+ output_batch = self._get_data_batch(output_dataset, event_time)
+
+ input_batch_path = get_batch_data_path(input_batch)
+ output_batch_path = output_batch.path
+
+ input_dataset_path_variable = make_variable(name='input_dataset_path', typed_value=input_dataset.path)
+ set_variable_value_to_job_config(job_config, input_dataset_path_variable)
+
+ input_batch_path_variable = make_variable(name='input_batch_path', typed_value=input_batch_path)
+ set_variable_value_to_job_config(job_config, input_batch_path_variable)
+
+ output_dataset_path_variable = make_variable(name='output_dataset_path', typed_value=output_dataset.path)
+ set_variable_value_to_job_config(job_config, output_dataset_path_variable)
+
+ output_batch_path_variable = make_variable(name='output_batch_path', typed_value=output_batch_path)
+ set_variable_value_to_job_config(job_config, output_batch_path_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=output_batch.batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_ot_psi_data_join_configer_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_ot_psi_data_join_configer_test.py
new file mode 100644
index 000000000..24c622f54
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_ot_psi_data_join_configer_test.py
@@ -0,0 +1,201 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+import os
+import unittest
+from unittest.mock import patch
+
+from google.protobuf.struct_pb2 import Value
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobConfig, DatasetJobGlobalConfigs
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, DatasetKindV2, DatasetType
+from fedlearner_webconsole.dataset.job_configer.light_client_ot_psi_data_join_configer import \
+ LightClientOtPsiDataJoinConfiger
+
+
+class LightClientOtPsiDataJoinConfigerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+
+ test_project = Project(name='test_project')
+ session.add(test_project)
+ session.flush([test_project])
+
+ test_input_dataset = Dataset(id=2,
+ name='test_input_dataset',
+ uuid=resource_uuid(),
+ is_published=False,
+ project_id=test_project.id,
+ path='/data/dataset/test_input_dataset',
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(test_input_dataset)
+ session.flush([test_input_dataset])
+
+ test_input_data_batch = DataBatch(dataset_id=test_input_dataset.id,
+ path=os.path.join(test_input_dataset.path, 'batch/test_input_data_batch'))
+ session.add(test_input_data_batch)
+
+ test_output_dataset = Dataset(id=3,
+ name='test_output_dataset',
+ uuid=resource_uuid(),
+ is_published=True,
+ project_id=test_project.id,
+ path='/data/dataset/test_output_dataset',
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(test_output_dataset)
+ session.flush([test_output_dataset])
+
+ test_output_data_batch = DataBatch(dataset_id=test_output_dataset.id,
+ path=os.path.join(test_output_dataset.path,
+ 'batch/test_output_data_batch'))
+ session.add(test_output_data_batch)
+
+ test_input_streaming_dataset = Dataset(id=4,
+ name='test_input_dataset',
+ uuid=resource_uuid(),
+ is_published=False,
+ project_id=test_project.id,
+ path='/data/dataset/test_input_dataset',
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(test_input_streaming_dataset)
+ session.flush()
+
+ test_input_streaming_data_batch = DataBatch(dataset_id=test_input_streaming_dataset.id,
+ event_time=datetime(2022, 1, 1),
+ path=os.path.join(test_input_dataset.path, 'batch/20220101'))
+ session.add(test_input_streaming_data_batch)
+
+ test_output_streaming_dataset = Dataset(id=5,
+ name='test_output_dataset',
+ uuid=resource_uuid(),
+ is_published=True,
+ project_id=test_project.id,
+ path='/data/dataset/test_output_dataset',
+ dataset_type=DatasetType.STREAMING,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(test_output_streaming_dataset)
+ session.flush()
+
+ test_output_streaming_data_batch = DataBatch(dataset_id=test_output_streaming_dataset.id,
+ event_time=datetime(2022, 1, 1),
+ path=os.path.join(test_output_dataset.path, 'batch/20220101'))
+ session.add(test_output_streaming_data_batch)
+
+ self._input_dataset_uuid = test_input_dataset.uuid
+ self._output_dataset_uuid = test_output_dataset.uuid
+ self._input_streaming_dataset_uuid = test_input_streaming_dataset.uuid
+ self._output_streaming_dataset_uuid = test_output_streaming_dataset.uuid
+
+ session.commit()
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ def test_light_client_ot_psi_data_join(self):
+
+ with db.session_scope() as session:
+ global_configs = LightClientOtPsiDataJoinConfiger(session).auto_config_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid=self._input_dataset_uuid)}))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [])
+
+ with db.session_scope() as session:
+ global_configs = LightClientOtPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain':
+ DatasetJobConfig(dataset_uuid=self._input_dataset_uuid,
+ variables=[
+ Variable(name='input_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='input_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ typed_value=Value(string_value=''),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_batch_path',
+ typed_value=Value(string_value=''),
+ value_type=Variable.ValueType.STRING),
+ ])
+ }),
+ result_dataset_uuid=self._output_dataset_uuid)
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='input_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/test_input_data_batch'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=Variable.ValueType.STRING),
+ Variable(
+ name='output_batch_path',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/test_output_data_batch'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='test_output_data_batch',
+ typed_value=Value(string_value='test_output_data_batch'),
+ value_type=Variable.ValueType.STRING),
+ ])
+ # test with event_time
+ with db.session_scope() as session:
+ global_configs = LightClientOtPsiDataJoinConfiger(session).config_local_variables(
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={
+ 'test_domain': DatasetJobConfig(dataset_uuid=self._input_streaming_dataset_uuid),
+ 'test_domain_2': DatasetJobConfig(dataset_uuid='u12345')
+ }),
+ result_dataset_uuid=self._output_streaming_dataset_uuid,
+ event_time=datetime(2022, 1, 1))
+ self.assertEqual(list(global_configs.global_configs['test_domain'].variables), [
+ Variable(name='input_dataset_path',
+ value='/data/dataset/test_input_dataset',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='input_batch_path',
+ value='/data/dataset/test_input_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_input_dataset/batch/20220101'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_dataset_path',
+ value='/data/dataset/test_output_dataset',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_batch_path',
+ value='/data/dataset/test_output_dataset/batch/20220101',
+ typed_value=Value(string_value='/data/dataset/test_output_dataset/batch/20220101'),
+ value_type=Variable.ValueType.STRING),
+ Variable(name='output_batch_name',
+ value='20220101',
+ typed_value=Value(string_value='20220101'),
+ value_type=Variable.ValueType.STRING),
+ ])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_rsa_psi_data_join_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_rsa_psi_data_join_configer.py
new file mode 100644
index 000000000..75f623505
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/light_client_rsa_psi_data_join_configer.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+from typing import List, Optional
+
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class LightClientRsaPsiDataJoinConfiger(BaseConfiger):
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(self._session).get_workflow_template(name='sys-preset-light-psi-data-join')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ # return all variables and frontend will filter them by tag
+ return list(zip_workflow_variables(self.get_config()))
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ output_dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ input_batch = self._get_data_batch(input_dataset, event_time)
+ output_batch = self._get_data_batch(output_dataset, event_time)
+
+ input_batch_path = get_batch_data_path(input_batch)
+ output_batch_path = output_batch.path
+
+ input_dataset_path_variable = make_variable(name='input_dataset_path', typed_value=input_dataset.path)
+ set_variable_value_to_job_config(job_config, input_dataset_path_variable)
+
+ input_batch_path_variable = make_variable(name='input_batch_path', typed_value=input_batch_path)
+ set_variable_value_to_job_config(job_config, input_batch_path_variable)
+
+ output_dataset_path_variable = make_variable(name='output_dataset_path', typed_value=output_dataset.path)
+ set_variable_value_to_job_config(job_config, output_dataset_path_variable)
+
+ output_batch_path_variable = make_variable(name='output_batch_path', typed_value=output_batch_path)
+ set_variable_value_to_job_config(job_config, output_batch_path_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=output_batch.batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/ot_psi_data_join_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/ot_psi_data_join_configer.py
new file mode 100644
index 000000000..eef1dfdc4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/ot_psi_data_join_configer.py
@@ -0,0 +1,94 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+from typing import List, Optional
+from google.protobuf.struct_pb2 import Value
+
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class OtPsiDataJoinConfiger(BaseConfiger):
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(self._session).get_workflow_template(name='sys-preset-ot-psi-analyzer')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ # return all variables and frontend will filter them by tag
+ return list(zip_workflow_variables(self.get_config()))
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+ dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ for pure_domain_name, job_config in global_configs.global_configs.items():
+ role_variable = make_variable(name='role', typed_value='server')
+ if pure_domain_name != my_domain_name:
+ role_variable.typed_value.CopyFrom(Value(string_value='client'))
+ set_variable_value_to_job_config(job_config, role_variable)
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ output_dataset = self._session.query(Dataset).filter(Dataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ input_batch = self._get_data_batch(input_dataset, event_time)
+ output_batch = self._get_data_batch(output_dataset, event_time)
+
+ input_batch_path = get_batch_data_path(input_batch)
+ output_batch_path = output_batch.path
+
+ input_dataset_path_variable = make_variable(name='input_dataset_path', typed_value=input_dataset.path)
+ set_variable_value_to_job_config(job_config, input_dataset_path_variable)
+
+ input_batch_path_variable = make_variable(name='input_batch_path', typed_value=input_batch_path)
+ set_variable_value_to_job_config(job_config, input_batch_path_variable)
+
+ output_dataset_path_variable = make_variable(name='output_dataset_path', typed_value=output_dataset.path)
+ set_variable_value_to_job_config(job_config, output_dataset_path_variable)
+
+ output_batch_path_variable = make_variable(name='output_batch_path', typed_value=output_batch_path)
+ set_variable_value_to_job_config(job_config, output_batch_path_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=output_batch.batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/rsa_psi_data_join_configer.py b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/rsa_psi_data_join_configer.py
new file mode 100644
index 000000000..c694b780e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/job_configer/rsa_psi_data_join_configer.py
@@ -0,0 +1,134 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+from typing import Tuple, List, Optional
+import rsa
+
+from fedlearner_webconsole.dataset.models import Dataset, ProcessedDataset
+from fedlearner_webconsole.dataset.job_configer.base_configer import BaseConfiger, get_my_pure_domain_name, \
+ set_variable_value_to_job_config
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobGlobalConfigs
+from fedlearner_webconsole.utils.workflow import zip_workflow_variables
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateService
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto import common_pb2, workflow_definition_pb2
+
+
+class RsaPsiDataJoinConfiger(BaseConfiger):
+ USER_VARIABLES_NAME_SET = {
+ 'fedlearner_image_version',
+ 'num_partitions',
+ 'raw_worker_cpu',
+ 'raw_worker_mem',
+ 'batch_size',
+ 'psi_worker_cpu',
+ 'psi_worker_mem',
+ 'master_cpu',
+ 'master_mem',
+ }
+
+ @staticmethod
+ def _generate_rsa_key_pair(length: int = 1024) -> Tuple[str, str]:
+ """generate rsa key pair in pem format
+
+ Args:
+ length (int, optional): bits for generate private key. Defaults to 1024.
+
+ Returns:
+ Tuple[str, str]: PublicKey PEM, PrivateKey PEM
+ """
+ # Note that we generate rsa keys in current thread, which will slow down api response.
+ # DONT USE POOLSIZE ARGUMENT!!!!
+ # TODO(wangsen.0914): optimize this using async job_scheduler to call this func.
+ public_key, private_key = rsa.newkeys(length)
+ return public_key.save_pkcs1(format='PEM').decode(), private_key.save_pkcs1(format='PEM').decode()
+
+ def get_config(self) -> workflow_definition_pb2.WorkflowDefinition:
+ template = WorkflowTemplateService(
+ self._session).get_workflow_template(name='sys-preset-psi-data-join-analyzer')
+ return template.get_config()
+
+ @property
+ def user_variables(self) -> List[common_pb2.Variable]:
+ real_user_variables = []
+ for variable in zip_workflow_variables(self.get_config()):
+ if variable.name in self.USER_VARIABLES_NAME_SET:
+ real_user_variables.append(variable)
+
+ return real_user_variables
+
+ def auto_config_variables(self, global_configs: DatasetJobGlobalConfigs) -> DatasetJobGlobalConfigs:
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+ dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+
+ public_key_pem, private_key_pem = self._generate_rsa_key_pair()
+ for pure_domain_name, job_config in global_configs.global_configs.items():
+ if pure_domain_name == my_domain_name:
+ role_variable = make_variable(name='role', typed_value='Leader')
+ set_variable_value_to_job_config(job_config, role_variable)
+ rsa_key_pem_variable = make_variable(name='rsa_key_pem', typed_value=private_key_pem)
+ set_variable_value_to_job_config(job_config, rsa_key_pem_variable)
+
+ else:
+ role_variable = make_variable(name='role', typed_value='Follower')
+ set_variable_value_to_job_config(job_config, role_variable)
+ rsa_key_pem_variable = make_variable(name='rsa_key_pem', typed_value=public_key_pem)
+ set_variable_value_to_job_config(job_config, rsa_key_pem_variable)
+
+ rsa_key_path_variable = make_variable(name='rsa_key_path', typed_value='')
+ set_variable_value_to_job_config(job_config, rsa_key_path_variable)
+ return global_configs
+
+ def config_local_variables(self,
+ global_configs: DatasetJobGlobalConfigs,
+ result_dataset_uuid: str,
+ event_time: Optional[datetime] = None) -> DatasetJobGlobalConfigs:
+
+ my_domain_name = get_my_pure_domain_name()
+ job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == job_config.dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {job_config.dataset_uuid}')
+ output_dataset = self._session.query(ProcessedDataset).filter(
+ ProcessedDataset.uuid == result_dataset_uuid).first()
+ if output_dataset is None:
+ raise InvalidArgumentException(details=f'failed to find dataset {result_dataset_uuid}')
+
+ input_batch = self._get_data_batch(input_dataset, event_time)
+ output_batch = self._get_data_batch(output_dataset, event_time)
+
+ input_batch_path = get_batch_data_path(input_batch)
+
+ dataset_variable = make_variable(name='dataset', typed_value=input_batch_path)
+ set_variable_value_to_job_config(job_config, dataset_variable)
+
+ output_dataset_path_variable = make_variable(name='output_dataset_path', typed_value=output_dataset.path)
+ set_variable_value_to_job_config(job_config, output_dataset_path_variable)
+
+ output_batch_path_variable = make_variable(name='output_batch_path', typed_value=output_batch.path)
+ set_variable_value_to_job_config(job_config, output_batch_path_variable)
+
+ output_batch_name_variable = make_variable(name='output_batch_name', typed_value=output_batch.batch_name)
+ set_variable_value_to_job_config(job_config, output_batch_name_variable)
+
+ return global_configs
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/local_controllers.py b/web_console_v2/api/fedlearner_webconsole/dataset/local_controllers.py
new file mode 100644
index 000000000..6e755db49
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/local_controllers.py
@@ -0,0 +1,251 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+import logging
+from typing import Optional
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.dataset.models import DataBatch, DatasetJob, DatasetJobStage, DatasetJobState, \
+ DatasetType
+from fedlearner_webconsole.dataset.services import BatchService, DatasetJobStageService, DatasetService
+from fedlearner_webconsole.proto.dataset_pb2 import BatchParameter, DatasetJobGlobalConfigs, CronType
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.workflow.workflow_controller import start_workflow_locally, stop_workflow_locally
+
+
+class DatasetJobStageLocalController(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+ self._dataset_job_stage_service = DatasetJobStageService(session)
+
+ def start(self, dataset_job_stage: DatasetJobStage):
+ """start dataset job stage task locally
+
+ 1. start related workflow locally
+ 2. set dataset job stage's state to RUNNING
+ """
+ start_workflow_locally(self._session, dataset_job_stage.workflow)
+ self._dataset_job_stage_service.start_dataset_job_stage(dataset_job_stage=dataset_job_stage)
+ logging.info('[dataset_job_stage_local_controller]: start successfully')
+
+ def stop(self, dataset_job_stage: DatasetJobStage):
+ """stop dataset job stage task locally
+
+ 1. stop related workflow locally
+ 2. set dataset job stage's state to STOPPED
+ """
+ if dataset_job_stage.workflow is not None:
+ stop_workflow_locally(self._session, dataset_job_stage.workflow)
+ else:
+ logging.info(f'workflow not found, just skip! workflow id: {dataset_job_stage.workflow_id}')
+ self._dataset_job_stage_service.finish_dataset_job_stage(dataset_job_stage=dataset_job_stage,
+ finish_state=DatasetJobState.STOPPED)
+ logging.info('[dataset_job_stage_local_controller]: stop successfully')
+
+ # TODO(liuhehan): delete in the near future after we use as_coordinator func
+ def create_data_batch_and_job_stage(self,
+ dataset_job_id: int,
+ event_time: Optional[datetime] = None,
+ uuid: Optional[str] = None,
+ name: Optional[str] = None) -> Optional[DatasetJobStage]:
+ """create data_batch and job_stage locally
+
+ UseCase 1: create new data_batch and new job_stage with given uuid and name:
+ only called as role of participants, uuid and name are given by coordinator
+ will create both data_batch and job_stage
+
+ Parameters:
+ dataset_job_id(int): dataset_job id
+ event_time(datetime): optional; only works in STREAMING dataset_job,
+ event_time of current data_batch and job_stage
+ uuid(str): uuid of dataset_job_stage
+ name(str): name of dataset_job_stage
+
+ Returns:
+ dataset_job_stage(DatasetJobStage): dataset_job_stage which created in func
+
+ UseCase 2: create new data_batch and new job_stage for PSI/STREAMING dataset_job:
+ only called as role of coordinator
+ will create both data_batch and job_stage
+
+ Parameters:
+ dataset_job_id(int): dataset_job id
+ event_time(datetime): optional; only works in STREAMING dataset_job,
+ event_time of current data_batch and job_stage
+
+ Returns:
+ dataset_job_stage(DatasetJobStage): dataset_job_stage which created in func
+
+ UseCase 3: rerun data_batch:
+ called to create a new job_stage when data_batch failed
+ will create only dataset_job_stage if find target data_batch
+
+ Parameters:
+ dataset_job_id(int): dataset_job id
+ event_time(datetime): optional; only works in STREAMING dataset_job,
+ event_time of current data_batch and job_stage
+
+ Returns:
+ dataset_job_stage(DatasetJobStage): dataset_job_stage which created in func
+ """
+ dataset_job: DatasetJob = self._session.query(DatasetJob).get(dataset_job_id)
+ if dataset_job.output_dataset.dataset_type == DatasetType.STREAMING:
+ data_batch = self._session.query(DataBatch).filter(
+ DataBatch.dataset_id == dataset_job.output_dataset_id).filter(
+ DataBatch.event_time == event_time).first()
+ else:
+ data_batch = self._session.query(DataBatch).filter(
+ DataBatch.dataset_id == dataset_job.output_dataset_id).first()
+ # create data_batch if not exist:
+ if data_batch is None:
+ batch_parameter = BatchParameter(dataset_id=dataset_job.output_dataset_id)
+ if event_time:
+ batch_parameter.event_time = to_timestamp(event_time)
+ data_batch = BatchService(self._session).create_batch(batch_parameter=batch_parameter)
+ self._session.flush()
+ dataset_job_stage = None
+ if uuid:
+ dataset_job_stage = self._session.query(DatasetJobStage).filter(DatasetJobStage.uuid == uuid).first()
+ # for idempotent, skip if dataset_job_stage exists:
+ if dataset_job_stage is None:
+ dataset_job_stage = self._dataset_job_stage_service.create_dataset_job_stage(
+ project_id=dataset_job.project_id,
+ dataset_job_id=dataset_job_id,
+ output_data_batch_id=data_batch.id,
+ uuid=uuid,
+ name=name)
+ return dataset_job_stage
+
+ def create_data_batch_and_job_stage_as_coordinator(
+ self,
+ dataset_job_id: int,
+ global_configs: DatasetJobGlobalConfigs,
+ event_time: Optional[datetime] = None) -> Optional[DatasetJobStage]:
+ """create data_batch and job_stage locally as coordinator
+
+ UseCase 1: create new data_batch and new job_stage for PSI/STREAMING dataset:
+ only called as role of coordinator
+ will create both data_batch and job_stage
+
+ Parameters:
+ dataset_job_id(int): dataset_job id
+ global_configs(global_configs): configs of all participants for this dataset_job_stage
+ event_time(datetime): optional; only works in STREAMING dataset,
+ event_time of current data_batch and job_stage
+
+ Returns:
+ dataset_job_stage(DatasetJobStage): dataset_job_stage which created in func
+
+ UseCase 2: rerun data_batch:
+ called to create a new job_stage when data_batch failed,
+ will create only dataset_job_stage if find target data_batch
+
+ Parameters:
+ dataset_job_id(int): dataset_job id
+ global_configs(global_configs): configs of all participants for this dataset_job_stage
+ event_time(datetime): optional; only works in STREAMING dataset,
+ event_time of current data_batch and job_stage
+
+ Returns:
+ dataset_job_stage(DatasetJobStage): dataset_job_stage which created in func
+ """
+ dataset_job: DatasetJob = self._session.query(DatasetJob).get(dataset_job_id)
+ data_batch = DatasetService(session=self._session).get_data_batch(dataset=dataset_job.output_dataset,
+ event_time=event_time)
+ # create data_batch if not exist:
+ if data_batch is None:
+ batch_parameter = BatchParameter(dataset_id=dataset_job.output_dataset_id)
+ if event_time:
+ batch_parameter.event_time = to_timestamp(event_time)
+ if dataset_job.is_daily_cron():
+ batch_parameter.cron_type = CronType.DAILY
+ elif dataset_job.is_hourly_cron():
+ batch_parameter.cron_type = CronType.HOURLY
+ data_batch = BatchService(self._session).create_batch(batch_parameter=batch_parameter)
+ self._session.flush()
+ dataset_job_stage = self._dataset_job_stage_service.create_dataset_job_stage_as_coordinator(
+ project_id=dataset_job.project_id,
+ dataset_job_id=dataset_job_id,
+ output_data_batch_id=data_batch.id,
+ global_configs=global_configs)
+ return dataset_job_stage
+
+ def create_data_batch_and_job_stage_as_participant(self,
+ dataset_job_id: int,
+ coordinator_id: int,
+ uuid: str,
+ name: str,
+ event_time: Optional[datetime] = None
+ ) -> Optional[DatasetJobStage]:
+ """create data_batch and job_stage locally as participant
+
+ UseCase 1: create new data_batch and new job_stage with given uuid and name:
+ only called as role of participants, uuid and name are given by coordinator.
+ will create both data_batch and job_stage
+
+ Parameters:
+ dataset_job_id(int): dataset_job id
+ coordinator_id(int): id of coordinator
+ uuid(str): uuid of dataset_job_stage
+ name(str): name of dataset_job_stage
+ event_time(datetime): optional; only works in STREAMING dataset,
+ event_time of current data_batch and job_stage
+
+ Returns:
+ dataset_job_stage(DatasetJobStage): dataset_job_stage which created in func
+
+ UseCase 2: rerun data_batch:
+ only called as role of participants, uuid and name are given by coordinator.
+ aim to create a new job_stage when data_batch failed,
+ will create only dataset_job_stage if find target data_batch
+
+ Parameters:
+ dataset_job_id(int): dataset_job id
+ coordinator_id(int): id of coordinator
+ uuid(str): uuid of dataset_job_stage
+ name(str): name of dataset_job_stage
+ event_time(datetime): optional; only works in STREAMING dataset,
+ event_time of current data_batch and job_stage
+
+ Returns:
+ dataset_job_stage(DatasetJobStage): dataset_job_stage which created in func
+ """
+ dataset_job: DatasetJob = self._session.query(DatasetJob).get(dataset_job_id)
+ data_batch = DatasetService(session=self._session).get_data_batch(dataset=dataset_job.output_dataset,
+ event_time=event_time)
+ # create data_batch if not exist:
+ if data_batch is None:
+ batch_parameter = BatchParameter(dataset_id=dataset_job.output_dataset_id)
+ if event_time:
+ batch_parameter.event_time = to_timestamp(event_time)
+ if dataset_job.is_daily_cron():
+ batch_parameter.cron_type = CronType.DAILY
+ elif dataset_job.is_hourly_cron():
+ batch_parameter.cron_type = CronType.HOURLY
+ data_batch = BatchService(self._session).create_batch(batch_parameter=batch_parameter)
+ self._session.flush()
+ dataset_job_stage = self._session.query(DatasetJobStage).filter(DatasetJobStage.uuid == uuid).first()
+ # for idempotent, skip if dataset_job_stage exists:
+ if dataset_job_stage is None:
+ dataset_job_stage = self._dataset_job_stage_service.create_dataset_job_stage_as_participant(
+ project_id=dataset_job.project_id,
+ dataset_job_id=dataset_job_id,
+ output_data_batch_id=data_batch.id,
+ uuid=uuid,
+ name=name,
+ coordinator_id=coordinator_id)
+ return dataset_job_stage
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/local_controllers_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/local_controllers_test.py
new file mode 100644
index 000000000..b88782542
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/local_controllers_test.py
@@ -0,0 +1,526 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime, timedelta, timezone
+import unittest
+from unittest.mock import patch, MagicMock
+
+from testing.common import NoWebServerTestCase
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.dataset.local_controllers import DatasetJobStageLocalController
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, DatasetJob, DatasetJobKind, DatasetJobStage, \
+ DatasetJobState, DatasetKindV2, DatasetType
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+
+
+class DatasetJobStageLocalControllerTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.dataset.local_controllers.start_workflow_locally')
+ @patch('fedlearner_webconsole.dataset.local_controllers.DatasetJobStageService.start_dataset_job_stage')
+ def test_start(self, mock_start_dataset_job_stage: MagicMock, mock_start_workflow_locally: MagicMock):
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStage(uuid=resource_uuid(),
+ project_id=1,
+ workflow_id=1,
+ dataset_job_id=1,
+ data_batch_id=1)
+ DatasetJobStageLocalController(session=session).start(dataset_job_stage=dataset_job_stage)
+ mock_start_workflow_locally.assert_called_once()
+ mock_start_dataset_job_stage.assert_called_once_with(dataset_job_stage=dataset_job_stage)
+
+ @patch('fedlearner_webconsole.dataset.local_controllers.stop_workflow_locally')
+ @patch('fedlearner_webconsole.dataset.local_controllers.DatasetJobStageService.finish_dataset_job_stage')
+ def test_stop(self, mock_finish_dataset_job_stage: MagicMock, mock_stop_workflow_locally: MagicMock):
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStage(uuid=resource_uuid(),
+ project_id=1,
+ workflow_id=1,
+ dataset_job_id=1,
+ data_batch_id=1)
+ session.add(dataset_job_stage)
+
+ dataset_job_stage_local_controller = DatasetJobStageLocalController(session=session)
+ # test no worlflow
+ dataset_job_stage_local_controller.stop(dataset_job_stage=dataset_job_stage)
+ mock_stop_workflow_locally.assert_not_called()
+ mock_finish_dataset_job_stage.assert_called_once_with(dataset_job_stage=dataset_job_stage,
+ finish_state=DatasetJobState.STOPPED)
+
+ # test has workflow
+ mock_stop_workflow_locally.reset_mock()
+ mock_finish_dataset_job_stage.reset_mock()
+ workflow = Workflow(id=1)
+ session.add(workflow)
+ session.flush()
+ dataset_job_stage_local_controller.stop(dataset_job_stage=dataset_job_stage)
+ mock_stop_workflow_locally.assert_called_once()
+ mock_finish_dataset_job_stage.assert_called_once_with(dataset_job_stage=dataset_job_stage,
+ finish_state=DatasetJobState.STOPPED)
+
+ def test_create_data_batch_and_job_stage(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING)
+ dataset_job.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job)
+ output_dataset = Dataset(id=2,
+ uuid='output_dataset uuid',
+ name='output_dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(output_dataset)
+ session.commit()
+
+ # test PSI
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage(
+ dataset_job_id=1)
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertIsNone(dataset_job_stage.event_time)
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/0')
+
+ # test PSI has batch
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1, name='test_data_batch', dataset_id=2, path='/data/test/batch/0')
+ session.add(data_batch)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage(
+ dataset_job_id=1)
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertIsNone(dataset_job_stage.event_time)
+ self.assertEqual(dataset_job_stage.data_batch.name, 'test_data_batch')
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/test/batch/0')
+
+ # test PSI has batch and stage
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1, name='test_data_batch', dataset_id=2, path='/data/test/batch/0')
+ session.add(data_batch)
+ dataset_job_stage = DatasetJobStage(id=100,
+ name='test_dataset_job',
+ uuid='test_dataset_job uuid',
+ project_id=1,
+ workflow_id=0,
+ dataset_job_id=1,
+ data_batch_id=1)
+ session.add(dataset_job_stage)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage(
+ dataset_job_id=1, uuid='test_dataset_job uuid', name='test_dataset_job')
+ self.assertEqual(dataset_job_stage.id, 100)
+
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(2)
+ dataset.dataset_type = DatasetType.STREAMING
+ session.commit()
+
+ # test STREAMING
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage(
+ dataset_job_id=1, event_time=datetime(2022, 1, 1))
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1).replace(tzinfo=timezone.utc))
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/20220101')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1))
+
+ # test STREAMING has batch
+ with db.session_scope() as session:
+ data_batch_1 = DataBatch(id=1,
+ name='test_data_batch 1',
+ dataset_id=2,
+ path='/data/test/batch/20220101',
+ event_time=datetime(2022, 1, 1))
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(id=2,
+ name='test_data_batch 2',
+ dataset_id=2,
+ path='/data/test/batch/20220102',
+ event_time=datetime(2022, 1, 2))
+ session.add(data_batch_2)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage(
+ dataset_job_id=1, event_time=datetime(2022, 1, 1))
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1))
+ self.assertEqual(dataset_job_stage.data_batch.name, 'test_data_batch 1')
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/test/batch/20220101')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1))
+
+ # test STREAMING has batch and stage
+ with db.session_scope() as session:
+ data_batch_1 = DataBatch(id=1,
+ name='test_data_batch 1',
+ dataset_id=2,
+ path='/data/test/batch/20220101',
+ event_time=datetime(2022, 1, 1))
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(id=2,
+ name='test_data_batch 2',
+ dataset_id=2,
+ path='/data/test/batch/20220102',
+ event_time=datetime(2022, 1, 2))
+ session.add(data_batch_2)
+ dataset_job_stage = DatasetJobStage(id=100,
+ name='test_dataset_job',
+ uuid='test_dataset_job uuid',
+ project_id=1,
+ workflow_id=0,
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 2))
+ session.add(dataset_job_stage)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage(
+ dataset_job_id=1,
+ event_time=datetime(2022, 1, 1),
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job')
+ self.assertEqual(dataset_job_stage.id, 100)
+
+ def test_create_data_batch_and_job_stage_as_coordinator(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING)
+ session.add(dataset_job)
+ output_dataset = Dataset(id=2,
+ uuid='output_dataset uuid',
+ name='output_dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(output_dataset)
+ session.commit()
+
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()})
+ # test PSI
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_coordinator(dataset_job_id=1,
+ global_configs=global_configs)
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertIsNone(dataset_job_stage.event_time)
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/0')
+ self.assertEqual(dataset_job_stage.get_global_configs(), global_configs)
+ self.assertTrue(dataset_job_stage.is_coordinator())
+
+ # test PSI has batch
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1, name='test_data_batch', dataset_id=2, path='/data/test/batch/0')
+ session.add(data_batch)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_coordinator(dataset_job_id=1,
+ global_configs=global_configs)
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertIsNone(dataset_job_stage.event_time)
+ self.assertEqual(dataset_job_stage.data_batch.name, 'test_data_batch')
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/test/batch/0')
+ self.assertEqual(dataset_job_stage.get_global_configs(), global_configs)
+ self.assertTrue(dataset_job_stage.is_coordinator())
+
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(2)
+ dataset.dataset_type = DatasetType.STREAMING
+ session.commit()
+
+ # test STREAMING
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_coordinator(dataset_job_id=1,
+ global_configs=global_configs,
+ event_time=datetime(2022, 1, 1))
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1).replace(tzinfo=timezone.utc))
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/20220101')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1))
+ self.assertEqual(dataset_job_stage.get_global_configs(), global_configs)
+ self.assertTrue(dataset_job_stage.is_coordinator())
+
+ # test STREAMING has batch
+ with db.session_scope() as session:
+ data_batch_1 = DataBatch(id=1,
+ name='test_data_batch 1',
+ dataset_id=2,
+ path='/data/test/batch/20220101',
+ event_time=datetime(2022, 1, 1))
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(id=2,
+ name='test_data_batch 2',
+ dataset_id=2,
+ path='/data/test/batch/20220102',
+ event_time=datetime(2022, 1, 2))
+ session.add(data_batch_2)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_coordinator(dataset_job_id=1,
+ global_configs=global_configs,
+ event_time=datetime(2022, 1, 1))
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1))
+ self.assertEqual(dataset_job_stage.data_batch.name, 'test_data_batch 1')
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/test/batch/20220101')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1))
+ self.assertEqual(dataset_job_stage.get_global_configs(), global_configs)
+ self.assertTrue(dataset_job_stage.is_coordinator())
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.time_range = timedelta(hours=1)
+ session.commit()
+
+ # test STREAMING in hourly level
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_coordinator(dataset_job_id=1,
+ global_configs=global_configs,
+ event_time=datetime(2022, 1, 1, 8))
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1, 8).replace(tzinfo=timezone.utc))
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/20220101-08')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1, 8))
+ self.assertEqual(dataset_job_stage.get_global_configs(), global_configs)
+ self.assertTrue(dataset_job_stage.is_coordinator())
+
+ def test_create_data_batch_and_job_stage_as_participant(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING)
+ session.add(dataset_job)
+ output_dataset = Dataset(id=2,
+ uuid='output_dataset uuid',
+ name='output_dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(output_dataset)
+ session.commit()
+
+ # test PSI
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_participant(dataset_job_id=1,
+ coordinator_id=1,
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job')
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertIsNone(dataset_job_stage.event_time)
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/0')
+ self.assertIsNone(dataset_job_stage.global_configs)
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+
+ # test PSI has batch
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1, name='test_data_batch', dataset_id=2, path='/data/test/batch/0')
+ session.add(data_batch)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_participant(dataset_job_id=1,
+ coordinator_id=1,
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job')
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertIsNone(dataset_job_stage.event_time)
+ self.assertEqual(dataset_job_stage.data_batch.name, 'test_data_batch')
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/test/batch/0')
+ self.assertIsNone(dataset_job_stage.global_configs)
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+
+ # test PSI has batch and stage
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1, name='test_data_batch', dataset_id=2, path='/data/test/batch/0')
+ session.add(data_batch)
+ dataset_job_stage = DatasetJobStage(id=100,
+ name='test_dataset_job',
+ uuid='test_dataset_job uuid',
+ project_id=1,
+ workflow_id=0,
+ dataset_job_id=1,
+ data_batch_id=1,
+ coordinator_id=1)
+ session.add(dataset_job_stage)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_participant(dataset_job_id=1,
+ coordinator_id=1,
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job')
+ self.assertEqual(dataset_job_stage.id, 100)
+
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(2)
+ dataset.dataset_type = DatasetType.STREAMING
+ session.commit()
+
+ # test STREAMING
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_participant(dataset_job_id=1,
+ coordinator_id=1,
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job',
+ event_time=datetime(2022, 1, 1))
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1).replace(tzinfo=timezone.utc))
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/20220101')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1))
+ self.assertIsNone(dataset_job_stage.global_configs)
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+
+ # test STREAMING has batch
+ with db.session_scope() as session:
+ data_batch_1 = DataBatch(id=1,
+ name='test_data_batch 1',
+ dataset_id=2,
+ path='/data/test/batch/20220101',
+ event_time=datetime(2022, 1, 1))
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(id=2,
+ name='test_data_batch 2',
+ dataset_id=2,
+ path='/data/test/batch/20220102',
+ event_time=datetime(2022, 1, 2))
+ session.add(data_batch_2)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_participant(dataset_job_id=1,
+ coordinator_id=1,
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job',
+ event_time=datetime(2022, 1, 1))
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1))
+ self.assertEqual(dataset_job_stage.data_batch.name, 'test_data_batch 1')
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/test/batch/20220101')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1))
+ self.assertIsNone(dataset_job_stage.global_configs)
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+
+ # test STREAMING has batch and stage
+ with db.session_scope() as session:
+ data_batch_1 = DataBatch(id=1,
+ name='test_data_batch 1',
+ dataset_id=2,
+ path='/data/test/batch/20220101',
+ event_time=datetime(2022, 1, 1))
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(id=2,
+ name='test_data_batch 2',
+ dataset_id=2,
+ path='/data/test/batch/20220102',
+ event_time=datetime(2022, 1, 2))
+ session.add(data_batch_2)
+ dataset_job_stage = DatasetJobStage(id=100,
+ name='test_dataset_job',
+ uuid='test_dataset_job uuid',
+ project_id=1,
+ workflow_id=0,
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2022, 1, 2),
+ coordinator_id=1)
+ session.add(dataset_job_stage)
+ session.flush()
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_participant(dataset_job_id=1,
+ coordinator_id=1,
+ event_time=datetime(2022, 1, 1),
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job')
+ self.assertEqual(dataset_job_stage.id, 100)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.time_range = timedelta(hours=1)
+ session.commit()
+
+ # test STREAMING in hourly level
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageLocalController(
+ session=session).create_data_batch_and_job_stage_as_participant(dataset_job_id=1,
+ coordinator_id=1,
+ event_time=datetime(2022, 1, 1, 8),
+ uuid='test_dataset_job uuid',
+ name='test_dataset_job')
+ session.flush()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1, 8).replace(tzinfo=timezone.utc))
+ self.assertEqual(dataset_job_stage.data_batch.dataset_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch.path, '/data/dataset/123/batch/20220101-08')
+ self.assertEqual(dataset_job_stage.data_batch.event_time, datetime(2022, 1, 1, 8))
+ self.assertIsNone(dataset_job_stage.global_configs)
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/meta_data.py b/web_console_v2/api/fedlearner_webconsole/dataset/meta_data.py
new file mode 100644
index 000000000..fcdd31113
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/meta_data.py
@@ -0,0 +1,207 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+import dateutil.parser
+from typing import List, Optional, Dict, Any
+
+
+class MetaData(object):
+ _DTYPES = 'dtypes'
+ _SAMPLE = 'sample'
+ _FEATURES = 'features'
+ _HIST = 'hist'
+ _COUNT = 'count'
+
+ def __init__(self, metadata: Optional[dict] = None):
+ self.metadata = metadata or {}
+
+ @property
+ def dtypes(self) -> List[Any]:
+ return self.metadata.get(self._DTYPES, [])
+
+ @property
+ def sample(self) -> List[Any]:
+ return self.metadata.get(self._SAMPLE, [])
+
+ @property
+ def metrics(self) -> Dict[str, Dict[Any, Any]]:
+ return self.metadata.get(self._FEATURES, {})
+
+ @property
+ def hist(self) -> Dict[str, Dict[Any, Any]]:
+ return self.metadata.get(self._HIST, {})
+
+ @property
+ def num_feature(self) -> int:
+ return len(self.dtypes)
+
+ @property
+ def num_example(self) -> int:
+ return self.metadata.get(self._COUNT, 0)
+
+ def get_metrics_by_name(self, name: str) -> Dict[Any, Any]:
+ return self.metrics.get(name, {})
+
+ def get_hist_by_name(self, name: str) -> Dict[Any, Any]:
+ return self.hist.get(name, {})
+
+ def get_preview(self) -> dict:
+ """ get the preview data
+ Returns:
+ preview dict format:
+ {
+ 'dtypes': [
+ {'key': 'f01', 'value': 'bigint'}
+ ],
+ 'sample': [
+ [1],
+ [0],
+ ],
+ 'count': 1000
+ 'metrics': {
+ 'f01': {
+ 'count': '2',
+ 'mean': '0.0015716767309123998',
+ 'stddev': '0.03961485047808605',
+ 'min': '0',
+ 'max': '1',
+ 'missing_count': '0'
+ }
+ },
+ 'hist': {
+ 'x': [0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5,
+ 0.6000000000000001, 0.7000000000000001, 0.8, 0.9, 1],
+ 'y': [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
+ }
+ }
+ """
+ preview = {}
+ preview['dtypes'] = self.dtypes
+ preview['sample'] = self.sample
+ preview['num_example'] = self.num_example
+ preview['metrics'] = self.metrics
+ return preview
+
+
+class ImageMetaData(MetaData):
+ _LABEL_COUNT = 'label_count'
+ _THUMBNAIL_EXTENSION = '.png'
+
+ def __init__(self, thumbnail_dir_path: str, metadata: Optional[dict] = None):
+ super().__init__(metadata=metadata)
+ self.thumbnail_dir_path = thumbnail_dir_path
+
+ @property
+ def label_count(self) -> List[Any]:
+ return self.metadata.get(self._LABEL_COUNT, [])
+
+ def _get_column_idx(self, col_name: str):
+ col_idx = -1
+ for index, col_map in enumerate(self.dtypes):
+ if col_map['key'] == col_name:
+ col_idx = index
+ break
+ if col_idx < 0:
+ logging.warning(f'can\'t found the {col_name} column in dtypes:{self.dtypes}')
+ return col_idx
+
+ def _get_thumbnail_file_name(self, file_name: str) -> str:
+ thumbnail_file_name = file_name.split('.')[0] + self._THUMBNAIL_EXTENSION
+ return thumbnail_file_name
+
+ def get_preview(self) -> dict:
+ """ get the preview data
+ Returns:
+ preview dict format:
+ {
+ "dtypes": [
+ { "key": "file_name", "value": "string" },
+ { "key": "width", "value": "int" },
+ { "key": "height", "value": "int" },
+ { "key": "nChannels", "value": "int" },
+ { "key": "mode", "value": "int" },
+ { "key": "name", "value": "string" },
+ { "key": "created_at", "value": "string" },
+ { "key": "caption", "value": "string" },
+ { "key": "label", "value": "string" }
+ ],
+ "label_count": [
+ {
+ "label": "B",
+ "count": 1
+ },
+ ],
+ "count": 50,
+ "sample": [
+ [
+ "000000050576.jpg",
+ 640,
+ 480,
+ 3,
+ 16,
+ "000000050576.jpg",
+ "2021-08-30T16:52:15.501516",
+ "A tow truck loading a bank security truck by a building.",
+ "B"
+ ],
+ ...
+ ],
+ "features": {
+ "file_name": {
+ "count": "50",
+ "mean": null,
+ "stddev": null,
+ "min": "000000005756.jpg",
+ "max": "000000562222.jpg",
+ "missing_count": "0"
+ },
+ ...
+ },
+ "hist": {
+ "width": {
+ "x": [ 333.0, 363.7, 394.4, 425.1, 455.8, 486.5, 517.2, 547.9, 578.6, 609.3, 640.0 ],
+ "y": [ 1, 1, 4, 3, 4, 0, 0, 0, 36 ]
+ },
+ ...
+ }
+ }
+ """
+ preview = super().get_preview()
+ display_name_idx = self._get_column_idx('name')
+ file_name_idx = self._get_column_idx('file_name')
+ height_idx = self._get_column_idx('height')
+ width_idx = self._get_column_idx('width')
+ created_at_idx = self._get_column_idx('created_at')
+ label_idx = self._get_column_idx('label')
+ images = []
+ for sample in self.sample:
+ sample[created_at_idx] = dateutil.parser.isoparse(sample[created_at_idx]).strftime('%Y-%m-%d')
+ image = {
+ 'name': sample[display_name_idx],
+ 'file_name': sample[file_name_idx],
+ 'width': sample[width_idx],
+ 'height': sample[height_idx],
+ 'created_at': sample[created_at_idx],
+ # TODO(wangzeju): hard code for the classification task, need to support more image follow up tasks.
+ 'annotation': {
+ 'label': sample[label_idx]
+ },
+ 'path': os.path.join(self.thumbnail_dir_path, self._get_thumbnail_file_name(sample[file_name_idx]))
+ }
+ images.append(image)
+ preview['images'] = images
+ return preview
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/meta_data_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/meta_data_test.py
new file mode 100644
index 000000000..9d863086f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/meta_data_test.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import json
+from envs import Envs
+from fedlearner_webconsole.dataset.meta_data import ImageMetaData
+
+
+class ImageMetaDataTest(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.maxDiff = None
+ with open(f'{Envs.BASE_DIR}/testing/test_data/image_meta.json', mode='r', encoding='utf-8') as f:
+ self.image_data = json.load(f)
+ with open(f'{Envs.BASE_DIR}/testing/test_data/expected_image_preview.json', mode='r', encoding='utf-8') as f:
+ self.expected_image_preview = json.load(f)
+ self.thumbnail_dir_path = '/fake_dir/'
+
+ def test_image_preview(self):
+ image_meta = ImageMetaData(self.thumbnail_dir_path, self.image_data)
+ image_preview = image_meta.get_preview()
+ self.assertDictEqual(self.expected_image_preview, image_preview)
+
+ def test_empty_meta(self):
+ image_meta = ImageMetaData(self.thumbnail_dir_path, None)
+ image_preview = image_meta.get_preview()
+ expected_response = {'dtypes': [], 'sample': [], 'num_example': 0, 'metrics': {}, 'images': []}
+ self.assertDictEqual(expected_response, image_preview)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/metrics.py b/web_console_v2/api/fedlearner_webconsole/dataset/metrics.py
new file mode 100644
index 000000000..77e10f06c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/metrics.py
@@ -0,0 +1,39 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.dataset.models import DatasetJobKind, DatasetJobState
+
+
+def emit_dataset_job_submission_store(uuid: str, kind: DatasetJobKind, coordinator_id: int):
+ emit_store(name='dataset.job.submission',
+ value=1,
+ tags={
+ 'uuid': uuid,
+ 'kind': kind.name,
+ 'coordinator_id': str(coordinator_id),
+ })
+
+
+def emit_dataset_job_duration_store(duration: int, uuid: str, kind: DatasetJobKind, coordinator_id: int,
+ state: DatasetJobState):
+ emit_store(name='dataset.job.duration',
+ value=duration,
+ tags={
+ 'uuid': uuid,
+ 'kind': kind.name,
+ 'coordinator_id': str(coordinator_id),
+ 'state': state.name
+ })
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/metrics_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/metrics_test.py
new file mode 100644
index 000000000..3df3b1a6c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/metrics_test.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.dataset.metrics import emit_dataset_job_submission_store, emit_dataset_job_duration_store
+from fedlearner_webconsole.dataset.models import DatasetJobKind, DatasetJobState
+
+
+class MetricsTest(unittest.TestCase):
+
+ def test_emit_dataset_job_submission_store(self):
+ with self.assertLogs() as cm:
+ emit_dataset_job_submission_store('uuit-test', DatasetJobKind.IMPORT_SOURCE, 0)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, [
+ '[Metric][Store] dataset.job.submission: 1, tags={\'uuid\': \'uuit-test\', ' \
+ '\'kind\': \'IMPORT_SOURCE\', \'coordinator_id\': \'0\'}',
+ ])
+
+ def test_emit_dataset_job_duration_store(self):
+ with self.assertLogs() as cm:
+ emit_dataset_job_duration_store(1000, 'uuit-test', DatasetJobKind.RSA_PSI_DATA_JOIN, 1,
+ DatasetJobState.SUCCEEDED)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, [
+ '[Metric][Store] dataset.job.duration: 1000, tags={\'uuid\': \'uuit-test\', ' \
+ '\'kind\': \'RSA_PSI_DATA_JOIN\', \'coordinator_id\': \'1\', \'state\': \'SUCCEEDED\'}',
+ ])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/models.py b/web_console_v2/api/fedlearner_webconsole/dataset/models.py
index 981a00f04..c85c73d10 100644
--- a/web_console_v2/api/fedlearner_webconsole/dataset/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/models.py
@@ -1,29 +1,40 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-# coding: utf-8
+#
import enum
+import os
+from typing import Optional
+
from sqlalchemy.sql import func
from sqlalchemy import UniqueConstraint
-from fedlearner_webconsole.db import db
-from fedlearner_webconsole.utils.mixins import to_dict_mixin
+from google.protobuf import text_format
+from fedlearner_webconsole.dataset.consts import ERROR_BATCH_SIZE
from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.proto.dataset_pb2 import (DatasetJobGlobalConfigs, DatasetRef, DatasetMetaInfo,
+ DatasetJobContext, DatasetJobStageContext, TimeRange)
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_and_auth_model import ReviewTicketAndAuthModel
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.utils.base_model.softdelete_model import SoftDeleteModel
+from fedlearner_webconsole.workflow.models import WorkflowExternalState
class DatasetType(enum.Enum):
- PSI = 'PSI'
+ PSI = 'PSI' # use PSI as none streaming dataset type
STREAMING = 'STREAMING'
@@ -32,88 +43,149 @@ class BatchState(enum.Enum):
SUCCESS = 'SUCCESS'
FAILED = 'FAILED'
IMPORTING = 'IMPORTING'
+ UNKNOWN = 'UNKNOWN'
-@to_dict_mixin(
- extras={
- 'data_batches':
- lambda dataset:
- [data_batch.to_dict() for data_batch in dataset.data_batches]
- })
-class Dataset(db.Model):
- __tablename__ = 'datasets_v2'
- __table_args__ = ({
- 'comment': 'This is webconsole dataset table',
- 'mysql_engine': 'innodb',
- 'mysql_charset': 'utf8mb4',
- })
-
- id = db.Column(db.Integer,
- primary_key=True,
- autoincrement=True,
- comment='id')
- name = db.Column(db.String(255), nullable=False, comment='dataset name')
- dataset_type = db.Column(db.Enum(DatasetType, native_enum=False),
- nullable=False,
- comment='data type')
- path = db.Column(db.String(512), comment='dataset path')
- comment = db.Column('cmt',
- db.Text(),
- key='comment',
- comment='comment of dataset')
- created_at = db.Column(db.DateTime(timezone=True),
- server_default=func.now(),
- comment='created time')
- updated_at = db.Column(db.DateTime(timezone=True),
- server_default=func.now(),
- onupdate=func.now(),
- comment='updated time')
- deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted time')
- project_id = db.Column(db.Integer, default=0, comment='project_id')
+# used to represent dataset and data_batch frontend state
+class ResourceState(enum.Enum):
+ PENDING = 'PENDING'
+ PROCESSING = 'PROCESSING'
+ SUCCEEDED = 'SUCCEEDED'
+ FAILED = 'FAILED'
+
+
+class PublishFrontendState(enum.Enum):
+ UNPUBLISHED = 'UNPUBLISHED'
+ TICKET_PENDING = 'TICKET_PENDING'
+ TICKET_DECLINED = 'TICKET_DECLINED'
+ PUBLISHED = 'PUBLISHED'
+
+
+class DatasetFormat(enum.Enum):
+ TABULAR = 0
+ IMAGE = 1
+ NONE_STRUCTURED = 2
+
+
+class ImportType(enum.Enum):
+ COPY = 'COPY'
+ NO_COPY = 'NO_COPY'
+
+
+class DatasetKindV2(enum.Enum):
+ RAW = 'raw'
+ PROCESSED = 'processed'
+ SOURCE = 'source'
+ EXPORTED = 'exported'
+ INTERNAL_PROCESSED = 'internal_processed' # dataset generatred by internal module, like model or tee
+
+
+class DatasetSchemaChecker(enum.Enum):
+ RAW_ID_CHECKER = 'RAW_ID_CHECKER'
+ NUMERIC_COLUMNS_CHECKER = 'NUMERIC_COLUMNS_CHECKER'
+
+
+class DatasetJobKind(enum.Enum):
+ RSA_PSI_DATA_JOIN = 'RSA_PSI_DATA_JOIN'
+ LIGHT_CLIENT_RSA_PSI_DATA_JOIN = 'LIGHT_CLIENT_RSA_PSI_DATA_JOIN'
+ OT_PSI_DATA_JOIN = 'OT_PSI_DATA_JOIN'
+ LIGHT_CLIENT_OT_PSI_DATA_JOIN = 'LIGHT_CLIENT_OT_PSI_DATA_JOIN'
+ HASH_DATA_JOIN = 'HASH_DATA_JOIN'
+ DATA_JOIN = 'DATA_JOIN'
+ DATA_ALIGNMENT = 'DATA_ALIGNMENT'
+ IMPORT_SOURCE = 'IMPORT_SOURCE'
+ EXPORT = 'EXPORT'
+ ANALYZER = 'ANALYZER'
+
+
+# micro dataset_job's input/output dataset is the same one
+MICRO_DATASET_JOB = [DatasetJobKind.ANALYZER]
+
+LOCAL_DATASET_JOBS = [
+ DatasetJobKind.IMPORT_SOURCE,
+ DatasetJobKind.ANALYZER,
+ DatasetJobKind.EXPORT,
+ DatasetJobKind.LIGHT_CLIENT_OT_PSI_DATA_JOIN,
+ DatasetJobKind.LIGHT_CLIENT_RSA_PSI_DATA_JOIN,
+]
+
+
+class DatasetJobState(enum.Enum):
+ PENDING = 'PENDING'
+ RUNNING = 'RUNNING'
+ SUCCEEDED = 'SUCCEEDED'
+ FAILED = 'FAILED'
+ STOPPED = 'STOPPED'
+
+
+class StoreFormat(enum.Enum):
+ UNKNOWN = 'UNKNOWN'
+ CSV = 'CSV'
+ TFRECORDS = 'TFRECORDS'
+
+
+class DatasetJobSchedulerState(enum.Enum):
+ PENDING = 'PENDING'
+ RUNNABLE = 'RUNNABLE'
+ STOPPED = 'STOPPED'
+
- data_batches = db.relationship(
- 'DataBatch', primaryjoin='foreign(DataBatch.dataset_id) == Dataset.id')
- project = db.relationship(
- 'Project', primaryjoin='foreign(Dataset.project_id) == Project.id')
+DATASET_STATE_CONVERT_MAP_V2 = {
+ DatasetJobState.PENDING: ResourceState.PENDING,
+ DatasetJobState.RUNNING: ResourceState.PROCESSING,
+ DatasetJobState.SUCCEEDED: ResourceState.SUCCEEDED,
+ DatasetJobState.FAILED: ResourceState.FAILED,
+ DatasetJobState.STOPPED: ResourceState.FAILED,
+}
+
+
+class DataSourceType(enum.Enum):
+ # hdfs datasource path, e.g. hdfs:///home/xxx
+ HDFS = 'hdfs'
+ # nfs datasource path, e.g. file:///data/xxx
+ FILE = 'file'
+
+
+SOURCE_IS_DELETED = 'deleted'
+WORKFLOW_STATUS_STATE_MAPPER = {
+ WorkflowExternalState.COMPLETED: ResourceState.SUCCEEDED,
+ WorkflowExternalState.FAILED: ResourceState.FAILED,
+ WorkflowExternalState.STOPPED: ResourceState.FAILED,
+ WorkflowExternalState.INVALID: ResourceState.FAILED,
+}
+DATASET_JOB_FINISHED_STATE = [DatasetJobState.SUCCEEDED, DatasetJobState.FAILED, DatasetJobState.STOPPED]
-@to_dict_mixin(extras={'details': (lambda batch: batch.get_details())})
class DataBatch(db.Model):
__tablename__ = 'data_batches_v2'
- __table_args__ = (
- UniqueConstraint('event_time',
- 'dataset_id',
- name='uniq_event_time_dataset_id'),
- {
- 'comment': 'This is webconsole dataset table',
- 'mysql_engine': 'innodb',
- 'mysql_charset': 'utf8mb4',
- },
- )
- id = db.Column(db.Integer,
- primary_key=True,
- autoincrement=True,
- comment='id')
- event_time = db.Column(db.TIMESTAMP(timezone=True),
- nullable=False,
- comment='event_time')
+ __table_args__ = (UniqueConstraint('event_time', 'dataset_id', name='uniq_event_time_dataset_id'),
+ default_table_args('This is webconsole dataset table'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ name = db.Column(db.String(255), nullable=True, comment='data_batch name')
+ event_time = db.Column(db.TIMESTAMP(timezone=True), nullable=True, comment='event_time')
dataset_id = db.Column(db.Integer, nullable=False, comment='dataset_id')
path = db.Column(db.String(512), comment='path')
- state = db.Column(db.Enum(BatchState, native_enum=False),
+ # TODO(wangsen.0914): gonna to deprecate
+ state = db.Column(db.Enum(BatchState, native_enum=False, create_constraint=False),
default=BatchState.NEW,
comment='state')
+ # move column will be deprecated after dataset refactor
move = db.Column(db.Boolean, default=False, comment='move')
# Serialized proto of DatasetBatch
- details = db.Column(db.LargeBinary(), comment='details')
- file_size = db.Column(db.Integer, default=0, comment='file_size')
- num_imported_file = db.Column(db.Integer,
- default=0,
- comment='num_imported_file')
- num_file = db.Column(db.Integer, default=0, comment='num_file')
+ file_size = db.Column(db.BigInteger, default=0, comment='file_size in bytes')
+ num_example = db.Column(db.BigInteger, default=0, comment='num_example')
+ num_feature = db.Column(db.BigInteger, default=0, comment='num_feature')
+ meta_info = db.Column(db.Text(16777215), comment='dataset meta info')
comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
- created_at = db.Column(db.DateTime(timezone=True),
- server_default=func.now(),
- comment='created_at')
+ latest_parent_dataset_job_stage_id = db.Column(db.Integer,
+ nullable=False,
+ server_default=db.text('0'),
+ comment='latest parent dataset_job_stage id')
+ latest_analyzer_dataset_job_stage_id = db.Column(db.Integer,
+ nullable=False,
+ server_default=db.text('0'),
+ comment='latest analyzer dataset_job_stage id')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created_at')
updated_at = db.Column(db.DateTime(timezone=True),
server_default=func.now(),
server_onupdate=func.now(),
@@ -125,30 +197,605 @@ class DataBatch(db.Model):
'foreign(DataBatch.dataset_id)',
back_populates='data_batches')
- def set_details(self, proto):
- self.num_file = len(proto.files)
- num_imported_file = 0
- num_failed_file = 0
+ latest_parent_dataset_job_stage = db.relationship(
+ 'DatasetJobStage',
+ primaryjoin='DatasetJobStage.id == foreign(DataBatch.latest_parent_dataset_job_stage_id)',
+ # To disable the warning of back_populates
+ overlaps='data_batch')
+
+ @property
+ def batch_name(self):
+ return self.name or os.path.basename(os.path.abspath(self.path))
+
+ def get_frontend_state(self) -> ResourceState:
+ # use dataset_job state to replace dataset_job_stage state when dataset_job_stage not support
+ if self.latest_parent_dataset_job_stage is None:
+ return self.dataset.get_frontend_state()
+ return DATASET_STATE_CONVERT_MAP_V2.get(self.latest_parent_dataset_job_stage.state)
+
+ def is_available(self) -> bool:
+ return self.get_frontend_state() == ResourceState.SUCCEEDED
+
+ def to_proto(self) -> dataset_pb2.DataBatch:
+ proto = dataset_pb2.DataBatch(id=self.id,
+ name=self.batch_name,
+ dataset_id=self.dataset_id,
+ path=self.path,
+ file_size=self.file_size,
+ num_example=self.num_example,
+ num_feature=self.num_feature,
+ comment=self.comment,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ event_time=to_timestamp(self.event_time) if self.event_time else 0,
+ latest_parent_dataset_job_stage_id=self.latest_parent_dataset_job_stage_id,
+ latest_analyzer_dataset_job_stage_id=self.latest_analyzer_dataset_job_stage_id)
+ proto.state = self.get_frontend_state().name
+ return proto
+
+
+class Dataset(SoftDeleteModel, ReviewTicketAndAuthModel, db.Model):
+ __tablename__ = 'datasets_v2'
+ __table_args__ = (default_table_args('This is webconsole dataset table'))
+
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ uuid = db.Column(db.String(255), nullable=True, comment='dataset uuid')
+ is_published = db.Column(db.Boolean, default=False, comment='dataset is published or not')
+ name = db.Column(db.String(255), nullable=False, comment='dataset name')
+ creator_username = db.Column(db.String(255), default='', comment='creator username')
+ dataset_type = db.Column(db.Enum(DatasetType, native_enum=False, create_constraint=False),
+ default=DatasetType.PSI,
+ nullable=False,
+ comment='data type')
+ path = db.Column(db.String(512), comment='dataset path')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment of dataset')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created time')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ server_default=func.now(),
+ onupdate=func.now(),
+ comment='updated time')
+ project_id = db.Column(db.Integer, default=0, comment='project_id')
+ # New version of dataset kind
+ dataset_kind = db.Column(db.Enum(DatasetKindV2, native_enum=False, length=32, create_constraint=False),
+ default=DatasetKindV2.RAW,
+ comment='new version of dataset kind, choices [raw, processed, ...]')
+ # DatasetFormat enum
+ dataset_format = db.Column(db.Integer, default=0, comment='dataset format')
+ # StoreFormat
+ store_format = db.Column(db.Enum(StoreFormat, native_enum=False, length=32, create_constraint=False),
+ default=StoreFormat.TFRECORDS,
+ comment='dataset store format, like CSV, TFRECORDS, ...')
+ meta_info = db.Column(db.Text(16777215), comment='dataset meta info')
+
+ import_type = db.Column(db.Enum(ImportType, length=64, native_enum=False, create_constraint=False),
+ server_default=ImportType.COPY.name,
+ comment='import type')
+
+ data_batches = db.relationship('DataBatch',
+ primaryjoin='foreign(DataBatch.dataset_id) == Dataset.id',
+ order_by='desc(DataBatch.id)')
+ project = db.relationship('Project', primaryjoin='foreign(Dataset.project_id) == Project.id')
+
+ # dataset only has one main dataset_job as parent_dataset_job, but could have many micro dataset_job
+ @property
+ def parent_dataset_job(self):
+ return None if not db.object_session(self) else db.object_session(self).query(DatasetJob).filter(
+ DatasetJob.output_dataset_id == self.id).filter(
+ DatasetJob.kind.not_in(MICRO_DATASET_JOB)).execution_options(include_deleted=True).first()
+
+ # dataset only has one analyzer dataset_job
+ def analyzer_dataset_job(self):
+ return None if not db.object_session(self) else db.object_session(self).query(DatasetJob).filter(
+ DatasetJob.output_dataset_id == self.id).filter(
+ DatasetJob.kind == DatasetJobKind.ANALYZER).execution_options(include_deleted=True).first()
+
+ # single table inheritance
+ # Ref: https://docs.sqlalchemy.org/en/14/orm/inheritance.html
+ __mapper_args__ = {'polymorphic_identity': DatasetKindV2.RAW, 'polymorphic_on': dataset_kind}
+
+ def get_frontend_state(self) -> ResourceState:
+ # if parent_dataset_job failed to generate, dataset state is failed
+ if self.parent_dataset_job is None:
+ return ResourceState.FAILED
+ return DATASET_STATE_CONVERT_MAP_V2.get(self.parent_dataset_job.state)
+
+ def get_file_size(self) -> int:
file_size = 0
- # Aggregates stats
- for file in proto.files:
- if file.state == dataset_pb2.File.State.COMPLETED:
- num_imported_file += 1
- file_size += file.size
- elif file.state == dataset_pb2.File.State.FAILED:
- num_failed_file += 1
- if num_imported_file + num_failed_file == self.num_file:
- if num_failed_file > 0:
- self.state = BatchState.FAILED
- else:
- self.state = BatchState.SUCCESS
- self.num_imported_file = num_imported_file
- self.file_size = file_size
- self.details = proto.SerializeToString()
-
- def get_details(self):
- if self.details is None:
+ for batch in self.data_batches:
+ if not batch.file_size or batch.file_size == ERROR_BATCH_SIZE:
+ continue
+ file_size += batch.file_size
+ return file_size
+
+ def get_num_example(self) -> int:
+ return sum([batch.num_example or 0 for batch in self.data_batches])
+
+ def get_num_feature(self) -> int:
+ if len(self.data_batches) != 0:
+ # num_feature is decided by the first data_batch
+ return self.data_batches[0].num_feature
+ return 0
+
+ # TODO(hangweiqiang): remove data_source after adapting fedlearner to dataset path
+ def get_data_source(self) -> Optional[str]:
+ if self.parent_dataset_job is not None:
+ dataset_job_stage = db.object_session(self).query(DatasetJobStage).filter_by(
+ dataset_job_id=self.parent_dataset_job.id).first()
+ if dataset_job_stage is not None:
+ return f'{dataset_job_stage.uuid}-psi-data-join-job'
+ if self.parent_dataset_job.workflow is not None:
+ return f'{self.parent_dataset_job.workflow.uuid}-psi-data-join-job'
+ return None
+
+ @property
+ def publish_frontend_state(self) -> PublishFrontendState:
+ if not self.is_published:
+ return PublishFrontendState.UNPUBLISHED
+ if self.ticket_status == TicketStatus.APPROVED:
+ return PublishFrontendState.PUBLISHED
+ if self.ticket_status == TicketStatus.DECLINED:
+ return PublishFrontendState.TICKET_DECLINED
+ return PublishFrontendState.TICKET_PENDING
+
+ def to_ref(self) -> DatasetRef:
+ # TODO(liuhehan): this is a lazy update of dataset store_format, remove it after release 2.4
+ if self.dataset_kind in [DatasetKindV2.RAW, DatasetKindV2.PROCESSED] and self.store_format is None:
+ self.store_format = StoreFormat.TFRECORDS
+ # TODO(liuhehan): this is a lazy update for auth status, remove after release 2.4
+ if self.auth_status is None:
+ self.auth_status = AuthStatus.AUTHORIZED
+ return DatasetRef(id=self.id,
+ uuid=self.uuid,
+ project_id=self.project_id,
+ name=self.name,
+ created_at=to_timestamp(self.created_at),
+ state_frontend=self.get_frontend_state().name,
+ path=self.path,
+ is_published=self.is_published,
+ dataset_format=DatasetFormat(self.dataset_format).name,
+ comment=self.comment,
+ dataset_kind=self.dataset_kind.name,
+ file_size=self.get_file_size(),
+ num_example=self.get_num_example(),
+ data_source=self.get_data_source(),
+ creator_username=self.creator_username,
+ dataset_type=self.dataset_type.name,
+ store_format=self.store_format.name if self.store_format else '',
+ import_type=self.import_type.name,
+ publish_frontend_state=self.publish_frontend_state.name,
+ auth_frontend_state=self.auth_frontend_state.name,
+ local_auth_status=self.auth_status.name,
+ participants_info=self.get_participants_info())
+
+ def to_proto(self) -> dataset_pb2.Dataset:
+ # TODO(liuhehan): this is a lazy update of dataset store_format, remove it after release 2.4
+ if self.dataset_kind in [DatasetKindV2.RAW, DatasetKindV2.PROCESSED] and self.store_format is None:
+ self.store_format = StoreFormat.TFRECORDS
+ # TODO(liuhehan): this is a lazy update for auth status, remove after release 2.4
+ if self.auth_status is None:
+ self.auth_status = AuthStatus.AUTHORIZED
+ meta_data = self.get_meta_info()
+ analyzer_dataset_job = self.analyzer_dataset_job()
+ # use newest data_batch updated_at time as dataset updated_at time if has data_batch
+ updated_at = self.data_batches[0].updated_at if self.data_batches else self.updated_at
+ return dataset_pb2.Dataset(
+ id=self.id,
+ uuid=self.uuid,
+ is_published=self.is_published,
+ project_id=self.project_id,
+ name=self.name,
+ workflow_id=self.parent_dataset_job.workflow_id if self.parent_dataset_job is not None else 0,
+ path=self.path,
+ created_at=to_timestamp(self.created_at),
+ data_source=self.get_data_source(),
+ file_size=self.get_file_size(),
+ num_example=self.get_num_example(),
+ comment=self.comment,
+ num_feature=self.get_num_feature(),
+ updated_at=to_timestamp(updated_at),
+ deleted_at=to_timestamp(self.deleted_at) if self.deleted_at else None,
+ parent_dataset_job_id=self.parent_dataset_job.id if self.parent_dataset_job is not None else 0,
+ dataset_format=DatasetFormat(self.dataset_format).name,
+ analyzer_dataset_job_id=analyzer_dataset_job.id if analyzer_dataset_job is not None else 0,
+ state_frontend=self.get_frontend_state().name,
+ dataset_kind=self.dataset_kind.name,
+ value=meta_data.value,
+ schema_checkers=meta_data.schema_checkers,
+ creator_username=self.creator_username,
+ import_type=self.import_type.name,
+ dataset_type=self.dataset_type.name,
+ store_format=self.store_format.name if self.store_format else '',
+ publish_frontend_state=self.publish_frontend_state.name,
+ auth_frontend_state=self.auth_frontend_state.name,
+ local_auth_status=self.auth_status.name,
+ participants_info=self.get_participants_info())
+
+ def is_tabular(self) -> bool:
+ return self.dataset_format == DatasetFormat.TABULAR.value
+
+ def is_image(self) -> bool:
+ return self.dataset_format == DatasetFormat.IMAGE.value
+
+ def set_meta_info(self, meta: DatasetMetaInfo):
+ if meta is None:
+ meta = DatasetMetaInfo()
+ self.meta_info = text_format.MessageToString(meta)
+
+ def get_meta_info(self) -> DatasetMetaInfo:
+ meta = DatasetMetaInfo()
+ if self.meta_info is not None:
+ meta = text_format.Parse(self.meta_info, DatasetMetaInfo())
+ return meta
+
+ def get_single_batch(self) -> DataBatch:
+ """Get single batch of this dataset
+
+ Returns:
+ DataBatch: according data batch
+
+ Raises:
+ TypeError: when there's no data batch or more than one data batch
+ """
+ if not self.data_batches:
+ raise TypeError(f'there is no data_batch for this dataset {self.id}')
+ if len(self.data_batches) != 1:
+ raise TypeError(f'there is more than one data_batch for this dataset {self.id}')
+ return self.data_batches[0]
+
+
+class DataSource(Dataset):
+
+ __mapper_args__ = {'polymorphic_identity': DatasetKindV2.SOURCE}
+
+ def to_proto(self) -> dataset_pb2.DataSource:
+ meta_info = self.get_meta_info()
+ return dataset_pb2.DataSource(
+ id=self.id,
+ comment=self.comment,
+ uuid=self.uuid,
+ name=self.name,
+ type=meta_info.datasource_type,
+ url=self.path,
+ created_at=to_timestamp(self.created_at),
+ project_id=self.project_id,
+ is_user_upload=meta_info.is_user_upload,
+ is_user_export=meta_info.is_user_export,
+ creator_username=self.creator_username,
+ dataset_format=DatasetFormat(self.dataset_format).name,
+ store_format=self.store_format.name if self.store_format else '',
+ dataset_type=self.dataset_type.name,
+ )
+
+
+class ProcessedDataset(Dataset):
+
+ __mapper_args__ = {'polymorphic_identity': DatasetKindV2.PROCESSED}
+
+
+class ExportedDataset(Dataset):
+
+ __mapper_args__ = {'polymorphic_identity': DatasetKindV2.EXPORTED}
+
+
+class InternalProcessedDataset(Dataset):
+
+ __mapper_args__ = {'polymorphic_identity': DatasetKindV2.INTERNAL_PROCESSED}
+
+ def get_frontend_state(self) -> ResourceState:
+ # we just hack internal_processed dataset state to successded now
+ return ResourceState.SUCCEEDED
+
+
+class DatasetJob(SoftDeleteModel, db.Model):
+ """ DatasetJob is the abstraction of basic action inside dataset module.
+
+ UseCase 1: A import job from datasource to a dataset
+ {
+ "id": 1,
+ "uuid": u456,
+ "input_dataset_id": 5,
+ "output_dataset_id": 4,
+ "kind": "import_datasource",
+ "global_configs": map,
+ "workflow_id": 6,
+ "coordinator_id": 0,
+ }
+
+ UseCase 2: A data join job between participants
+ coodinator:
+ {
+ "id": 1,
+ "uuid": u456,
+ "input_dataset_id": 2,
+ "output_dataset_id": 4,
+ "kind": "rsa_psi_data_join",
+ "global_configs": map,
+ "coordinator_id": 0,
+ "workflow_id": 6,
+ }
+
+ participant:
+ {
+ "id": 1,
+ "uuid": u456,
+ "input_dataset_id": 4,
+ "output_dataset_id": 7,
+ "kind": "rsa_psi_data_join",
+ "global_configs": "", # pull from coodinator
+ "coordinator_id": 1,
+ "workflow_id": 7,
+ }
+ """
+ __tablename__ = 'dataset_jobs_v2'
+ __table_args__ = (UniqueConstraint('uuid', name='uniq_dataset_job_uuid'), default_table_args('dataset_jobs_v2'))
+
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id of dataset job')
+ uuid = db.Column(db.String(255), nullable=False, comment='dataset job uuid')
+ name = db.Column(db.String(255), nullable=True, comment='dataset job name')
+
+ # state is updated to keep the same with the newest dataset_job_stage
+ state = db.Column(db.Enum(DatasetJobState, length=64, native_enum=False, create_constraint=False),
+ nullable=False,
+ default=DatasetJobState.PENDING,
+ comment='dataset job state')
+ project_id = db.Column(db.Integer, nullable=False, comment='project id')
+
+ # If multiple dataset/datasource input is supported, the following two columns will be deprecated.
+ # Instead, a new table will be introduced.
+ input_dataset_id = db.Column(db.Integer, nullable=False, comment='input dataset id')
+ output_dataset_id = db.Column(db.Integer, nullable=False, comment='output dataset id')
+
+ kind = db.Column(db.Enum(DatasetJobKind, length=128, native_enum=False, create_constraint=False),
+ nullable=False,
+ comment='dataset job kind')
+ # If batch update mode is supported, this column will be deprecated.
+ # Instead, a new table called DatasetStage and a new Column called Context will be introduced.
+ workflow_id = db.Column(db.Integer, nullable=True, default=0, comment='relating workflow id')
+ context = db.Column(db.Text(), nullable=True, default=None, comment='context info of dataset job')
+
+ global_configs = db.Column(
+ db.Text(), comment='global configs of this job including related participants only appear in coordinator')
+ coordinator_id = db.Column(db.Integer, nullable=False, default=0, comment='participant id of this job coordinator')
+
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created time')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ server_default=func.now(),
+ onupdate=func.now(),
+ comment='updated time')
+ started_at = db.Column(db.DateTime(timezone=True), comment='started_at')
+ finished_at = db.Column(db.DateTime(timezone=True), comment='finished_at')
+
+ # cron_job will use time_range to infer event_time for next data_batch
+ time_range = db.Column(db.Interval(native=False), nullable=True, comment='time_range to create new job_stage')
+ # cron_job will read event_time to get current data_batch,
+ # and update it to event_time + time_range when next new data_batch created
+ event_time = db.Column(db.DateTime(timezone=True), nullable=True, comment='event_time for current data_batch')
+
+ # scheduler_state will be filter and change by job_scheduler_v2
+ scheduler_state = db.Column(db.Enum(DatasetJobSchedulerState, length=64, native_enum=False,
+ create_constraint=False),
+ nullable=True,
+ default=DatasetJobSchedulerState.PENDING,
+ comment='dataset job scheduler state')
+
+ creator_username = db.Column(db.String(255), nullable=True, comment='creator username')
+
+ workflow = db.relationship('Workflow', primaryjoin='foreign(DatasetJob.workflow_id) == Workflow.id')
+ project = db.relationship('Project', primaryjoin='foreign(DatasetJob.project_id) == Project.id')
+ input_dataset = db.relationship('Dataset', primaryjoin='foreign(DatasetJob.input_dataset_id) == Dataset.id')
+
+ @property
+ def output_dataset(self):
+ return None if not db.object_session(self) else db.object_session(self).query(Dataset).filter(
+ Dataset.id == self.output_dataset_id).execution_options(include_deleted=True).first()
+
+ dataset_job_stages = db.relationship(
+ 'DatasetJobStage',
+ order_by='desc(DatasetJobStage.created_at)',
+ primaryjoin='DatasetJob.id == foreign(DatasetJobStage.dataset_job_id)',
+ # To disable the warning of back_populates
+ overlaps='dataset_job')
+
+ def get_global_configs(self) -> Optional[DatasetJobGlobalConfigs]:
+ # For participant, global_config is empty text.
+ if self.global_configs is None or len(self.global_configs) == 0:
return None
- proto = dataset_pb2.DataBatch()
- proto.ParseFromString(self.details)
+ return text_format.Parse(self.global_configs, DatasetJobGlobalConfigs())
+
+ def set_global_configs(self, global_configs: DatasetJobGlobalConfigs):
+ self.global_configs = text_format.MessageToString(global_configs)
+
+ def get_context(self) -> DatasetJobContext:
+ context_pb = DatasetJobContext()
+ if self.context:
+ context_pb = text_format.Parse(self.context, context_pb)
+ return context_pb
+
+ def set_context(self, context: DatasetJobContext):
+ self.context = text_format.MessageToString(context)
+
+ def set_scheduler_message(self, scheduler_message: str):
+ context_pb = self.get_context()
+ context_pb.scheduler_message = scheduler_message
+ self.set_context(context=context_pb)
+
+ @property
+ def time_range_pb(self) -> TimeRange:
+ time_range_pb = TimeRange()
+ if self.is_daily_cron():
+ time_range_pb.days = self.time_range.days
+ elif self.is_hourly_cron():
+ # convert seconds to hours
+ time_range_pb.hours = int(self.time_range.seconds / 3600)
+ return time_range_pb
+
+ def to_proto(self) -> dataset_pb2.DatasetJob:
+ context = self.get_context()
+ proto = dataset_pb2.DatasetJob(
+ id=self.id,
+ uuid=self.uuid,
+ name=self.name,
+ project_id=self.project_id,
+ workflow_id=self.workflow_id,
+ coordinator_id=self.coordinator_id,
+ kind=self.kind.value,
+ state=self.state.name,
+ global_configs=self.get_global_configs(),
+ input_data_batch_num_example=context.input_data_batch_num_example,
+ output_data_batch_num_example=context.output_data_batch_num_example,
+ has_stages=context.has_stages,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ started_at=to_timestamp(self.started_at) if self.started_at else 0,
+ finished_at=to_timestamp(self.finished_at) if self.finished_at else 0,
+ creator_username=self.creator_username,
+ scheduler_state=self.scheduler_state.name if self.scheduler_state else '',
+ time_range=self.time_range_pb,
+ scheduler_message=context.scheduler_message,
+ )
+ if self.output_dataset:
+ proto.result_dataset_uuid = self.output_dataset.uuid
+ proto.result_dataset_name = self.output_dataset.name
+ if self.workflow_id:
+ proto.is_ready = True
return proto
+
+ def to_ref(self) -> dataset_pb2.DatasetJobRef:
+ return dataset_pb2.DatasetJobRef(
+ uuid=self.uuid,
+ id=self.id,
+ name=self.name,
+ coordinator_id=self.coordinator_id,
+ project_id=self.project_id,
+ kind=self.kind.name,
+ result_dataset_id=self.output_dataset_id,
+ result_dataset_name=self.output_dataset.name if self.output_dataset else '',
+ state=self.state.name,
+ created_at=to_timestamp(self.created_at),
+ has_stages=self.get_context().has_stages,
+ creator_username=self.creator_username,
+ )
+
+ def is_coordinator(self) -> bool:
+ return self.coordinator_id == 0
+
+ def is_finished(self) -> bool:
+ return self.state in DATASET_JOB_FINISHED_STATE
+
+ def is_cron(self) -> bool:
+ return self.time_range is not None
+
+ def is_daily_cron(self) -> bool:
+ if self.time_range is None:
+ return False
+ return self.time_range.days > 0
+
+ def is_hourly_cron(self) -> bool:
+ if self.time_range is None:
+ return False
+ # hourly time_range is less than one day
+ return self.time_range.days == 0
+
+
+class DatasetJobStage(SoftDeleteModel, db.Model):
+ __tablename__ = 'dataset_job_stages_v2'
+ __table_args__ = (UniqueConstraint('uuid',
+ name='uniq_dataset_job_stage_uuid'), default_table_args('dataset_job_stages_v2'))
+
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id of dataset job stage')
+ uuid = db.Column(db.String(255), nullable=False, comment='dataset job stage uuid')
+ name = db.Column(db.String(255), nullable=True, comment='dataset job stage name')
+ state = db.Column(db.Enum(DatasetJobState, length=64, native_enum=False, create_constraint=False),
+ nullable=False,
+ default=DatasetJobState.PENDING,
+ comment='dataset job stage state')
+ project_id = db.Column(db.Integer, nullable=False, comment='project id')
+ workflow_id = db.Column(db.Integer, nullable=True, default=0, comment='relating workflow id')
+ dataset_job_id = db.Column(db.Integer, nullable=False, comment='dataset_job id')
+ data_batch_id = db.Column(db.Integer, nullable=False, comment='data_batch id')
+ event_time = db.Column(db.DateTime(timezone=True), nullable=True, comment='event_time of data upload')
+ # store dataset_job global_configs to job_stage global_configs when job_stage created if is coordinator
+ global_configs = db.Column(
+ db.Text(), comment='global configs of this stage including related participants only appear in coordinator')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created time')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ server_default=func.now(),
+ onupdate=func.now(),
+ comment='updated time')
+ started_at = db.Column(db.DateTime(timezone=True), comment='started_at')
+ finished_at = db.Column(db.DateTime(timezone=True), comment='finished_at')
+
+ # dataset_job coordinator might be different with dataset_job_stage coordinator
+ coordinator_id = db.Column(db.Integer,
+ nullable=False,
+ server_default=db.text('0'),
+ comment='participant id of this dataset_job_stage, 0 if it is coordinator')
+
+ context = db.Column(db.Text(), nullable=True, default=None, comment='context info of dataset job stage')
+
+ workflow = db.relationship('Workflow', primaryjoin='foreign(DatasetJobStage.workflow_id) == Workflow.id')
+ project = db.relationship('Project', primaryjoin='foreign(DatasetJobStage.project_id) == Project.id')
+ dataset_job = db.relationship('DatasetJob', primaryjoin='foreign(DatasetJobStage.dataset_job_id) == DatasetJob.id')
+ data_batch = db.relationship('DataBatch', primaryjoin='foreign(DatasetJobStage.data_batch_id) == DataBatch.id')
+
+ def get_global_configs(self) -> Optional[DatasetJobGlobalConfigs]:
+ # For participant, global_config is empty text.
+ if self.global_configs is None or len(self.global_configs) == 0:
+ return None
+ return text_format.Parse(self.global_configs, DatasetJobGlobalConfigs())
+
+ def set_global_configs(self, global_configs: DatasetJobGlobalConfigs):
+ self.global_configs = text_format.MessageToString(global_configs)
+
+ def is_finished(self) -> bool:
+ return self.state in DATASET_JOB_FINISHED_STATE
+
+ def to_ref(self) -> dataset_pb2.DatasetJobStageRef:
+ return dataset_pb2.DatasetJobStageRef(id=self.id,
+ name=self.name,
+ dataset_job_id=self.dataset_job_id,
+ output_data_batch_id=self.data_batch_id,
+ project_id=self.project_id,
+ state=self.state.name,
+ created_at=to_timestamp(self.created_at),
+ kind=self.dataset_job.kind.name if self.dataset_job else '')
+
+ def to_proto(self) -> dataset_pb2.DatasetJobStage:
+ context = self.get_context()
+ return dataset_pb2.DatasetJobStage(id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ dataset_job_id=self.dataset_job_id,
+ output_data_batch_id=self.data_batch_id,
+ workflow_id=self.workflow_id,
+ project_id=self.project_id,
+ state=self.state.name,
+ event_time=to_timestamp(self.event_time) if self.event_time else 0,
+ global_configs=self.get_global_configs(),
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ started_at=to_timestamp(self.started_at) if self.started_at else 0,
+ finished_at=to_timestamp(self.finished_at) if self.finished_at else 0,
+ dataset_job_uuid=self.dataset_job.uuid if self.dataset_job else None,
+ is_ready=self.workflow is not None,
+ kind=self.dataset_job.kind.name if self.dataset_job else '',
+ input_data_batch_num_example=context.input_data_batch_num_example,
+ output_data_batch_num_example=context.output_data_batch_num_example,
+ scheduler_message=context.scheduler_message)
+
+ def get_context(self) -> DatasetJobStageContext:
+ context_pb = DatasetJobStageContext()
+ if self.context:
+ context_pb = text_format.Parse(self.context, context_pb)
+ return context_pb
+
+ def set_context(self, context: DatasetJobStageContext):
+ self.context = text_format.MessageToString(context)
+
+ def set_scheduler_message(self, scheduler_message: str):
+ context_pb = self.get_context()
+ context_pb.scheduler_message = scheduler_message
+ self.set_context(context=context_pb)
+
+ def is_coordinator(self) -> bool:
+ return self.coordinator_id == 0
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/models_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/models_test.py
new file mode 100644
index 000000000..53406258d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/models_test.py
@@ -0,0 +1,711 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime, timedelta
+import time
+import unittest
+from unittest.mock import MagicMock, PropertyMock, patch
+from testing.common import NoWebServerTestCase
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import (DATASET_STATE_CONVERT_MAP_V2, Dataset, DataSource, DatasetFormat,
+ DatasetJobSchedulerState, DatasetJobStage, DatasetKindV2, ImportType,
+ PublishFrontendState, ResourceState, StoreFormat, DatasetType,
+ DatasetJob, DataSourceType, DatasetJobKind, DatasetJobState,
+ DataBatch)
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto import dataset_pb2, project_pb2
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobStageContext, DatasetMetaInfo, DatasetJobConfig, \
+ DatasetJobGlobalConfigs, TimeRange
+from google.protobuf.struct_pb2 import Value
+
+
+class DataBatchTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ dataset = Dataset(id=1,
+ uuid=resource_uuid(),
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=False,
+ import_type=ImportType.NO_COPY)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='20220701',
+ dataset_id=1,
+ path='/data/test/batch/20220701',
+ event_time=datetime.strptime('20220701', '%Y%m%d'),
+ file_size=100,
+ num_example=10,
+ num_feature=3,
+ latest_parent_dataset_job_stage_id=1)
+ session.add(data_batch)
+ session.commit()
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ data_batch: DataBatch = session.query(DataBatch).get(1)
+ self.assertPartiallyEqual(
+ to_dict(data_batch.to_proto()),
+ {
+ 'id': 1,
+ 'name': '20220701',
+ 'dataset_id': 1,
+ 'path': '/data/test/batch/20220701',
+ 'event_time': to_timestamp(datetime.strptime('20220701', '%Y%m%d')),
+ 'file_size': 100,
+ 'num_example': 10,
+ 'num_feature': 3,
+ 'comment': '',
+ 'state': 'FAILED',
+ 'latest_parent_dataset_job_stage_id': 1,
+ 'latest_analyzer_dataset_job_stage_id': 0,
+ },
+ ignore_fields=['created_at', 'updated_at'],
+ )
+
+ def test_is_available(self):
+ with db.session_scope() as session:
+ job_stage = DatasetJobStage(id=1,
+ uuid='job stage uuid',
+ name='default dataset job stage',
+ project_id=1,
+ workflow_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2012, 1, 15),
+ state=DatasetJobState.PENDING,
+ coordinator_id=0)
+ session.add(job_stage)
+ session.commit()
+ with db.session_scope() as session:
+ data_batch = session.query(DataBatch).get(1)
+ self.assertFalse(data_batch.is_available())
+ job_stage = session.query(DatasetJobStage).get(1)
+ job_stage.state = DatasetJobState.SUCCEEDED
+ session.flush()
+ self.assertTrue(data_batch.is_available())
+
+
+class DatasetTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ default_dataset = Dataset(id=10,
+ uuid=resource_uuid(),
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=False)
+ session.add(default_dataset)
+ session.commit()
+
+ def test_dataset_meta_info(self):
+ meta_info = DatasetMetaInfo(value=100)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ dataset.set_meta_info(meta_info)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ meta_info_current = dataset.get_meta_info()
+ self.assertEqual(meta_info_current, meta_info)
+
+ def test_get_frontend_state(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=10,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ for state, front_state in DATASET_STATE_CONVERT_MAP_V2.items():
+ dataset.parent_dataset_job.state = state
+ self.assertEqual(dataset.get_frontend_state(), front_state)
+
+ def test_get_single_batch(self):
+ # test no batch
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ with self.assertRaises(TypeError) as cm:
+ dataset.get_single_batch()
+ self.assertEqual(cm.exception.args[0], 'there is no data_batch for this dataset 10')
+
+ # test one batch
+ first_event_time = datetime(year=2000, month=1, day=1)
+ with db.session_scope() as session:
+ batch = DataBatch(dataset_id=10, event_time=first_event_time)
+ session.add(batch)
+ session.commit()
+
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ batch = dataset.get_single_batch()
+ self.assertEqual(batch.event_time, first_event_time)
+
+ # test two batch
+ second_event_time = datetime(year=2000, month=1, day=2)
+ with db.session_scope() as session:
+ batch = DataBatch(dataset_id=10, event_time=second_event_time)
+ session.add(batch)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ with self.assertRaises(TypeError) as cm:
+ dataset.get_single_batch()
+ self.assertEqual(cm.exception.args[0], 'there is more than one data_batch for this dataset 10')
+
+ def test_to_proto(self):
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(10)
+ dataset.auth_status = AuthStatus.AUTHORIZED
+ dataset.set_participants_info(participants_info)
+ dataset_proto = dataset.to_proto()
+ self.assertPartiallyEqual(
+ to_dict(dataset_proto),
+ {
+ 'id': 10,
+ 'project_id': 1,
+ 'name': 'default dataset',
+ 'path': '/data/dataset/123',
+ 'comment': 'test comment',
+ 'dataset_format': 'TABULAR',
+ 'state_frontend': 'FAILED',
+ 'dataset_kind': 'RAW',
+ 'workflow_id': 0,
+ 'data_source': '',
+ 'file_size': 0,
+ 'num_example': 0,
+ 'num_feature': 0,
+ 'deleted_at': 0,
+ 'parent_dataset_job_id': 0,
+ 'analyzer_dataset_job_id': 0,
+ 'is_published': False,
+ 'value': 0,
+ 'schema_checkers': [],
+ 'creator_username': '',
+ 'import_type': 'COPY',
+ 'dataset_type': 'PSI',
+ 'store_format': 'TFRECORDS',
+ 'publish_frontend_state': 'UNPUBLISHED',
+ 'auth_frontend_state': 'AUTH_PENDING',
+ 'local_auth_status': 'AUTHORIZED',
+ 'participants_info': {
+ 'participants_map': {
+ 'test_1': {
+ 'auth_status': 'PENDING',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': '',
+ },
+ 'test_2': {
+ 'auth_status': 'AUTHORIZED',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': '',
+ },
+ }
+ },
+ },
+ ignore_fields=['uuid', 'created_at', 'updated_at'],
+ )
+
+ def test_parent_dataset_job(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=10,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job)
+ micro_dataset_job = DatasetJob(id=2,
+ uuid='micro_dataset_job',
+ project_id=1,
+ input_dataset_id=10,
+ output_dataset_id=10,
+ kind=DatasetJobKind.ANALYZER,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(micro_dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.parent_dataset_job.id, 1)
+
+ def test_publish_frontend_state(self):
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.publish_frontend_state, PublishFrontendState.UNPUBLISHED)
+ dataset.is_published = True
+ dataset.ticket_status = TicketStatus.APPROVED
+ self.assertEqual(dataset.publish_frontend_state, PublishFrontendState.PUBLISHED)
+ dataset.ticket_status = TicketStatus.PENDING
+ self.assertEqual(dataset.publish_frontend_state, PublishFrontendState.TICKET_PENDING)
+ dataset.ticket_status = TicketStatus.DECLINED
+ self.assertEqual(dataset.publish_frontend_state, PublishFrontendState.TICKET_DECLINED)
+
+ def test_updated_at(self):
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1,
+ name='20220701',
+ dataset_id=10,
+ path='/data/test/batch/20220701',
+ event_time=datetime(2022, 7, 1),
+ file_size=100,
+ num_example=10,
+ num_feature=3)
+ session.add(data_batch)
+ session.commit()
+ # make sure two batch have different updated_at time
+ time.sleep(1)
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=2,
+ name='20220702',
+ dataset_id=10,
+ path='/data/test/batch/20220702',
+ event_time=datetime(2022, 7, 2),
+ file_size=100,
+ num_example=10,
+ num_feature=3)
+ session.add(data_batch)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ data_batch = session.query(DataBatch).get(2)
+ self.assertEqual(dataset.to_proto().updated_at, to_timestamp(data_batch.updated_at))
+
+
+class DataSourceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ default_datasource = DataSource(id=10,
+ uuid=resource_uuid(),
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ is_published=False,
+ creator_username='xiaohang',
+ store_format=StoreFormat.CSV,
+ dataset_format=DatasetFormat.TABULAR.value)
+ default_datasource.set_meta_info(
+ meta=DatasetMetaInfo(datasource_type=DataSourceType.HDFS.value, is_user_upload=False))
+ session.add(default_datasource)
+ session.commit()
+
+ def test_to_data_source(self):
+ with db.session_scope() as session:
+ dataset = session.query(DataSource).get(10)
+ data_source = dataset_pb2.DataSource(id=dataset.id,
+ uuid=dataset.uuid,
+ name=dataset.name,
+ type=DataSourceType.HDFS.value,
+ url=dataset.path,
+ created_at=to_timestamp(dataset.created_at),
+ project_id=dataset.project_id,
+ is_user_upload=False,
+ creator_username='xiaohang',
+ dataset_format='TABULAR',
+ store_format='CSV',
+ dataset_type='PSI',
+ comment='test comment')
+ self.assertEqual(dataset.to_proto(), data_source)
+
+
+class InternalProcessedDatasetTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ default_dataset = Dataset(id=10,
+ uuid=resource_uuid(),
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ dataset_kind=DatasetKindV2.INTERNAL_PROCESSED,
+ is_published=False,
+ auth_status=AuthStatus.AUTHORIZED)
+ session.add(default_dataset)
+ session.commit()
+
+ def test_get_frontend_state(self):
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(10)
+ self.assertEqual(dataset.get_frontend_state(), ResourceState.SUCCEEDED)
+
+
+class DatasetJobTest(NoWebServerTestCase):
+
+ def test_get_set_global_configs(self):
+ dataset_job = DatasetJob()
+ global_configs = DatasetJobGlobalConfigs()
+ global_configs.global_configs['test'].MergeFrom(
+ DatasetJobConfig(dataset_uuid=resource_uuid(),
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ dataset_job.set_global_configs(global_configs)
+ new_global_configs = dataset_job.get_global_configs()
+ self.assertEqual(new_global_configs, global_configs)
+
+ @patch('fedlearner_webconsole.dataset.models.DatasetJob.output_dataset', new_callable=PropertyMock)
+ def test_to_proto(self, mock_output_dataset: MagicMock):
+ uuid = resource_uuid()
+ current_time = now()
+ dataset_job = DatasetJob(id=1,
+ name='test_dataset_job',
+ coordinator_id=2,
+ uuid=uuid,
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ created_at=current_time,
+ updated_at=current_time,
+ creator_username='test user',
+ scheduler_state=DatasetJobSchedulerState.PENDING)
+ global_configs = DatasetJobGlobalConfigs()
+ global_configs.global_configs['test'].MergeFrom(
+ DatasetJobConfig(dataset_uuid=resource_uuid(),
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ dataset_job.set_global_configs(global_configs)
+ dataset_job.set_scheduler_message(scheduler_message='调度信息 🐵')
+
+ mock_output_dataset.return_value = None
+ dataset_job_pb = dataset_pb2.DatasetJob(id=1,
+ name='test_dataset_job',
+ coordinator_id=2,
+ uuid=uuid,
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT.value,
+ state=DatasetJobState.PENDING.value,
+ global_configs=global_configs,
+ created_at=to_timestamp(current_time),
+ updated_at=to_timestamp(current_time),
+ creator_username='test user',
+ scheduler_state=DatasetJobSchedulerState.PENDING.name,
+ time_range=TimeRange(),
+ scheduler_message='调度信息 🐵')
+
+ self.assertEqual(dataset_job.to_proto(), dataset_job_pb)
+
+ dataset_job.workflow = Workflow(id=1, uuid='workflow_uuid')
+ mock_output_dataset.return_value = Dataset(id=1, name='test_dataset', uuid='dataset_uuid')
+ dataset_job.workflow_id = 1
+ dataset_job.output_dataset_id = 1
+ dataset_job_pb.is_ready = True
+ dataset_job_pb.workflow_id = 1
+ dataset_job_pb.result_dataset_name = 'test_dataset'
+ dataset_job_pb.result_dataset_uuid = 'dataset_uuid'
+ self.assertEqual(dataset_job.to_proto(), dataset_job_pb)
+
+ context = dataset_job.get_context()
+ context.input_data_batch_num_example = 1000
+ context.output_data_batch_num_example = 500
+ dataset_job.set_context(context)
+ dataset_job_pb.input_data_batch_num_example = 1000
+ dataset_job_pb.output_data_batch_num_example = 500
+ self.assertEqual(dataset_job.to_proto(), dataset_job_pb)
+
+ @patch('fedlearner_webconsole.dataset.models.DatasetJob.output_dataset', new_callable=PropertyMock)
+ def test_to_ref(self, mock_output_dataset: MagicMock):
+ uuid = resource_uuid()
+ output_dataset = Dataset(name='test_output_dataset', id=1)
+ dataset_job = DatasetJob(id=1,
+ name='test_dataset_job',
+ coordinator_id=2,
+ uuid=uuid,
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ output_dataset_id=1,
+ created_at=now(),
+ creator_username='test user')
+ mock_output_dataset.return_value = output_dataset
+ self.assertPartiallyEqual(to_dict(dataset_job.to_ref()), {
+ 'id': 1,
+ 'name': 'test_dataset_job',
+ 'coordinator_id': 2,
+ 'uuid': uuid,
+ 'project_id': 1,
+ 'kind': DatasetJobKind.DATA_ALIGNMENT.name,
+ 'state': DatasetJobState.PENDING.name,
+ 'result_dataset_id': 1,
+ 'result_dataset_name': 'test_output_dataset',
+ 'has_stages': False,
+ 'creator_username': 'test user',
+ },
+ ignore_fields=['created_at'])
+
+ def test_is_finished(self):
+ dataset_job = DatasetJob(uuid='uuid',
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ created_at=now())
+ self.assertFalse(dataset_job.is_finished())
+ dataset_job.state = DatasetJobState.RUNNING
+ self.assertFalse(dataset_job.is_finished())
+ dataset_job.state = DatasetJobState.SUCCEEDED
+ self.assertTrue(dataset_job.is_finished())
+ dataset_job.state = DatasetJobState.FAILED
+ self.assertTrue(dataset_job.is_finished())
+
+ def test_is_cron(self):
+ dataset_job = DatasetJob(id=1,
+ name='test_dataset_job',
+ coordinator_id=2,
+ uuid='uuid',
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING)
+ self.assertFalse(dataset_job.is_cron())
+ dataset_job.time_range = timedelta(days=1)
+ self.assertTrue(dataset_job.is_cron())
+
+ def test_is_daily_cron(self):
+ dataset_job = DatasetJob(id=1,
+ name='test_dataset_job',
+ coordinator_id=2,
+ uuid='uuid',
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING)
+ self.assertFalse(dataset_job.is_daily_cron())
+ dataset_job.time_range = timedelta(days=1)
+ self.assertTrue(dataset_job.is_daily_cron())
+ dataset_job.time_range = timedelta(hours=1)
+ self.assertFalse(dataset_job.is_daily_cron())
+
+ def test_is_hourly_cron(self):
+ dataset_job = DatasetJob(id=1,
+ name='test_dataset_job',
+ coordinator_id=2,
+ uuid='uuid',
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING)
+ self.assertFalse(dataset_job.is_hourly_cron())
+ dataset_job.time_range = timedelta(days=1)
+ self.assertFalse(dataset_job.is_hourly_cron())
+ dataset_job.time_range = timedelta(hours=1)
+ self.assertTrue(dataset_job.is_hourly_cron())
+
+ def test_set_scheduler_message(self):
+ scheduler_message = '调度信息 🦻'
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ name='test_dataset_job',
+ input_dataset_id=1,
+ output_dataset_id=2,
+ coordinator_id=0,
+ uuid='uuid',
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING)
+ dataset_job.set_scheduler_message(scheduler_message=scheduler_message)
+ session.add(dataset_job)
+ session.commit()
+
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.get_context().scheduler_message, scheduler_message)
+
+
+class DatasetJobStageTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(uuid='dataset_job uuid',
+ project_id=1,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ created_at=now())
+ session.add(dataset_job)
+ job_stage = DatasetJobStage(id=1,
+ uuid='uuid_1',
+ name='default dataset job stage',
+ project_id=1,
+ workflow_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2012, 1, 15),
+ state=DatasetJobState.PENDING,
+ coordinator_id=0)
+ session.add(job_stage)
+ session.commit()
+
+ def test_get_set_global_configs(self):
+ with db.session_scope() as session:
+ job_stage: DatasetJobStage = session.query(DatasetJobStage).get(1)
+ global_configs = DatasetJobGlobalConfigs()
+ global_configs.global_configs['test'].MergeFrom(
+ DatasetJobConfig(dataset_uuid=resource_uuid(),
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ job_stage.set_global_configs(global_configs)
+ new_global_configs = job_stage.get_global_configs()
+ self.assertEqual(new_global_configs, global_configs)
+
+ def test_to_ref(self):
+ with db.session_scope() as session:
+ job_stage: DatasetJobStage = session.query(DatasetJobStage).get(1)
+ job_stage_ref = job_stage.to_ref()
+ self.assertEqual(
+ to_dict(job_stage_ref), {
+ 'id': 1,
+ 'name': 'default dataset job stage',
+ 'dataset_job_id': 1,
+ 'output_data_batch_id': 1,
+ 'project_id': 1,
+ 'state': DatasetJobState.PENDING.name,
+ 'created_at': to_timestamp(datetime(2012, 1, 14, 12, 0, 5)),
+ 'kind': DatasetJobKind.DATA_ALIGNMENT.name,
+ })
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ job_stage: DatasetJobStage = session.query(DatasetJobStage).get(1)
+ context = DatasetJobStageContext(batch_stats_item_name='batch_stats_item_1',
+ input_data_batch_num_example=100,
+ output_data_batch_num_example=50,
+ scheduler_message='错误信息 ✖️')
+ job_stage.set_context(context=context)
+ job_stage_proto = job_stage.to_proto()
+ self.assertPartiallyEqual(
+ to_dict(job_stage_proto),
+ {
+ 'id': 1,
+ 'name': 'default dataset job stage',
+ 'uuid': 'uuid_1',
+ 'dataset_job_id': 1,
+ 'output_data_batch_id': 1,
+ 'workflow_id': 1,
+ 'project_id': 1,
+ 'state': DatasetJobState.PENDING.name,
+ 'event_time': to_timestamp(datetime(2012, 1, 15)),
+ 'created_at': to_timestamp(datetime(2012, 1, 14, 12, 0, 5)),
+ 'dataset_job_uuid': 'dataset_job uuid',
+ 'started_at': 0,
+ 'finished_at': 0,
+ 'is_ready': False,
+ 'kind': DatasetJobKind.DATA_ALIGNMENT.name,
+ 'input_data_batch_num_example': 100,
+ 'output_data_batch_num_example': 50,
+ 'scheduler_message': '错误信息 ✖️',
+ },
+ ignore_fields=['updated_at'],
+ )
+
+ def test_set_and_get_context(self):
+ with db.session_scope() as session:
+ job_stage: DatasetJobStage = session.query(DatasetJobStage).get(1)
+ empty_context = job_stage.get_context()
+ self.assertEqual(empty_context, DatasetJobStageContext())
+ context = DatasetJobStageContext(batch_stats_item_name='batch_stats_item_1',
+ input_data_batch_num_example=100,
+ output_data_batch_num_example=50)
+ job_stage.set_context(context=context)
+ text_context = 'batch_stats_item_name: "batch_stats_item_1"\n' \
+ 'input_data_batch_num_example: 100\noutput_data_batch_num_example: 50\n'
+ self.assertEqual(job_stage.context, text_context)
+ target_context = job_stage.get_context()
+ self.assertEqual(target_context, context)
+
+ def test_set_scheduler_message(self):
+ scheduler_message = '错误信息 ❌'
+ with db.session_scope() as session:
+ job_stage: DatasetJobStage = session.query(DatasetJobStage).get(1)
+ empty_context = job_stage.get_context()
+ self.assertEqual(empty_context, DatasetJobStageContext())
+ job_stage.set_scheduler_message(scheduler_message=scheduler_message)
+ session.commit()
+
+ with db.session_scope() as session:
+ job_stage: DatasetJobStage = session.query(DatasetJobStage).get(1)
+ self.assertEqual(job_stage.get_context().scheduler_message, scheduler_message)
+
+ def test_is_coordinator(self):
+ with db.session_scope() as session:
+ job_stage: DatasetJobStage = session.query(DatasetJobStage).get(1)
+ self.assertTrue(job_stage.is_coordinator())
+ job_stage.coordinator_id = 1
+ self.assertFalse(job_stage.is_coordinator())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/BUILD.bazel
new file mode 100644
index 000000000..672d3773a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/BUILD.bazel
@@ -0,0 +1,250 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "scheduler",
+ srcs = [
+ "dataset_long_period_scheduler.py",
+ "dataset_short_period_scheduler.py",
+ ],
+ imports = ["../../.."],
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:local_controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "dataset_long_period_scheduler_test",
+ size = "small",
+ srcs = [
+ "dataset_long_period_scheduler_test.py",
+ ],
+ imports = ["../../.."],
+ main = "dataset_long_period_scheduler_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ ":scheduler",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_test(
+ name = "dataset_short_period_scheduler_test",
+ size = "small",
+ srcs = [
+ "dataset_short_period_scheduler_test.py",
+ ],
+ imports = ["../../.."],
+ main = "dataset_short_period_scheduler_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ ":scheduler",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "base_executor_lib",
+ srcs = [
+ "base_executor.py",
+ ],
+ imports = ["../../.."],
+ deps = [
+ ":consts_lib",
+ ],
+)
+
+py_library(
+ name = "executors_lib",
+ srcs = [
+ "chained_executor.py",
+ "cron_dataset_job_executor.py",
+ "dataset_job_executor.py",
+ "pending_dataset_job_stage_executor.py",
+ "running_dataset_job_stage_executor.py",
+ "update_auth_status_executor.py",
+ ],
+ imports = ["../../.."],
+ deps = [
+ ":base_executor_lib",
+ ":consts_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/cleanup:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/cleanup:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:local_controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:resource_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "chained_executor_test",
+ size = "small",
+ srcs = [
+ "chained_executor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "chained_executor_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ ],
+)
+
+py_test(
+ name = "cron_dataset_job_executor_test",
+ size = "small",
+ srcs = [
+ "cron_dataset_job_executor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "cron_dataset_job_executor_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_test(
+ name = "update_auth_status_executor_test",
+ size = "small",
+ srcs = [
+ "update_auth_status_executor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "update_auth_status_executor_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_test(
+ name = "dataset_job_executor_test",
+ size = "medium",
+ srcs = [
+ "dataset_job_executor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "dataset_job_executor_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_test(
+ name = "pending_dataset_job_stage_executor_test",
+ size = "small",
+ srcs = [
+ "pending_dataset_job_stage_executor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "pending_dataset_job_stage_executor_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_test(
+ name = "running_dataset_job_stage_executor_test",
+ size = "medium",
+ srcs = [
+ "running_dataset_job_stage_executor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "running_dataset_job_stage_executor_test.py",
+ deps = [
+ ":consts_lib",
+ ":executors_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "consts_lib",
+ srcs = [
+ "consts.py",
+ ],
+ imports = ["../../.."],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/base_executor.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/base_executor.py
new file mode 100644
index 000000000..7f5f3c9ad
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/base_executor.py
@@ -0,0 +1,40 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+from typing import List
+
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+
+
+class BaseExecutor(metaclass=abc.ABCMeta):
+
+ @abc.abstractmethod
+ def get_item_ids(self) -> List[int]:
+ """Get all items id should be processed in this executor
+
+ Returns:
+ List[int]: all items id
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def run_item(self, item_id: int) -> ExecutorResult:
+ """process item by given id
+
+ Returns:
+ ExecutorResult
+ """
+ raise NotImplementedError()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/chained_executor.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/chained_executor.py
new file mode 100644
index 000000000..56335b3e7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/chained_executor.py
@@ -0,0 +1,57 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorType, ExecutorResult
+from fedlearner_webconsole.dataset.scheduler.base_executor import BaseExecutor
+from fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor import CronDatasetJobExecutor
+from fedlearner_webconsole.dataset.scheduler.update_auth_status_executor import UpdateAuthStatusExecutor
+from fedlearner_webconsole.dataset.scheduler.dataset_job_executor import DatasetJobExecutor
+from fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor import PendingDatasetJobStageExecutor
+from fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor import RunningDatasetJobStageExecutor
+from fedlearner_webconsole.proto.composer_pb2 import ExecutorResults
+
+
+def _get_executor(executor_type: ExecutorType) -> BaseExecutor:
+ executor_map = {
+ ExecutorType.CRON_DATASET_JOB: CronDatasetJobExecutor,
+ ExecutorType.UPDATE_AUTH_STATUS: UpdateAuthStatusExecutor,
+ ExecutorType.DATASET_JOB: DatasetJobExecutor,
+ ExecutorType.PENDING_DATASET_JOB_STAGE: PendingDatasetJobStageExecutor,
+ ExecutorType.RUNNING_DATASET_JOB_STAGE: RunningDatasetJobStageExecutor,
+ }
+ return executor_map.get(executor_type)()
+
+
+def run_executor(executor_type: ExecutorType) -> ExecutorResults:
+ executor = _get_executor(executor_type=executor_type)
+ item_ids = executor.get_item_ids()
+ succeeded_items = []
+ failed_items = []
+ skip_items = []
+ for item_id in item_ids:
+ try:
+ executor_result = executor.run_item(item_id=item_id)
+ if executor_result == ExecutorResult.SUCCEEDED:
+ succeeded_items.append(item_id)
+ elif executor_result == ExecutorResult.FAILED:
+ failed_items.append(item_id)
+ else:
+ skip_items.append(item_id)
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception(f'[Dataset ChainedExecutor] failed to run {item_id}, executor_type: {executor_type.name}')
+ failed_items.append(item_id)
+ return ExecutorResults(succeeded_item_ids=succeeded_items, failed_item_ids=failed_items, skip_item_ids=skip_items)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/chained_executor_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/chained_executor_test.py
new file mode 100644
index 000000000..1dfd5ff90
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/chained_executor_test.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch
+from fedlearner_webconsole.dataset.scheduler.chained_executor import run_executor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorType
+from testing.dataset import FakeExecutor
+
+
+class ChainedExecutorTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.dataset.scheduler.chained_executor._get_executor')
+ def test_run_executor(self, mock_get_executor: MagicMock):
+ mock_get_executor.return_value = FakeExecutor()
+ executor_resutls = run_executor(executor_type=ExecutorType.CRON_DATASET_JOB)
+ self.assertEqual(executor_resutls.succeeded_item_ids, [1])
+ self.assertEqual(executor_resutls.failed_item_ids, [2, 4])
+ self.assertEqual(executor_resutls.skip_item_ids, [3])
+ mock_get_executor.assert_called_once_with(executor_type=ExecutorType.CRON_DATASET_JOB)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/consts.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/consts.py
new file mode 100644
index 000000000..278f93852
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/consts.py
@@ -0,0 +1,30 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+
+
+class ExecutorType(enum.Enum):
+ CRON_DATASET_JOB = 'CRON_DATASET_JOB'
+ UPDATE_AUTH_STATUS = 'UPDATE_AUTH_STATUS'
+ PENDING_DATASET_JOB_STAGE = 'PENDING_DATASET_JOB_STAGE'
+ RUNNING_DATASET_JOB_STAGE = 'RUNNING_DATASET_JOB_STAGE'
+ DATASET_JOB = 'DATASET_JOB'
+
+
+class ExecutorResult(enum.Enum):
+ SUCCEEDED = 'SUCCEEDED'
+ FAILED = 'FAILED'
+ SKIP = 'SKIP'
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/cron_dataset_job_executor.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/cron_dataset_job_executor.py
new file mode 100644
index 000000000..9f941dc46
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/cron_dataset_job_executor.py
@@ -0,0 +1,140 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List, Optional
+from datetime import datetime, timezone
+import logging
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.scheduler.base_executor import BaseExecutor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, DatasetJob, DatasetJobSchedulerState, \
+ DatasetKindV2
+from fedlearner_webconsole.dataset.local_controllers import DatasetJobStageLocalController
+from fedlearner_webconsole.dataset.util import get_oldest_daily_folder_time, parse_event_time_to_daily_folder_name, \
+ check_batch_folder_ready, get_oldest_hourly_folder_time, parse_event_time_to_hourly_folder_name, \
+ get_hourly_folder_not_ready_err_msg, get_daily_folder_not_ready_err_msg, get_hourly_batch_not_ready_err_msg, \
+ get_daily_batch_not_ready_err_msg, get_certain_folder_not_ready_err_msg, get_certain_batch_not_ready_err_msg, \
+ get_cron_succeeded_msg
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.utils.pp_datetime import now
+
+
+class CronDatasetJobExecutor(BaseExecutor):
+
+ def _get_next_event_time(self, session: Session, runnable_dataset_job: DatasetJob) -> Optional[datetime]:
+ """get next event_time.
+
+ 1. If current event_time is None, next event_time is oldest event_time for input_dataset
+ 2. If current event_time is not None, next event_time is event_time + time_range
+ """
+ if runnable_dataset_job.event_time is None:
+ input_dataset: Dataset = runnable_dataset_job.input_dataset
+ if input_dataset.dataset_kind == DatasetKindV2.SOURCE:
+ if runnable_dataset_job.is_hourly_cron():
+ return get_oldest_hourly_folder_time(input_dataset.path)
+ return get_oldest_daily_folder_time(input_dataset.path)
+ oldest_data_batch = session.query(DataBatch).filter(
+ DataBatch.dataset_id == runnable_dataset_job.input_dataset_id).order_by(
+ DataBatch.event_time.asc()).first()
+ return oldest_data_batch.event_time if oldest_data_batch else None
+ return runnable_dataset_job.event_time + runnable_dataset_job.time_range
+
+ def _should_run(self, session: Session, runnable_dataset_job: DatasetJob, next_event_time: datetime) -> bool:
+ """check dependence to decide whether should create next stage.
+
+ 1. for input_dataset is data_source, check folder exists and _SUCCESS file exists
+ 2. for input_dataset is dataset, check data_batch exists and state is SUCCEEDED
+ """
+ input_dataset: Dataset = runnable_dataset_job.input_dataset
+ if input_dataset.dataset_kind == DatasetKindV2.SOURCE:
+ if runnable_dataset_job.is_hourly_cron():
+ batch_name = parse_event_time_to_hourly_folder_name(next_event_time)
+ else:
+ batch_name = parse_event_time_to_daily_folder_name(next_event_time)
+ return check_batch_folder_ready(folder=input_dataset.path, batch_name=batch_name)
+ data_batch: DataBatch = session.query(DataBatch).filter(
+ DataBatch.dataset_id == runnable_dataset_job.input_dataset_id).filter(
+ DataBatch.event_time == next_event_time).first()
+ if data_batch is None:
+ return False
+ return data_batch.is_available()
+
+ def get_item_ids(self) -> List[int]:
+ with db.session_scope() as session:
+ runnable_dataset_job_ids = session.query(
+ DatasetJob.id).filter(DatasetJob.scheduler_state == DatasetJobSchedulerState.RUNNABLE).all()
+ return [runnable_dataset_job_id for runnable_dataset_job_id, *_ in runnable_dataset_job_ids]
+
+ def run_item(self, item_id: int) -> ExecutorResult:
+ with db.session_scope() as session:
+ # we set isolation_level to SERIALIZABLE to make sure state won't be changed within this session
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ runnable_dataset_job: DatasetJob = session.query(DatasetJob).get(item_id)
+ if runnable_dataset_job.scheduler_state != DatasetJobSchedulerState.RUNNABLE:
+ logging.warning('dataset_job scheduler_state is not runnable, ' \
+ f'dataset_job id: {item_id}')
+ return ExecutorResult.SKIP
+ # check authorization
+ if not AuthService(session=session, dataset_job=runnable_dataset_job).check_participants_authorized():
+ message = '[cron_dataset_job_executor] still waiting for participants authorized, ' \
+ f'dataset_job_id: {item_id}'
+ logging.warning(message)
+ return ExecutorResult.SKIP
+
+ next_event_time = self._get_next_event_time(session=session, runnable_dataset_job=runnable_dataset_job)
+ if next_event_time is None:
+ if runnable_dataset_job.input_dataset.dataset_kind == DatasetKindV2.SOURCE:
+ logging.warning(f'input_dataset has no matched streaming folder, dataset_job id: {item_id}')
+ err_msg = get_hourly_folder_not_ready_err_msg() if runnable_dataset_job.is_hourly_cron() \
+ else get_daily_folder_not_ready_err_msg()
+ runnable_dataset_job.set_scheduler_message(scheduler_message=err_msg)
+ else:
+ logging.warning(f'input_dataset has no matched batch, dataset_job id: {item_id}')
+ err_msg = get_hourly_batch_not_ready_err_msg() if runnable_dataset_job.is_hourly_cron() \
+ else get_daily_batch_not_ready_err_msg()
+ runnable_dataset_job.set_scheduler_message(scheduler_message=err_msg)
+ session.commit()
+ return ExecutorResult.SKIP
+ # if next_event_time is 20220801, we wouldn't schedule it until 2022-08-01 00:00:00
+ if next_event_time.replace(tzinfo=timezone.utc) > now():
+ return ExecutorResult.SKIP
+ next_batch_name = parse_event_time_to_hourly_folder_name(event_time=next_event_time) \
+ if runnable_dataset_job.is_hourly_cron() \
+ else parse_event_time_to_daily_folder_name(event_time=next_event_time)
+ if not self._should_run(
+ session=session, runnable_dataset_job=runnable_dataset_job, next_event_time=next_event_time):
+ if runnable_dataset_job.input_dataset.dataset_kind == DatasetKindV2.SOURCE:
+ runnable_dataset_job.set_scheduler_message(scheduler_message=get_certain_folder_not_ready_err_msg(
+ folder_name=next_batch_name))
+ else:
+ runnable_dataset_job.set_scheduler_message(scheduler_message=get_certain_batch_not_ready_err_msg(
+ batch_name=next_batch_name))
+ logging.info(
+ f'[cron_dataset_job_executor] dataset job {item_id} is not should run, ' \
+ f'next_event_time: {next_event_time.strftime("%Y%m%d")}'
+ )
+ session.commit()
+ return ExecutorResult.SKIP
+ DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage_as_coordinator(
+ dataset_job_id=item_id,
+ global_configs=runnable_dataset_job.get_global_configs(),
+ event_time=next_event_time)
+ runnable_dataset_job.event_time = next_event_time
+ runnable_dataset_job.set_scheduler_message(scheduler_message=get_cron_succeeded_msg(
+ batch_name=next_batch_name))
+ session.commit()
+ return ExecutorResult.SUCCEEDED
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/cron_dataset_job_executor_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/cron_dataset_job_executor_test.py
new file mode 100644
index 000000000..c95ca9f1e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/cron_dataset_job_executor_test.py
@@ -0,0 +1,365 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=protected-access
+import unittest
+from unittest.mock import patch, MagicMock
+from datetime import datetime, timedelta
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import DataBatch, ResourceState, Dataset, DatasetJob, \
+ DatasetJobKind, DatasetJobSchedulerState, DatasetJobStage, DatasetJobState, DatasetKindV2, DatasetType
+from fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor import CronDatasetJobExecutor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.proto import dataset_pb2
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class CronDatasetJobExecutorTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _WORKFLOW_ID = 1
+ _INPUT_DATASET_ID = 1
+ _OUTPUT_DATASET_ID = 2
+
+ def test_get_item_ids(self):
+ with db.session_scope() as session:
+ dataset_job_1 = DatasetJob(id=1,
+ uuid='dataset_job_1',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE,
+ time_range=timedelta(days=1))
+ dataset_job_1.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_1)
+ dataset_job_2 = DatasetJob(id=2,
+ uuid='dataset_job_2',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=3,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.STOPPED,
+ time_range=timedelta(days=1))
+ dataset_job_2.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_2)
+ dataset_job_3 = DatasetJob(id=3,
+ uuid='dataset_job_3',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=4,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.PENDING,
+ time_range=timedelta(days=1))
+ dataset_job_3.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_3)
+ session.commit()
+ cron_dataset_job_executor = CronDatasetJobExecutor()
+ self.assertEqual(cron_dataset_job_executor.get_item_ids(), [1])
+
+
+ @patch(
+ 'fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor.CronDatasetJobExecutor.'\
+ '_get_next_event_time'
+ )
+ @patch('fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor.CronDatasetJobExecutor._should_run')
+ def test_run_item(self, mock_should_run: MagicMock, mock_get_next_event_time: MagicMock):
+ with db.session_scope() as session:
+ dataset_job_1 = DatasetJob(id=1,
+ uuid='dataset_job_1',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE,
+ time_range=timedelta(days=1))
+ dataset_job_1.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_1)
+ dataset_job_2 = DatasetJob(id=2,
+ uuid='dataset_job_2',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=3,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.STOPPED,
+ time_range=timedelta(days=1))
+ dataset_job_2.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_2)
+ input_dataset = Dataset(id=self._INPUT_DATASET_ID,
+ uuid='dataset_uuid',
+ name='default dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(input_dataset)
+ output_dataset = Dataset(id=self._OUTPUT_DATASET_ID,
+ uuid='dataset_uuid',
+ name='default dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(output_dataset)
+ session.commit()
+
+ cron_dataset_job_executor = CronDatasetJobExecutor()
+
+ # test next_event_time bigger than now
+ mock_should_run.return_value = True
+ mock_get_next_event_time.return_value = datetime(2100, 1, 1)
+ executor_result = cron_dataset_job_executor.run_item(item_id=1)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ mock_get_next_event_time.assert_called_once()
+
+ # test should_run is false
+ mock_should_run.reset_mock()
+ mock_get_next_event_time.reset_mock()
+ mock_should_run.return_value = False
+ mock_get_next_event_time.return_value = datetime(2022, 1, 1)
+ executor_result = cron_dataset_job_executor.run_item(item_id=1)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ mock_get_next_event_time.assert_called_once()
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.get_context().scheduler_message, '数据批次20220101检查失败,请确认该批次命名格式及状态')
+
+ # test should_run is True
+ mock_should_run.reset_mock()
+ mock_get_next_event_time.reset_mock()
+ mock_should_run.return_value = True
+ mock_get_next_event_time.return_value = datetime(2022, 1, 1)
+ executor_result = cron_dataset_job_executor.run_item(item_id=1)
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ mock_get_next_event_time.assert_called_once()
+
+ with db.session_scope() as session:
+ data_batch = session.query(DataBatch).get(1)
+ self.assertEqual(data_batch.event_time, datetime(2022, 1, 1))
+ dataset_job_stage = session.query(DatasetJobStage).get(1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(2022, 1, 1))
+ dataset_job = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.event_time, datetime(2022, 1, 1))
+ self.assertEqual(dataset_job.get_context().scheduler_message, '已成功发起20220101批次处理任务')
+
+ @patch('fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor.get_oldest_hourly_folder_time')
+ @patch('fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor.get_oldest_daily_folder_time')
+ def test_get_next_event_time(self, mock_get_oldest_daily_folder_time: MagicMock,
+ mock_get_oldest_hourly_folder_time: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job_1',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE,
+ time_range=timedelta(days=1))
+ session.add(dataset_job)
+ input_dataset = Dataset(id=self._INPUT_DATASET_ID,
+ uuid='dataset_uuid',
+ name='default dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.SOURCE,
+ is_published=True)
+ session.add(input_dataset)
+ session.commit()
+ cron_dataset_job_executor = CronDatasetJobExecutor()
+ with db.session_scope() as session:
+ # test input_dataset is source
+ mock_get_oldest_daily_folder_time.return_value = datetime(2022, 8, 1)
+ dataset_job = session.query(DatasetJob).get(1)
+ next_event_time = cron_dataset_job_executor._get_next_event_time(session=session,
+ runnable_dataset_job=dataset_job)
+ mock_get_oldest_daily_folder_time.assert_called_once_with('/data/dataset/123')
+ self.assertEqual(next_event_time, datetime(2022, 8, 1))
+
+ # test input_dataset is not source but has no batch
+ mock_get_oldest_daily_folder_time.reset_mock()
+ dataset_job.input_dataset.dataset_kind = DatasetKindV2.RAW
+ next_event_time = cron_dataset_job_executor._get_next_event_time(session=session,
+ runnable_dataset_job=dataset_job)
+ mock_get_oldest_daily_folder_time.assert_not_called()
+ self.assertIsNone(next_event_time)
+
+ # test input_dataset is not source and has batch
+ data_batch_1 = DataBatch(id=1,
+ name='20220801',
+ dataset_id=self._INPUT_DATASET_ID,
+ path='/data/test/batch/20220801',
+ event_time=datetime(2022, 8, 1))
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(id=2,
+ name='20220802',
+ dataset_id=self._INPUT_DATASET_ID,
+ path='/data/test/batch/20220802',
+ event_time=datetime(2022, 8, 2))
+ session.add(data_batch_2)
+ session.flush()
+ next_event_time = cron_dataset_job_executor._get_next_event_time(session=session,
+ runnable_dataset_job=dataset_job)
+ self.assertEqual(next_event_time, datetime(2022, 8, 1))
+
+ # test dataset_job already has event_time
+ dataset_job.event_time = datetime(2022, 8, 1)
+ next_event_time = cron_dataset_job_executor._get_next_event_time(session=session,
+ runnable_dataset_job=dataset_job)
+ self.assertEqual(next_event_time, datetime(2022, 8, 2))
+
+ # test hourly level
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.time_range = timedelta(hours=1)
+ session.commit()
+ with db.session_scope() as session:
+ # test input_dataset is source
+ mock_get_oldest_hourly_folder_time.return_value = datetime(2022, 8, 1, 8)
+ dataset_job = session.query(DatasetJob).get(1)
+ next_event_time = cron_dataset_job_executor._get_next_event_time(session=session,
+ runnable_dataset_job=dataset_job)
+ mock_get_oldest_hourly_folder_time.assert_called_once_with('/data/dataset/123')
+ self.assertEqual(next_event_time, datetime(2022, 8, 1, 8))
+
+ # test input_dataset is not source and has batch
+ data_batch_1 = DataBatch(id=1,
+ name='20220801',
+ dataset_id=self._INPUT_DATASET_ID,
+ path='/data/test/batch/20220801',
+ event_time=datetime(2022, 8, 1, 8))
+ session.add(data_batch_1)
+ data_batch_2 = DataBatch(id=2,
+ name='20220802',
+ dataset_id=self._INPUT_DATASET_ID,
+ path='/data/test/batch/20220802',
+ event_time=datetime(2022, 8, 1, 9))
+ session.add(data_batch_2)
+ session.flush()
+ next_event_time = cron_dataset_job_executor._get_next_event_time(session=session,
+ runnable_dataset_job=dataset_job)
+ self.assertEqual(next_event_time, datetime(2022, 8, 1, 8))
+
+ # test dataset_job already has event_time
+ dataset_job.event_time = datetime(2022, 8, 1, 8)
+ next_event_time = cron_dataset_job_executor._get_next_event_time(session=session,
+ runnable_dataset_job=dataset_job)
+ self.assertEqual(next_event_time, datetime(2022, 8, 1, 9))
+
+ @patch('fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor.check_batch_folder_ready')
+ @patch('fedlearner_webconsole.dataset.models.DataBatch.get_frontend_state')
+ def test_should_run(self, mock_get_frontend_state: MagicMock, mock_check_batch_folder_ready: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job uuid',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=0,
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE,
+ time_range=timedelta(days=1))
+ dataset_job.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job)
+ input_dataset = Dataset(id=self._INPUT_DATASET_ID,
+ uuid='dataset_uuid',
+ name='default dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(input_dataset)
+ data_batch = DataBatch(id=1,
+ name='20220701',
+ dataset_id=self._INPUT_DATASET_ID,
+ path='/data/test/batch/20220701',
+ event_time=datetime.strptime('20220701', '%Y%m%d'),
+ file_size=100,
+ num_example=10,
+ num_feature=3,
+ latest_parent_dataset_job_stage_id=1)
+ session.add(data_batch)
+ session.commit()
+
+ cron_dataset_job_executor = CronDatasetJobExecutor()
+
+ # test dataset
+ with db.session_scope() as session:
+ # test no data_batch
+ dataset_job = session.query(DatasetJob).get(1)
+ mock_get_frontend_state.return_value = ResourceState.SUCCEEDED
+ self.assertFalse(cron_dataset_job_executor._should_run(session, dataset_job, datetime(2022, 7, 2)))
+ mock_get_frontend_state.reset_mock()
+ # test data_batch frontend state not succeeded
+ mock_get_frontend_state.return_value = ResourceState.FAILED
+ self.assertFalse(cron_dataset_job_executor._should_run(session, dataset_job, datetime(2022, 7, 1)))
+ mock_get_frontend_state.reset_mock()
+ # test should run
+ mock_get_frontend_state.return_value = ResourceState.SUCCEEDED
+ self.assertTrue(cron_dataset_job_executor._should_run(session, dataset_job, datetime(2022, 7, 1)))
+ mock_check_batch_folder_ready.assert_not_called()
+ mock_get_frontend_state.reset_mock()
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.input_dataset.dataset_kind = DatasetKindV2.SOURCE
+ session.flush()
+ # test streaming_folder not ready
+ mock_check_batch_folder_ready.return_value = False
+ self.assertFalse(cron_dataset_job_executor._should_run(session, dataset_job, datetime(2022, 7, 1)))
+ mock_check_batch_folder_ready.assert_called_once_with(folder='/data/dataset/123', batch_name='20220701')
+ mock_check_batch_folder_ready.reset_mock()
+ # test should run
+ mock_check_batch_folder_ready.return_value = True
+ self.assertTrue(cron_dataset_job_executor._should_run(session, dataset_job, datetime(2022, 7, 1)))
+ mock_check_batch_folder_ready.assert_called_once_with(folder='/data/dataset/123', batch_name='20220701')
+ mock_get_frontend_state.assert_not_called()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_job_executor.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_job_executor.py
new file mode 100644
index 000000000..f568a4da2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_job_executor.py
@@ -0,0 +1,102 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import List
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.scheduler.base_executor import BaseExecutor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobSchedulerState
+from fedlearner_webconsole.dataset.services import DatasetJobService
+from fedlearner_webconsole.dataset.job_configer.dataset_job_configer import DatasetJobConfiger
+from fedlearner_webconsole.dataset.local_controllers import DatasetJobStageLocalController
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.flag.models import Flag
+
+
+class DatasetJobExecutor(BaseExecutor):
+
+ def get_item_ids(self) -> List[int]:
+ with db.session_scope() as session:
+ pending_dataset_job_ids = session.query(
+ DatasetJob.id).filter(DatasetJob.scheduler_state == DatasetJobSchedulerState.PENDING).all()
+ return [pending_dataset_job_id for pending_dataset_job_id, *_ in pending_dataset_job_ids]
+
+ def run_item(self, item_id: int) -> ExecutorResult:
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(item_id)
+ # TODO(liuhehan): remove this func after remove DATASET_JOB_STAGE_ENABLED flag
+ if not dataset_job.get_context().has_stages:
+ dataset_job.scheduler_state = DatasetJobSchedulerState.STOPPED
+ session.commit()
+ return ExecutorResult.SKIP
+ if dataset_job.scheduler_state != DatasetJobSchedulerState.PENDING:
+ return ExecutorResult.SKIP
+ if dataset_job.output_dataset.ticket_status != TicketStatus.APPROVED:
+ return ExecutorResult.SKIP
+ if dataset_job.is_coordinator():
+ # create participant dataset_job
+ participants = DatasetJobService(session).get_participants_need_distribute(dataset_job)
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(dataset_job.project.name, dataset_job.project.token,
+ participant.domain_name)
+ dataset_job_parameter = dataset_job.to_proto()
+ dataset_job_parameter.workflow_definition.MergeFrom(
+ DatasetJobConfiger.from_kind(dataset_job.kind, session).get_config())
+ dataset_parameter = dataset_pb2.Dataset(
+ participants_info=dataset_job.output_dataset.get_participants_info())
+ client.create_dataset_job(dataset_job=dataset_job_parameter,
+ ticket_uuid=dataset_job.output_dataset.ticket_uuid,
+ dataset=dataset_parameter)
+ # check flags, if participants donot check authstatus, just set authorized
+ system_client = SystemServiceClient.from_participant(domain_name=participant.domain_name)
+ flag_resp = system_client.list_flags()
+ if not flag_resp.get(Flag.DATASET_AUTH_STATUS_CHECK_ENABLED.name):
+ AuthService(session=session, dataset_job=dataset_job).update_auth_status(
+ domain_name=participant.pure_domain_name(), auth_status=AuthStatus.AUTHORIZED)
+ else:
+ # participant scheduler state always set to stopped,
+ # and never created data_batch and dataset_job_stage itself
+ dataset_job.scheduler_state = DatasetJobSchedulerState.STOPPED
+ session.commit()
+ with db.session_scope() as session:
+ # we set isolation_level to SERIALIZABLE to make sure state won't be changed within this session
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ dataset_job: DatasetJob = session.query(DatasetJob).get(item_id)
+ if dataset_job.scheduler_state != DatasetJobSchedulerState.PENDING:
+ return ExecutorResult.SKIP
+ if dataset_job.is_cron():
+ # if dataset_job is cron, we set scheduler state to runnable,
+ # and it will be scheduler again by cron_dataset_job_executor
+ dataset_job.scheduler_state = DatasetJobSchedulerState.RUNNABLE
+ else:
+ if dataset_job.get_context().need_create_stage:
+ # check authorization
+ if not AuthService(session=session, dataset_job=dataset_job).check_participants_authorized():
+ message = '[dataset_job_executor] still waiting for participants authorized, ' \
+ f'dataset_job_id: {item_id}'
+ logging.warning(message)
+ return ExecutorResult.SKIP
+ DatasetJobStageLocalController(session).create_data_batch_and_job_stage_as_coordinator(
+ dataset_job_id=dataset_job.id, global_configs=dataset_job.get_global_configs())
+ dataset_job.scheduler_state = DatasetJobSchedulerState.STOPPED
+ session.commit()
+ return ExecutorResult.SUCCEEDED
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_job_executor_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_job_executor_test.py
new file mode 100644
index 000000000..6c0d91ed8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_job_executor_test.py
@@ -0,0 +1,264 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import PropertyMock, patch, MagicMock
+from datetime import timedelta
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetJobSchedulerState, \
+ DatasetJobStage, DatasetJobState, DatasetKindV2, DatasetType
+from fedlearner_webconsole.dataset.scheduler.dataset_job_executor import DatasetJobExecutor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.job_configer.ot_psi_data_join_configer import OtPsiDataJoinConfiger
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.flag.models import _Flag
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.proto import dataset_pb2, project_pb2
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class DatasetJobExecutorTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _PARTICIPANT_ID = 1
+ _WORKFLOW_ID = 1
+ _INPUT_DATASET_ID = 1
+ _OUTPUT_DATASET_ID = 2
+ _DATASET_JOB_ID = 1
+
+ def test_get_item_ids(self):
+ with db.session_scope() as session:
+ dataset_job_1 = DatasetJob(id=1,
+ uuid='dataset_job_1',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE,
+ time_range=timedelta(days=1))
+ dataset_job_1.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_1)
+ dataset_job_2 = DatasetJob(id=2,
+ uuid='dataset_job_2',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=3,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.STOPPED,
+ time_range=timedelta(days=1))
+ dataset_job_2.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_2)
+ dataset_job_3 = DatasetJob(id=3,
+ uuid='dataset_job_3',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=4,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.PENDING,
+ time_range=timedelta(days=1))
+ dataset_job_3.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job_3)
+ session.commit()
+ executor = DatasetJobExecutor()
+ self.assertEqual(executor.get_item_ids(), [3])
+
+ @patch('fedlearner_webconsole.flag.models.Flag.DATASET_AUTH_STATUS_CHECK_ENABLED', new_callable=PropertyMock)
+ @patch('fedlearner_webconsole.dataset.scheduler.dataset_job_executor.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.dataset.local_controllers.DatasetJobStageLocalController.'\
+ 'create_data_batch_and_job_stage_as_coordinator')
+ @patch('fedlearner_webconsole.dataset.scheduler.dataset_job_executor.RpcClient.create_dataset_job')
+ def test_run_item(self, mock_create_dataset_job: MagicMock,
+ mock_create_data_batch_and_job_stage_as_coordinator: MagicMock, mock_list_flags: MagicMock,
+ mock_dataset_auth_status_check_enabled: MagicMock):
+ with db.session_scope() as session:
+ # pylint: disable=protected-access
+ _insert_or_update_templates(session)
+ dataset_job_1 = DatasetJob(id=self._PARTICIPANT_ID,
+ uuid='dataset_job_1',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=self._WORKFLOW_ID,
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE,
+ time_range=timedelta(days=1))
+ dataset_job_1.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={'test_participant_2': dataset_pb2.DatasetJobConfig()}))
+ dataset_job_1.set_context(dataset_pb2.DatasetJobContext(has_stages=True))
+ session.add(dataset_job_1)
+ input_dataset = Dataset(id=self._INPUT_DATASET_ID,
+ uuid='dataset_uuid',
+ name='default dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(input_dataset)
+ output_dataset = Dataset(id=self._OUTPUT_DATASET_ID,
+ uuid='dataset_uuid',
+ name='default dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True,
+ ticket_status=TicketStatus.APPROVED)
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_participant_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_participant_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ output_dataset.set_participants_info(participants_info=participants_info)
+ session.add(output_dataset)
+ project = Project(id=self._PROJECT_ID, name='test-project')
+ session.add(project)
+ participant = Participant(id=self._PARTICIPANT_ID,
+ name='participant_1',
+ domain_name='fl-test_participant_2.com')
+ project_participant = ProjectParticipant(project_id=self._PROJECT_ID, participant_id=self._PARTICIPANT_ID)
+ session.add_all([participant, project_participant])
+ session.commit()
+
+ mock_list_flags.return_value = {'dataset_auth_status_check_enabled': True}
+ mock_dataset_auth_status_check_enabled.return_value = _Flag('dataset_auth_status_check_enabled', False)
+ executor = DatasetJobExecutor()
+ # test not pending
+ executor_result = executor.run_item(self._DATASET_JOB_ID)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+
+ # test not approved
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.scheduler_state = DatasetJobSchedulerState.PENDING
+ dataset_job.output_dataset.ticket_status = TicketStatus.PENDING
+ session.commit()
+ executor_result = executor.run_item(self._DATASET_JOB_ID)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+
+ # test not coordinator
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.coordinator_id = 1
+ dataset_job.output_dataset.ticket_status = TicketStatus.APPROVED
+ session.commit()
+ executor_result = executor.run_item(self._DATASET_JOB_ID)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.STOPPED)
+ mock_create_dataset_job.assert_not_called()
+
+ # test streaming dataset_job and check flag True
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.coordinator_id = 0
+ dataset_job.scheduler_state = DatasetJobSchedulerState.PENDING
+ session.commit()
+ session.flush()
+ dataset_job_parameter = dataset_job.to_proto()
+ dataset_job_parameter.workflow_definition.MergeFrom(OtPsiDataJoinConfiger(session).get_config())
+ executor_result = executor.run_item(self._DATASET_JOB_ID)
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.RUNNABLE)
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_participant_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_participant_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ mock_create_dataset_job.assert_called_once_with(
+ dataset_job=dataset_job_parameter,
+ ticket_uuid=None,
+ dataset=dataset_pb2.Dataset(participants_info=participants_info))
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test_participant_2'].auth_status,
+ AuthStatus.PENDING.name)
+
+ mock_create_dataset_job.reset_mock()
+ # test psi dataset_job need_create_batch
+ mock_create_data_batch_and_job_stage_as_coordinator.return_value = DatasetJobStage(
+ uuid='mock_stage',
+ project_id=self._PROJECT_ID,
+ coordinator_id=0,
+ dataset_job_id=self._DATASET_JOB_ID,
+ data_batch_id=0)
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.time_range = None
+ dataset_job.scheduler_state = DatasetJobSchedulerState.PENDING
+ dataset_job.output_dataset.dataset_type = DatasetType.PSI
+ context = dataset_job.get_context()
+ context.need_create_stage = True
+ dataset_job.set_context(context)
+ session.commit()
+ executor_result = executor.run_item(self._DATASET_JOB_ID)
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.STOPPED)
+ mock_create_dataset_job.assert_called_once()
+ mock_create_data_batch_and_job_stage_as_coordinator.assert_called_once_with(
+ dataset_job_id=self._DATASET_JOB_ID,
+ global_configs=dataset_pb2.DatasetJobGlobalConfigs(
+ global_configs={'test_participant_2': dataset_pb2.DatasetJobConfig()}),
+ )
+
+ mock_create_dataset_job.reset_mock()
+ mock_create_data_batch_and_job_stage_as_coordinator.reset_mock()
+ mock_dataset_auth_status_check_enabled.reset_mock()
+ # test check auth_status_failed
+ mock_dataset_auth_status_check_enabled.return_value = _Flag('dataset_auth_status_check_enabled', True)
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.time_range = None
+ dataset_job.scheduler_state = DatasetJobSchedulerState.PENDING
+ dataset_job.output_dataset.dataset_type = DatasetType.PSI
+ session.commit()
+ executor_result = executor.run_item(self._DATASET_JOB_ID)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ with db.session_scope() as session:
+ dataset_job: DatasetJob = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.PENDING)
+ mock_create_dataset_job.assert_called_once()
+ self.assertEqual(
+ dataset_job.output_dataset.get_participants_info().participants_map['test_participant_2'].auth_status,
+ AuthStatus.PENDING.name)
+ mock_create_data_batch_and_job_stage_as_coordinator.assert_not_called()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_long_period_scheduler.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_long_period_scheduler.py
new file mode 100644
index 000000000..d201bfb19
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_long_period_scheduler.py
@@ -0,0 +1,39 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Tuple
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.dataset.scheduler.chained_executor import run_executor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorType
+from fedlearner_webconsole.proto.composer_pb2 import DatasetSchedulerOutput, RunnerOutput
+
+
+class DatasetLongPeriodScheduler(IRunnerV2):
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ runner_output = RunnerOutput(dataset_scheduler_output=DatasetSchedulerOutput())
+
+ executor_result = run_executor(executor_type=ExecutorType.CRON_DATASET_JOB)
+ runner_output.dataset_scheduler_output.executor_outputs[ExecutorType.CRON_DATASET_JOB.value].MergeFrom(
+ executor_result)
+
+ executor_result = run_executor(executor_type=ExecutorType.UPDATE_AUTH_STATUS)
+ runner_output.dataset_scheduler_output.executor_outputs[ExecutorType.UPDATE_AUTH_STATUS.value].MergeFrom(
+ executor_result)
+
+ return RunnerStatus.DONE, runner_output
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_long_period_scheduler_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_long_period_scheduler_test.py
new file mode 100644
index 000000000..e6117e3a8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_long_period_scheduler_test.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult, ExecutorType
+from fedlearner_webconsole.dataset.scheduler.dataset_long_period_scheduler import DatasetLongPeriodScheduler
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.proto.composer_pb2 import DatasetSchedulerOutput, ExecutorResults, RunnerInput, RunnerOutput
+
+
+class DatasetLongPeriodSchedulerTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor.CronDatasetJobExecutor.get_item_ids')
+ @patch('fedlearner_webconsole.dataset.scheduler.update_auth_status_executor.UpdateAuthStatusExecutor.get_item_ids')
+ @patch('fedlearner_webconsole.dataset.scheduler.cron_dataset_job_executor.CronDatasetJobExecutor.run_item')
+ @patch('fedlearner_webconsole.dataset.scheduler.update_auth_status_executor.UpdateAuthStatusExecutor.run_item')
+ def test_run(self, mock_update_auth_status_run_item: MagicMock, mock_cron_dataset_job_run_item: MagicMock,
+ mock_update_auth_status_get_item_ids: MagicMock, mock_cron_dataset_job_get_item_ids: MagicMock):
+ mock_cron_dataset_job_get_item_ids.return_value = [1, 2, 3, 4]
+ mock_cron_dataset_job_run_item.side_effect = [
+ ExecutorResult.SUCCEEDED, ExecutorResult.SUCCEEDED, ExecutorResult.FAILED, ExecutorResult.SKIP
+ ]
+ mock_update_auth_status_get_item_ids.return_value = [1, 2]
+ mock_update_auth_status_run_item.side_effect = [ExecutorResult.FAILED, ExecutorResult.FAILED]
+ dataset_long_period_scheduler = DatasetLongPeriodScheduler()
+ status, runner_output = dataset_long_period_scheduler.run(context=RunnerContext(0, RunnerInput()))
+ self.assertEqual(status, RunnerStatus.DONE)
+ expected_runner_output = RunnerOutput(dataset_scheduler_output=DatasetSchedulerOutput(
+ executor_outputs={
+ ExecutorType.CRON_DATASET_JOB.value:
+ ExecutorResults(succeeded_item_ids=[1, 2], failed_item_ids=[3], skip_item_ids=[4]),
+ ExecutorType.UPDATE_AUTH_STATUS.value:
+ ExecutorResults(failed_item_ids=[1, 2]),
+ }))
+ self.assertEqual(runner_output, expected_runner_output)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_short_period_scheduler.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_short_period_scheduler.py
new file mode 100644
index 000000000..54724efa4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_short_period_scheduler.py
@@ -0,0 +1,43 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Tuple
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.dataset.scheduler.chained_executor import run_executor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorType
+from fedlearner_webconsole.proto.composer_pb2 import DatasetSchedulerOutput, RunnerOutput
+
+
+class DatasetShortPeriodScheduler(IRunnerV2):
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ runner_output = RunnerOutput(dataset_scheduler_output=DatasetSchedulerOutput())
+
+ executor_result = run_executor(executor_type=ExecutorType.DATASET_JOB)
+ runner_output.dataset_scheduler_output.executor_outputs[ExecutorType.DATASET_JOB.value].MergeFrom(
+ executor_result)
+
+ executor_result = run_executor(executor_type=ExecutorType.PENDING_DATASET_JOB_STAGE)
+ runner_output.dataset_scheduler_output.executor_outputs[ExecutorType.PENDING_DATASET_JOB_STAGE.value].MergeFrom(
+ executor_result)
+
+ executor_result = run_executor(executor_type=ExecutorType.RUNNING_DATASET_JOB_STAGE)
+ runner_output.dataset_scheduler_output.executor_outputs[ExecutorType.RUNNING_DATASET_JOB_STAGE.value].MergeFrom(
+ executor_result)
+
+ return RunnerStatus.DONE, runner_output
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_short_period_scheduler_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_short_period_scheduler_test.py
new file mode 100644
index 000000000..f2d49cc4b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/dataset_short_period_scheduler_test.py
@@ -0,0 +1,67 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult, ExecutorType
+from fedlearner_webconsole.dataset.scheduler.dataset_short_period_scheduler import DatasetShortPeriodScheduler
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.proto.composer_pb2 import DatasetSchedulerOutput, ExecutorResults, RunnerInput, RunnerOutput
+
+
+class DatasetShortPeriodSchedulerTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.dataset.scheduler.dataset_job_executor.DatasetJobExecutor.get_item_ids')
+ @patch('fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor.'\
+ 'PendingDatasetJobStageExecutor.get_item_ids')
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.'\
+ 'RunningDatasetJobStageExecutor.get_item_ids')
+ @patch('fedlearner_webconsole.dataset.scheduler.dataset_job_executor.DatasetJobExecutor.run_item')
+ @patch('fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor.'\
+ 'PendingDatasetJobStageExecutor.run_item')
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.'\
+ 'RunningDatasetJobStageExecutor.run_item')
+ def test_run(self, mock_running_dataset_job_stage_run_item: MagicMock,
+ mock_pending_dataset_job_stage_run_item: MagicMock, mock_dataset_job_run_item: MagicMock,
+ mock_running_dataset_job_stage_get_item_ids: MagicMock,
+ mock_pending_dataset_job_stage_get_item_ids: MagicMock, mock_dataset_job_get_item_ids: MagicMock):
+ mock_running_dataset_job_stage_get_item_ids.return_value = [1, 2, 3, 4]
+ mock_running_dataset_job_stage_run_item.side_effect = [
+ ExecutorResult.SUCCEEDED, ExecutorResult.SUCCEEDED, ExecutorResult.FAILED, ExecutorResult.SKIP
+ ]
+ mock_pending_dataset_job_stage_get_item_ids.return_value = [1, 2]
+ mock_pending_dataset_job_stage_run_item.side_effect = [ExecutorResult.FAILED, ExecutorResult.FAILED]
+ mock_dataset_job_get_item_ids.return_value = [1, 2]
+ mock_dataset_job_run_item.side_effect = [ExecutorResult.SUCCEEDED, ExecutorResult.SKIP]
+ dataset_short_period_scheduler = DatasetShortPeriodScheduler()
+ status, runner_output = dataset_short_period_scheduler.run(context=RunnerContext(0, RunnerInput()))
+ self.assertEqual(status, RunnerStatus.DONE)
+ expected_runner_output = RunnerOutput(dataset_scheduler_output=DatasetSchedulerOutput(
+ executor_outputs={
+ ExecutorType.RUNNING_DATASET_JOB_STAGE.value:
+ ExecutorResults(succeeded_item_ids=[1, 2], failed_item_ids=[3], skip_item_ids=[4]),
+ ExecutorType.PENDING_DATASET_JOB_STAGE.value:
+ ExecutorResults(failed_item_ids=[1, 2]),
+ ExecutorType.DATASET_JOB.value:
+ ExecutorResults(succeeded_item_ids=[1], skip_item_ids=[2]),
+ }))
+ self.assertEqual(runner_output, expected_runner_output)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/pending_dataset_job_stage_executor.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/pending_dataset_job_stage_executor.py
new file mode 100644
index 000000000..d5cf91922
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/pending_dataset_job_stage_executor.py
@@ -0,0 +1,118 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import List
+import grpc
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.controllers import DatasetJobStageController
+from fedlearner_webconsole.dataset.services import DatasetJobService
+from fedlearner_webconsole.dataset.scheduler.base_executor import BaseExecutor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobSchedulerState, DatasetJobStage, DatasetJobState
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.rpc.v2.job_service_client import JobServiceClient
+
+
+class PendingDatasetJobStageExecutor(BaseExecutor):
+
+ def _process_pending_dataset_job_stage(self, session: Session,
+ dataset_job_stage: DatasetJobStage) -> ExecutorResult:
+ """Schedules pending dataset job stage, same logic as _process_pending_dataset_job.
+
+ 1. If is not coordinator, return
+ 2. check whether participant is ready, if not, try to create it and return
+ 3. try to start the dataset_job_stage
+ """
+ if not dataset_job_stage.is_coordinator():
+ return ExecutorResult.SUCCEEDED
+ # create participant dataset_job_stage
+ dataset_job: DatasetJob = dataset_job_stage.dataset_job
+ participants = DatasetJobService(session).get_participants_need_distribute(dataset_job)
+ # if all participants which need distribute have created dataset_job_stage and related workflow,
+ # is_peer_ready is True
+ is_peer_ready = True
+ for participant in participants:
+ client = JobServiceClient.from_project_and_participant(domain_name=participant.domain_name,
+ project_name=dataset_job_stage.project.name)
+ try:
+ response = client.get_dataset_job_stage(dataset_job_stage_uuid=dataset_job_stage.uuid)
+ if response.dataset_job_stage.is_ready:
+ logging.info(
+ '[pending dataset_job_stage executor]: participant dataset_job_stage is ready, ' \
+ f'participant name: {participant.name}'
+ )
+ else:
+ is_peer_ready = False
+ except grpc.RpcError as err:
+ if err.code() != grpc.StatusCode.NOT_FOUND:
+ raise InternalException(
+ details=f'failed to call GetDatasetJobStage with status code {err.code()} ' \
+ f'and details {err.details()}'
+ ) from err
+ # participant has no dataset_job_stage
+ logging.info(
+ f'[pending dataset_job_stage executor]: dataset_job_stage in participant {participant.name} ' \
+ 'not found, start to create')
+ is_peer_ready = False
+ client.create_dataset_job_stage(dataset_job_uuid=dataset_job.uuid,
+ dataset_job_stage_uuid=dataset_job_stage.uuid,
+ name=dataset_job_stage.name,
+ event_time=dataset_job_stage.event_time)
+ if not is_peer_ready:
+ return ExecutorResult.SKIP
+ # start dataset_job
+ try:
+ DatasetJobStageController(session=session).start(uuid=dataset_job_stage.uuid)
+ logging.info(
+ '[pending dataset_job_stage executor]: start dataset_job_stage successfully, ' \
+ f'dataset_job_stage_id: {dataset_job_stage.id}'
+ )
+ except InternalException as e:
+ logging.error(
+ f'[pending dataset_job_stage executor]: start dataset_job_stage {dataset_job_stage.id} failed, ' \
+ f'exception: {e}'
+ )
+ # reset dataset_job_stage state to PENDING,
+ # in order to make sure it will be scheduled to start again next time
+ dataset_job_stage.state = DatasetJobState.PENDING
+ return ExecutorResult.FAILED
+ return ExecutorResult.SUCCEEDED
+
+ def get_item_ids(self) -> List[int]:
+ with db.session_scope() as session:
+ pending_dataset_job_stage_ids = session.query(DatasetJobStage.id).outerjoin(
+ DatasetJob, DatasetJob.id == DatasetJobStage.dataset_job_id).filter(
+ DatasetJobStage.state == DatasetJobState.PENDING).filter(
+ DatasetJob.scheduler_state != DatasetJobSchedulerState.PENDING).all()
+ return [pending_dataset_job_stage_id for pending_dataset_job_stage_id, *_ in pending_dataset_job_stage_ids]
+
+ def run_item(self, item_id: int) -> ExecutorResult:
+ with db.session_scope() as session:
+ # we set isolation_level to SERIALIZABLE to make sure state won't be changed within this session
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ dataset_job_stage: DatasetJobStage = session.query(DatasetJobStage).get(item_id)
+ if dataset_job_stage.state != DatasetJobState.PENDING:
+ return ExecutorResult.SKIP
+ if not dataset_job_stage.workflow:
+ DatasetJobStageController(session=session).create_ready_workflow(dataset_job_stage)
+ session.commit()
+ with db.session_scope() as session:
+ dataset_job_stage: DatasetJobStage = session.query(DatasetJobStage).get(item_id)
+ executor_result = self._process_pending_dataset_job_stage(session, dataset_job_stage)
+ session.commit()
+ return executor_result
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/pending_dataset_job_stage_executor_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/pending_dataset_job_stage_executor_test.py
new file mode 100644
index 000000000..a0f4b2125
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/pending_dataset_job_stage_executor_test.py
@@ -0,0 +1,269 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+from datetime import datetime, timedelta
+import grpc
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor import PendingDatasetJobStageExecutor
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, DatasetJob, DatasetJobKind, \
+ DatasetJobSchedulerState, DatasetJobStage, DatasetJobState, DatasetKindV2, DatasetType
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.participant.models import ProjectParticipant, Participant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import GetDatasetJobStageResponse
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ProcessDatasetJobStagesTest(NoWebServerTestCase):
+
+ _PROJECT_ID = 1
+ _JOB_ID = 1
+ _WORKFLOW_ID = 1
+ _PARTICIPANT_ID = 1
+ _INPUT_DATASET_ID = 1
+ _STREAMING_OUTPUT_DATASET_ID = 2
+ _PSI_OUTPUT_DATASET_ID = 3
+ _DATA_BATCH_NO_EVENT_TIME_ID = 1
+ _DATA_BATCH_WITH_EVENT_TIME_ID = 2
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ project = Project(id=self._PROJECT_ID, name='test-project')
+ workflow = Workflow(id=self._WORKFLOW_ID,
+ project_id=self._PROJECT_ID,
+ state=WorkflowState.READY,
+ uuid='workflow_uuid')
+ job = Job(id=self._JOB_ID,
+ state=JobState.NEW,
+ job_type=JobType.PSI_DATA_JOIN,
+ workflow_id=self._WORKFLOW_ID,
+ project_id=1)
+ session.add_all([project, workflow, job])
+ input_dataset = Dataset(id=self._INPUT_DATASET_ID,
+ name='input dataset',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ created_at=datetime(2012, 1, 14, 12, 0, 6),
+ dataset_kind=DatasetKindV2.RAW)
+ streaming_output_dataset = Dataset(id=self._STREAMING_OUTPUT_DATASET_ID,
+ name='streaming output dataset',
+ uuid='streaming output_dataset uuid',
+ path='/data/dataset/321',
+ project_id=self._PROJECT_ID,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_type=DatasetType.STREAMING)
+ batch_with_event_time = DataBatch(id=self._DATA_BATCH_WITH_EVENT_TIME_ID,
+ path='/data/dataset/321/batch/20220101',
+ dataset_id=self._STREAMING_OUTPUT_DATASET_ID,
+ event_time=datetime(2022, 1, 1))
+ psi_output_dataset = Dataset(id=self._PSI_OUTPUT_DATASET_ID,
+ name='psi output dataset',
+ uuid='psi output_dataset uuid',
+ path='/data/dataset/321',
+ project_id=self._PROJECT_ID,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_type=DatasetType.PSI)
+ batch_no_event_time = DataBatch(id=self._DATA_BATCH_NO_EVENT_TIME_ID,
+ path='/data/dataset/321/batch/0',
+ dataset_id=self._PSI_OUTPUT_DATASET_ID)
+ session.add_all([
+ input_dataset, streaming_output_dataset, batch_with_event_time, psi_output_dataset, batch_no_event_time
+ ])
+ participant = Participant(id=self._PARTICIPANT_ID, name='participant_1', domain_name='fake_domain_name_1')
+ project_participant = ProjectParticipant(project_id=self._PROJECT_ID, participant_id=self._PARTICIPANT_ID)
+ session.add_all([participant, project_participant])
+ session.commit()
+
+ def _insert_psi_dataset_job_and_stage(self, state: DatasetJobState, job_id: int):
+ with db.session_scope() as session:
+ psi_dataset_job = DatasetJob(id=job_id,
+ uuid=f'psi dataset_job uuid {state.name}',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._PSI_OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=state,
+ scheduler_state=DatasetJobSchedulerState.STOPPED,
+ coordinator_id=0,
+ workflow_id=0)
+ psi_dataset_job_stage = DatasetJobStage(id=job_id,
+ uuid=f'psi dataset_job_stage uuid {state.name}',
+ name='psi dataset job stage',
+ project_id=self._PROJECT_ID,
+ workflow_id=self._WORKFLOW_ID,
+ dataset_job_id=job_id,
+ data_batch_id=self._DATA_BATCH_NO_EVENT_TIME_ID,
+ state=state)
+ session.add_all([psi_dataset_job, psi_dataset_job_stage])
+ session.commit()
+
+ def _insert_streaming_dataset_job_and_stage(self, state: DatasetJobState, job_id: int):
+ with db.session_scope() as session:
+ streaming_dataset_job = DatasetJob(id=job_id,
+ uuid=f'streaming dataset_job uuid {state.name}',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._STREAMING_OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=state,
+ scheduler_state=DatasetJobSchedulerState.STOPPED,
+ coordinator_id=0,
+ workflow_id=0,
+ time_range=timedelta(days=1))
+ streaming_dataset_job_stage = DatasetJobStage(id=job_id,
+ uuid=f'streaming dataset_job_stage uuid {state.name}',
+ name='streaming dataset job stage',
+ project_id=self._PROJECT_ID,
+ workflow_id=self._WORKFLOW_ID,
+ dataset_job_id=job_id,
+ data_batch_id=self._DATA_BATCH_WITH_EVENT_TIME_ID,
+ state=state,
+ event_time=datetime(2022, 1, 1))
+ session.add_all([streaming_dataset_job, streaming_dataset_job_stage])
+ session.commit()
+
+ def test_get_item_ids(self):
+ dataset_job_stage_pending_id = 1
+ dataset_job_stage_running_id = 2
+ dataset_job_stage_succeeded_id = 3
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.PENDING, dataset_job_stage_pending_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.RUNNING, dataset_job_stage_running_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.SUCCEEDED, dataset_job_stage_succeeded_id)
+ executor = PendingDatasetJobStageExecutor()
+ processed_dataset_job_stage_ids = executor.get_item_ids()
+ self.assertEqual(processed_dataset_job_stage_ids, [dataset_job_stage_pending_id])
+
+ @patch(
+ 'fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor.PendingDatasetJobStageExecutor.' \
+ '_process_pending_dataset_job_stage'
+ )
+ @patch('fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor.DatasetJobStageController.'\
+ 'create_ready_workflow')
+ def test_run_item(self, mock_create_ready_workflow: MagicMock, mock_process_pending_dataset_job_stage: MagicMock):
+ dataset_job_stage_pending_id = 1
+ dataset_job_stage_running_id = 2
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.PENDING, dataset_job_stage_pending_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.RUNNING, dataset_job_stage_running_id)
+ executor = PendingDatasetJobStageExecutor()
+
+ # test not pending
+ executor_result = executor.run_item([dataset_job_stage_running_id])
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ mock_create_ready_workflow.assert_not_called()
+ mock_process_pending_dataset_job_stage.assert_not_called()
+
+ # test succeeded
+ mock_process_pending_dataset_job_stage.return_value = ExecutorResult.SUCCEEDED
+ executor_result = executor.run_item([dataset_job_stage_pending_id])
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ mock_create_ready_workflow.assert_not_called()
+ mock_process_pending_dataset_job_stage.assert_called_once()
+ self.assertEqual(mock_process_pending_dataset_job_stage.call_args[0][1].id, dataset_job_stage_pending_id)
+
+ # test no workflow and process failed
+ with db.session_scope() as session:
+ dataset_job_stage_pending = session.query(DatasetJobStage).get(dataset_job_stage_pending_id)
+ dataset_job_stage_pending.workflow_id = 0
+ session.commit()
+ mock_process_pending_dataset_job_stage.reset_mock()
+ mock_process_pending_dataset_job_stage.return_value = ExecutorResult.FAILED
+ executor_result = executor.run_item([dataset_job_stage_pending_id])
+ self.assertEqual(executor_result, ExecutorResult.FAILED)
+ mock_create_ready_workflow.assert_called_once()
+ mock_process_pending_dataset_job_stage.assert_called_once()
+ self.assertEqual(mock_process_pending_dataset_job_stage.call_args[0][1].id, dataset_job_stage_pending_id)
+
+ @patch('fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor.JobServiceClient.'\
+ 'get_dataset_job_stage')
+ @patch('fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor.JobServiceClient.'\
+ 'create_dataset_job_stage')
+ @patch('fedlearner_webconsole.dataset.scheduler.pending_dataset_job_stage_executor.DatasetJobStageController.'\
+ 'start')
+ def test_process_pending_dataset_job_stage(self, mock_start: MagicMock, mock_create_dataset_job_stage: MagicMock,
+ mock_get_dataset_job_stage: MagicMock):
+ dataset_job_stage_pending_id = 1
+ self._insert_streaming_dataset_job_and_stage(DatasetJobState.PENDING, dataset_job_stage_pending_id)
+ executor = PendingDatasetJobStageExecutor()
+
+ # test not coordinator
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_pending_id)
+ dataset_job_stage.coordinator_id = 1
+ # pylint: disable=protected-access
+ executor_result = executor._process_pending_dataset_job_stage(session=session,
+ dataset_job_stage=dataset_job_stage)
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ mock_get_dataset_job_stage.assert_not_called()
+
+ # test not_ready
+ mock_get_dataset_job_stage.return_value = GetDatasetJobStageResponse(
+ dataset_job_stage=dataset_pb2.DatasetJobStage(is_ready=False))
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_pending_id)
+ # pylint: disable=protected-access
+ executor_result = executor._process_pending_dataset_job_stage(session=session,
+ dataset_job_stage=dataset_job_stage)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ mock_create_dataset_job_stage.assert_not_called()
+ mock_start.assert_not_called()
+
+ mock_get_dataset_job_stage.reset_mock()
+
+ # test ready and start
+ mock_get_dataset_job_stage.return_value = GetDatasetJobStageResponse(
+ dataset_job_stage=dataset_pb2.DatasetJobStage(is_ready=True))
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_pending_id)
+ # pylint: disable=protected-access
+ executor_result = executor._process_pending_dataset_job_stage(session=session,
+ dataset_job_stage=dataset_job_stage)
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ mock_create_dataset_job_stage.assert_not_called()
+ mock_start.assert_called_once_with(uuid='streaming dataset_job_stage uuid PENDING')
+
+ mock_get_dataset_job_stage.reset_mock()
+ mock_start.reset_mock()
+
+ # test get_dataset_job_stage raise
+ e = grpc.RpcError()
+ e.code = lambda: grpc.StatusCode.NOT_FOUND
+ mock_get_dataset_job_stage.side_effect = e
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_pending_id)
+ # pylint: disable=protected-access
+ executor_result = executor._process_pending_dataset_job_stage(session=session,
+ dataset_job_stage=dataset_job_stage)
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ mock_create_dataset_job_stage.assert_called_once_with(
+ dataset_job_uuid='streaming dataset_job uuid PENDING',
+ dataset_job_stage_uuid='streaming dataset_job_stage uuid PENDING',
+ name='streaming dataset job stage',
+ event_time=datetime(2022, 1, 1))
+ mock_start.assert_not_called()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/running_dataset_job_stage_executor.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/running_dataset_job_stage_executor.py
new file mode 100644
index 000000000..be894d5c1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/running_dataset_job_stage_executor.py
@@ -0,0 +1,203 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import timedelta
+import logging
+import os
+from typing import List
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.cleanup.models import ResourceType
+from fedlearner_webconsole.cleanup.services import CleanupService
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetJobStage, \
+ DatasetJobSchedulerState, DatasetJobState, DataBatch, DatasetKindV2, DatasetType, ResourceState
+from fedlearner_webconsole.dataset.scheduler.base_executor import BaseExecutor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.services import DatasetJobService, DatasetJobStageService, DatasetService
+from fedlearner_webconsole.dataset.consts import ERROR_BATCH_SIZE
+from fedlearner_webconsole.rpc.v2.resource_service_client import ResourceServiceClient
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp, now
+from fedlearner_webconsole.utils.workflow import build_job_name
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowExternalState
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupParameter, CleanupPayload
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, BatchStatsInput
+
+SIDE_OUTPUT_CLEANUP_DEFAULT_DELAY = timedelta(days=1)
+
+
+class RunningDatasetJobStageExecutor(BaseExecutor):
+
+ def _process_running_dataset_job_stage(self, session: Session,
+ dataset_job_stage: DatasetJobStage) -> ExecutorResult:
+ """Schedules running dataset job stage, same logic as _process_running_dataset_job.
+
+ 1. If the related workflow is completed
+ - 1.1 Checks the batch stats item if it has been triggered.
+ - 1.1.1 If the runner fails, then mark the job_stage as failed.
+ - 1.1.2 If the runner completes, then mark the job_stage as succeeded.
+ - 1.2 Triggers the batch stats item if it does not exist.
+ 2. If the related workflow is failed, then mark the job_stage as failed.
+ """
+ dataset_job_stage_service = DatasetJobStageService(session=session)
+ workflow_state = dataset_job_stage.workflow.get_state_for_frontend()
+ if workflow_state == WorkflowExternalState.COMPLETED:
+ if not self._need_batch_stats(dataset_job_stage.dataset_job.kind):
+ dataset_job_stage_service.finish_dataset_job_stage(dataset_job_stage=dataset_job_stage,
+ finish_state=DatasetJobState.SUCCEEDED)
+ return ExecutorResult.SUCCEEDED
+ item_name = dataset_job_stage.get_context().batch_stats_item_name
+ executor_result = ExecutorResult.SKIP
+ if item_name:
+ runners = ComposerService(session).get_recent_runners(item_name, count=1)
+ if len(runners) > 0:
+ if runners[0].status == RunnerStatus.DONE.value:
+ self._set_data_batch_num_example(session, dataset_job_stage)
+ dataset_job_stage_service.finish_dataset_job_stage(dataset_job_stage=dataset_job_stage,
+ finish_state=DatasetJobState.SUCCEEDED)
+ executor_result = ExecutorResult.SUCCEEDED
+ elif runners[0].status == RunnerStatus.FAILED.value:
+ batch = session.query(DataBatch).get(dataset_job_stage.data_batch_id)
+ # set file size to illegal value to let frontend know batch stats failed
+ batch.file_size = ERROR_BATCH_SIZE
+ session.flush()
+ self._set_data_batch_num_example(session, dataset_job_stage)
+ dataset_job_stage_service.finish_dataset_job_stage(dataset_job_stage=dataset_job_stage,
+ finish_state=DatasetJobState.SUCCEEDED)
+ executor_result = ExecutorResult.SUCCEEDED
+ else:
+ item_name = f'batch_stats_{dataset_job_stage.data_batch.id}_{dataset_job_stage.id}'
+ runner_input = RunnerInput(batch_stats_input=BatchStatsInput(batch_id=dataset_job_stage.data_batch.id))
+ ComposerService(session).collect_v2(name=item_name, items=[(ItemType.BATCH_STATS, runner_input)])
+ context = dataset_job_stage.get_context()
+ context.batch_stats_item_name = item_name
+ dataset_job_stage.set_context(context)
+ return executor_result
+ if workflow_state in (WorkflowExternalState.FAILED, WorkflowExternalState.STOPPED,
+ WorkflowExternalState.INVALID):
+ dataset_job_stage_service.finish_dataset_job_stage(dataset_job_stage=dataset_job_stage,
+ finish_state=DatasetJobState.FAILED)
+ return ExecutorResult.SUCCEEDED
+ return ExecutorResult.SKIP
+
+ def _process_succeeded_dataset_job_stage(self, session: Session, dataset_job_stage: DatasetJobStage):
+ """Schedules when running dataset job stage succeeded, same logic as _process_succeeded_dataset_job.
+
+ 1. publish output_dataset if needed
+ 2. create transaction for participants
+ 3. delete side_output data
+ """
+ output_dataset: Dataset = dataset_job_stage.dataset_job.output_dataset
+ meta_info = output_dataset.get_meta_info()
+ if meta_info.need_publish:
+ DatasetService(session).publish_dataset(dataset_id=output_dataset.id, value=meta_info.value)
+ logging.info(f'[dataset_job_scheduler] auto publish dataset {output_dataset.id}')
+ # set need_publish to false after publish
+ meta_info.need_publish = False
+ output_dataset.set_meta_info(meta_info)
+ self._delete_side_output(session=session, dataset_job_stage=dataset_job_stage)
+
+ def _process_failed_dataset_job_stage(self, session: Session, dataset_job_stage: DatasetJobStage):
+ """Schedules when running dataset job stage failed.
+
+ 1. delete side_output data
+ """
+ self._delete_side_output(session=session, dataset_job_stage=dataset_job_stage)
+
+ def _need_batch_stats(self, dataset_job_kind: DatasetJobKind):
+ # batch sample info is now generated by analyzer spark task, so we need run again data stats after analyzer
+ return dataset_job_kind in [
+ DatasetJobKind.RSA_PSI_DATA_JOIN, DatasetJobKind.LIGHT_CLIENT_RSA_PSI_DATA_JOIN,
+ DatasetJobKind.OT_PSI_DATA_JOIN, DatasetJobKind.LIGHT_CLIENT_OT_PSI_DATA_JOIN,
+ DatasetJobKind.HASH_DATA_JOIN, DatasetJobKind.DATA_JOIN, DatasetJobKind.DATA_ALIGNMENT,
+ DatasetJobKind.IMPORT_SOURCE, DatasetJobKind.ANALYZER
+ ]
+
+ def _get_single_batch_num_example(self, dataset: Dataset) -> int:
+ try:
+ return dataset.get_single_batch().num_example
+ except TypeError as e:
+ logging.info(f'single data_batch not found, err: {e}')
+ return 0
+
+ def _set_data_batch_num_example(self, session: Session, dataset_job_stage: DatasetJobStage):
+ input_dataset: Dataset = dataset_job_stage.dataset_job.input_dataset
+ if input_dataset.dataset_type == DatasetType.PSI:
+ input_data_batch_num_example = self._get_single_batch_num_example(input_dataset)
+ else:
+ # TODO(liuhehan): add filter input data_batch by time_range
+ input_data_batch = session.query(DataBatch).filter(DataBatch.dataset_id == input_dataset.id).filter(
+ DataBatch.event_time == dataset_job_stage.event_time).first()
+ input_data_batch_num_example = input_data_batch.num_example if input_data_batch else 0
+ output_data_batch_num_example = dataset_job_stage.data_batch.num_example if dataset_job_stage.data_batch else 0
+ context = dataset_job_stage.get_context()
+ context.input_data_batch_num_example = input_data_batch_num_example
+ context.output_data_batch_num_example = output_data_batch_num_example
+ dataset_job_stage.set_context(context)
+
+ def _delete_side_output(self, session: Session, dataset_job_stage: DatasetJobStage):
+ output_dataset: Dataset = dataset_job_stage.dataset_job.output_dataset
+ if output_dataset.dataset_kind not in [DatasetKindV2.RAW, DatasetKindV2.PROCESSED]:
+ return
+ batch_name = dataset_job_stage.data_batch.batch_name
+ paths = [DatasetDirectory(output_dataset.path).side_output_path(batch_name)]
+ # hack to get rsa_psi side_output
+ # raw_data_path: raw_data_job side_output
+ # psi_data_join_path: psi_data_join_job side_output, we only delete psi_output folder
+ # as data_block folder is still used by model training
+ if dataset_job_stage.dataset_job.kind == DatasetJobKind.RSA_PSI_DATA_JOIN:
+ workflow: Workflow = dataset_job_stage.workflow
+ raw_data_folder = build_job_name(workflow.uuid, 'raw-data-job')
+ raw_data_path = os.path.join(workflow.project.get_storage_root_path(None), 'raw_data', raw_data_folder)
+ psi_data_join_folder_name = build_job_name(workflow.uuid, 'psi-data-join-job')
+ psi_data_join_path = os.path.join(workflow.project.get_storage_root_path(None), 'data_source',
+ psi_data_join_folder_name, 'psi_output')
+ paths.extend([raw_data_path, psi_data_join_path])
+ target_start_at = to_timestamp(now() + SIDE_OUTPUT_CLEANUP_DEFAULT_DELAY)
+ cleanup_param = CleanupParameter(resource_id=dataset_job_stage.id,
+ resource_type=ResourceType.DATASET_JOB_STAGE.name,
+ payload=CleanupPayload(paths=paths),
+ target_start_at=target_start_at)
+ CleanupService(session).create_cleanup(cleanup_parmeter=cleanup_param)
+
+ def get_item_ids(self) -> List[int]:
+ with db.session_scope() as session:
+ running_dataset_job_stage_ids = session.query(DatasetJobStage.id).outerjoin(
+ DatasetJob, DatasetJob.id == DatasetJobStage.dataset_job_id).filter(
+ DatasetJobStage.state == DatasetJobState.RUNNING).filter(
+ DatasetJob.scheduler_state != DatasetJobSchedulerState.PENDING).all()
+ return [running_dataset_job_stage_id for running_dataset_job_stage_id, *_ in running_dataset_job_stage_ids]
+
+ def run_item(self, item_id: int) -> ExecutorResult:
+ with db.session_scope() as session:
+ # we set isolation_level to SERIALIZABLE to make sure state won't be changed within this session
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ dataset_job_stage: DatasetJobStage = session.query(DatasetJobStage).get(item_id)
+ if dataset_job_stage.state != DatasetJobState.RUNNING:
+ return ExecutorResult.SKIP
+ executor_result = self._process_running_dataset_job_stage(session, dataset_job_stage)
+ if dataset_job_stage.state == DatasetJobState.SUCCEEDED:
+ self._process_succeeded_dataset_job_stage(session, dataset_job_stage)
+ elif dataset_job_stage.state == DatasetJobState.FAILED:
+ self._process_failed_dataset_job_stage(session, dataset_job_stage)
+ session.commit()
+ return executor_result
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/running_dataset_job_stage_executor_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/running_dataset_job_stage_executor_test.py
new file mode 100644
index 000000000..81b8c7547
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/running_dataset_job_stage_executor_test.py
@@ -0,0 +1,420 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+from datetime import datetime, timedelta, timezone
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor import RunningDatasetJobStageExecutor
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, DatasetFormat, DatasetJob, DatasetJobKind, \
+ DatasetJobSchedulerState, DatasetJobStage, DatasetJobState, DatasetKindV2, DatasetType, ResourceState
+from fedlearner_webconsole.dataset.consts import ERROR_BATCH_SIZE
+from fedlearner_webconsole.flag.models import _Flag
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import RunnerStatus, SchedulerRunner
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowExternalState, WorkflowState
+from fedlearner_webconsole.participant.models import ProjectParticipant, Participant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.proto import dataset_pb2, service_pb2
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupParameter, CleanupPayload
+from fedlearner_webconsole.proto.composer_pb2 import BatchStatsInput, RunnerInput
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+def mock_process_pending_dataset_job_stage(self, session: Session,
+ dataset_job_stage: DatasetJobStage) -> ExecutorResult:
+ if dataset_job_stage.id == 1:
+ return ExecutorResult.SKIP
+ if dataset_job_stage.id == 2:
+ dataset_job_stage.state = DatasetJobState.SUCCEEDED
+ return ExecutorResult.SUCCEEDED
+ dataset_job_stage.state = DatasetJobState.FAILED
+ return ExecutorResult.SUCCEEDED
+
+
+class ProcessDatasetJobStagesTest(NoWebServerTestCase):
+
+ _PROJECT_ID = 1
+ _JOB_ID = 1
+ _WORKFLOW_ID = 1
+ _PARTICIPANT_ID = 1
+ _INPUT_DATASET_ID = 1
+ _STREAMING_OUTPUT_DATASET_ID = 2
+ _PSI_OUTPUT_DATASET_ID = 3
+ _DATA_BATCH_NO_EVENT_TIME_ID = 1
+ _DATA_BATCH_WITH_EVENT_TIME_ID = 2
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ project = Project(id=self._PROJECT_ID, name='test-project')
+ workflow = Workflow(id=self._WORKFLOW_ID,
+ project_id=self._PROJECT_ID,
+ state=WorkflowState.READY,
+ uuid='workflow_uuid')
+ job = Job(id=self._JOB_ID,
+ state=JobState.NEW,
+ job_type=JobType.PSI_DATA_JOIN,
+ workflow_id=self._WORKFLOW_ID,
+ project_id=1)
+ session.add_all([project, workflow, job])
+ input_dataset = Dataset(id=self._INPUT_DATASET_ID,
+ name='input dataset',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ created_at=datetime(2012, 1, 14, 12, 0, 6),
+ dataset_kind=DatasetKindV2.RAW)
+ streaming_output_dataset = Dataset(id=self._STREAMING_OUTPUT_DATASET_ID,
+ name='streaming output dataset',
+ uuid='streaming output_dataset uuid',
+ path='/data/dataset/321',
+ project_id=self._PROJECT_ID,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_type=DatasetType.STREAMING)
+ batch_with_event_time = DataBatch(id=self._DATA_BATCH_WITH_EVENT_TIME_ID,
+ path='/data/dataset/321/batch/20220101',
+ dataset_id=self._STREAMING_OUTPUT_DATASET_ID,
+ event_time=datetime(2022, 1, 1))
+ psi_output_dataset = Dataset(id=self._PSI_OUTPUT_DATASET_ID,
+ name='psi output dataset',
+ uuid='psi output_dataset uuid',
+ path='/data/dataset/321',
+ project_id=self._PROJECT_ID,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ dataset_kind=DatasetKindV2.PROCESSED,
+ dataset_type=DatasetType.PSI)
+ batch_no_event_time = DataBatch(id=self._DATA_BATCH_NO_EVENT_TIME_ID,
+ path='/data/dataset/321/batch/0',
+ dataset_id=self._PSI_OUTPUT_DATASET_ID)
+ session.add_all([
+ input_dataset, streaming_output_dataset, batch_with_event_time, psi_output_dataset, batch_no_event_time
+ ])
+ participant = Participant(id=self._PARTICIPANT_ID,
+ name='participant_1',
+ domain_name='fl-fake_domain_name_1.com')
+ project_participant = ProjectParticipant(project_id=self._PROJECT_ID, participant_id=self._PARTICIPANT_ID)
+ session.add_all([participant, project_participant])
+ session.commit()
+
+ def _insert_psi_dataset_job_and_stage(self, state: DatasetJobState, job_id: int):
+ with db.session_scope() as session:
+ psi_dataset_job = DatasetJob(id=job_id,
+ uuid=f'psi dataset_job uuid {job_id}',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._PSI_OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=state,
+ scheduler_state=DatasetJobSchedulerState.STOPPED,
+ coordinator_id=0,
+ workflow_id=0)
+ psi_dataset_job_stage = DatasetJobStage(id=job_id,
+ uuid=f'psi dataset_job_stage uuid {job_id}',
+ name='psi dataset job stage',
+ project_id=self._PROJECT_ID,
+ workflow_id=self._WORKFLOW_ID,
+ dataset_job_id=job_id,
+ data_batch_id=self._DATA_BATCH_NO_EVENT_TIME_ID,
+ state=state)
+ session.add_all([psi_dataset_job, psi_dataset_job_stage])
+ session.commit()
+
+ def _insert_streaming_dataset_job_and_stage(self, state: DatasetJobState, job_id: int):
+ with db.session_scope() as session:
+ streaming_dataset_job = DatasetJob(id=job_id,
+ uuid=f'streaming dataset_job uuid {job_id}',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._STREAMING_OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=state,
+ scheduler_state=DatasetJobSchedulerState.STOPPED,
+ coordinator_id=0,
+ workflow_id=0,
+ time_range=timedelta(days=1))
+ streaming_dataset_job_stage = DatasetJobStage(id=job_id,
+ uuid=f'streaming dataset_job_stage uuid {job_id}',
+ name='streaming dataset job stage',
+ project_id=self._PROJECT_ID,
+ workflow_id=self._WORKFLOW_ID,
+ dataset_job_id=job_id,
+ data_batch_id=self._DATA_BATCH_WITH_EVENT_TIME_ID,
+ state=state,
+ event_time=datetime(2022, 1, 1))
+ session.add_all([streaming_dataset_job, streaming_dataset_job_stage])
+ session.commit()
+
+ def test_get_item_ids(self):
+ dataset_job_stage_pending_id = 1
+ dataset_job_stage_running_id = 2
+ dataset_job_stage_succeeded_id = 3
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.PENDING, dataset_job_stage_pending_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.RUNNING, dataset_job_stage_running_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.SUCCEEDED, dataset_job_stage_succeeded_id)
+ executor = RunningDatasetJobStageExecutor()
+ processed_dataset_job_stage_ids = executor.get_item_ids()
+ self.assertEqual(processed_dataset_job_stage_ids, [dataset_job_stage_running_id])
+
+ @patch(
+ 'fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.RunningDatasetJobStageExecutor' \
+ '._process_running_dataset_job_stage', mock_process_pending_dataset_job_stage
+ )
+ @patch(
+ 'fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.RunningDatasetJobStageExecutor' \
+ '._process_succeeded_dataset_job_stage'
+ )
+ @patch(
+ 'fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.RunningDatasetJobStageExecutor' \
+ '._process_failed_dataset_job_stage'
+ )
+ def test_run_item_not_running(self, mock_process_failed_dataset_job_stage: MagicMock,
+ mock_process_succeeded_dataset_job_stage: MagicMock):
+ dataset_job_stage_running_1_id = 1
+ dataset_job_stage_running_2_id = 2
+ dataset_job_stage_running_3_id = 3
+ dataset_job_stage_pending_id = 4
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.PENDING, dataset_job_stage_pending_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.RUNNING, dataset_job_stage_running_1_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.RUNNING, dataset_job_stage_running_2_id)
+ self._insert_psi_dataset_job_and_stage(DatasetJobState.RUNNING, dataset_job_stage_running_3_id)
+ executor = RunningDatasetJobStageExecutor()
+
+ # test not running
+ executor_result = executor.run_item([dataset_job_stage_pending_id])
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ mock_process_succeeded_dataset_job_stage.assert_not_called()
+ mock_process_failed_dataset_job_stage.assert_not_called()
+
+ # test skip
+ executor_result = executor.run_item([dataset_job_stage_running_1_id])
+ self.assertEqual(executor_result, ExecutorResult.SKIP)
+ mock_process_succeeded_dataset_job_stage.assert_not_called()
+ mock_process_failed_dataset_job_stage.assert_not_called()
+
+ # test succeeded
+ executor_result = executor.run_item([dataset_job_stage_running_2_id])
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ mock_process_succeeded_dataset_job_stage.assert_called_once()
+ mock_process_failed_dataset_job_stage.assert_not_called()
+ mock_process_succeeded_dataset_job_stage.reset_mock()
+
+ # test failed
+ executor_result = executor.run_item([dataset_job_stage_running_3_id])
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ mock_process_succeeded_dataset_job_stage.assert_not_called()
+ mock_process_failed_dataset_job_stage.assert_called_once()
+
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.ComposerService.'\
+ 'collect_v2')
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.'\
+ 'RunningDatasetJobStageExecutor._need_batch_stats')
+ @patch('fedlearner_webconsole.workflow.models.Workflow.get_state_for_frontend')
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.ComposerService.'\
+ 'get_recent_runners')
+ def test_process_running_dataset_job_stage(self, mock_get_recent_runners: MagicMock,
+ mock_get_state_for_frontend: MagicMock, mock_need_batch_stats: MagicMock,
+ mock_collect_v2: MagicMock):
+ dataset_job_stage_running_id = 1
+ self._insert_streaming_dataset_job_and_stage(DatasetJobState.RUNNING, dataset_job_stage_running_id)
+ executor = RunningDatasetJobStageExecutor()
+
+ # test workflow failed
+ mock_get_state_for_frontend.return_value = WorkflowExternalState.FAILED
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_running_id)
+ # pylint: disable=protected-access
+ executor._process_running_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.FAILED)
+
+ mock_get_recent_runners.reset_mock()
+ mock_need_batch_stats.reset_mock()
+ mock_get_state_for_frontend.reset_mock()
+
+ # test no need batch stats
+ mock_need_batch_stats.return_value = False
+ mock_get_state_for_frontend.return_value = WorkflowExternalState.COMPLETED
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_running_id)
+ # pylint: disable=protected-access
+ executor._process_running_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.SUCCEEDED)
+
+ mock_get_recent_runners.reset_mock()
+ mock_need_batch_stats.reset_mock()
+ mock_get_state_for_frontend.reset_mock()
+
+ # test need batch stats and runner done
+ mock_need_batch_stats.return_value = True
+ mock_get_state_for_frontend.return_value = WorkflowExternalState.COMPLETED
+ mock_get_recent_runners.return_value = [SchedulerRunner(status=RunnerStatus.DONE.value)]
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_running_id)
+ dataset_job_stage.set_context(dataset_pb2.DatasetJobStageContext(batch_stats_item_name='123'))
+ # pylint: disable=protected-access
+ executor._process_running_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.SUCCEEDED)
+
+ mock_get_recent_runners.reset_mock()
+ mock_need_batch_stats.reset_mock()
+ mock_get_state_for_frontend.reset_mock()
+
+ # test need batch stats and runner failed
+ mock_need_batch_stats.return_value = True
+ mock_get_state_for_frontend.return_value = WorkflowExternalState.COMPLETED
+ mock_get_recent_runners.return_value = [SchedulerRunner(status=RunnerStatus.FAILED.value)]
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_running_id)
+ dataset_job_stage.set_context(dataset_pb2.DatasetJobStageContext(batch_stats_item_name='123'))
+ # pylint: disable=protected-access
+ executor._process_running_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.SUCCEEDED)
+ mock_get_recent_runners.assert_called_once_with('123', count=1)
+ batch = session.query(DataBatch).get(self._DATA_BATCH_WITH_EVENT_TIME_ID)
+ self.assertEqual(batch.file_size, ERROR_BATCH_SIZE)
+
+ mock_get_recent_runners.reset_mock()
+ mock_need_batch_stats.reset_mock()
+ mock_get_state_for_frontend.reset_mock()
+
+ # test need batch stats and runner running
+ mock_need_batch_stats.return_value = True
+ mock_get_state_for_frontend.return_value = WorkflowExternalState.COMPLETED
+ mock_get_recent_runners.return_value = [SchedulerRunner(status=RunnerStatus.RUNNING.value)]
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_running_id)
+ dataset_job_stage.set_context(dataset_pb2.DatasetJobStageContext(batch_stats_item_name='123'))
+ # pylint: disable=protected-access
+ executor._process_running_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.RUNNING)
+
+ mock_get_recent_runners.reset_mock()
+ mock_need_batch_stats.reset_mock()
+ mock_get_state_for_frontend.reset_mock()
+
+ # test no runner
+ mock_need_batch_stats.return_value = True
+ mock_get_state_for_frontend.return_value = WorkflowExternalState.COMPLETED
+ mock_get_recent_runners.return_value = []
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_running_id)
+ # pylint: disable=protected-access
+ executor._process_running_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.RUNNING)
+ runner_input = RunnerInput(batch_stats_input=BatchStatsInput(batch_id=self._DATA_BATCH_WITH_EVENT_TIME_ID))
+ mock_collect_v2.assert_called_once_with(
+ name=f'batch_stats_{self._DATA_BATCH_WITH_EVENT_TIME_ID}_{dataset_job_stage_running_id}',
+ items=[(ItemType.BATCH_STATS, runner_input)])
+
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.'\
+ 'RunningDatasetJobStageExecutor._create_transaction')
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.'\
+ 'RunningDatasetJobStageExecutor._delete_side_output')
+ def test_process_succeeded_dataset_job_stage(self, mock_create_transaction: MagicMock,
+ mock_delete_side_output: MagicMock):
+ dataset_job_stage_succeeded_id = 1
+ self._insert_streaming_dataset_job_and_stage(DatasetJobState.SUCCEEDED, dataset_job_stage_succeeded_id)
+ executor = RunningDatasetJobStageExecutor()
+
+ # test no need publish
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_succeeded_id)
+ dataset_job_stage.dataset_job.output_dataset.dataset_kind = DatasetKindV2.RAW
+ # pylint: disable=protected-access
+ executor._process_succeeded_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertFalse(dataset_job_stage.dataset_job.output_dataset.is_published)
+ mock_create_transaction.assert_called_once()
+ mock_delete_side_output.assert_called_once()
+
+ mock_create_transaction.reset_mock()
+ mock_delete_side_output.reset_mock()
+
+ # test need publish
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_succeeded_id)
+ dataset_job_stage.dataset_job.output_dataset.dataset_kind = DatasetKindV2.RAW
+ dataset_job_stage.dataset_job.output_dataset.set_meta_info(dataset_pb2.DatasetMetaInfo(need_publish=True))
+ # pylint: disable=protected-access
+ executor._process_succeeded_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ self.assertTrue(dataset_job_stage.dataset_job.output_dataset.is_published)
+ self.assertFalse(dataset_job_stage.dataset_job.output_dataset.get_meta_info().need_publish)
+ mock_create_transaction.assert_called_once()
+ mock_delete_side_output.assert_called_once()
+
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.'\
+ 'RunningDatasetJobStageExecutor._delete_side_output')
+ def test_process_failed_dataset_job_stage(self, mock_delete_side_output: MagicMock):
+ dataset_job_stage_failed_id = 1
+ self._insert_streaming_dataset_job_and_stage(DatasetJobState.FAILED, dataset_job_stage_failed_id)
+ executor = RunningDatasetJobStageExecutor()
+
+ # test failed dataset_job_stage
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_failed_id)
+ # pylint: disable=protected-access
+ executor._process_failed_dataset_job_stage(session=session, dataset_job_stage=dataset_job_stage)
+ mock_delete_side_output.assert_called_once()
+
+ @patch('fedlearner_webconsole.project.models.Project.get_storage_root_path')
+ @patch('fedlearner_webconsole.cleanup.services.CleanupService.create_cleanup')
+ @patch('fedlearner_webconsole.dataset.scheduler.running_dataset_job_stage_executor.now',
+ lambda: datetime(2022, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc))
+ def test_delete_side_output(self, cleanup_mock: MagicMock, mock_get_storage_root_path: MagicMock):
+ dataset_job_stage_succeeded_id = 1
+ self._insert_streaming_dataset_job_and_stage(DatasetJobState.SUCCEEDED, dataset_job_stage_succeeded_id)
+ executor = RunningDatasetJobStageExecutor()
+ mock_get_storage_root_path.return_value = '/data'
+
+ # test normal dataset_job
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_succeeded_id)
+ # pylint: disable=protected-access
+ executor._delete_side_output(session=session, dataset_job_stage=dataset_job_stage)
+ payload = CleanupPayload(paths=['/data/dataset/321/side_output/20220101'])
+ cleanup_parmeter = CleanupParameter(resource_id=dataset_job_stage.id,
+ resource_type='DATASET_JOB_STAGE',
+ payload=payload,
+ target_start_at=to_timestamp(
+ datetime(2022, 1, 2, 0, 0, 0, 0, tzinfo=timezone.utc)))
+ cleanup_mock.assert_called_with(cleanup_parmeter=cleanup_parmeter)
+ # test rsa_psi dataset_job
+ cleanup_mock.reset_mock()
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_succeeded_id)
+ dataset_job_stage.dataset_job.kind = DatasetJobKind.RSA_PSI_DATA_JOIN
+ # pylint: disable=protected-access
+ executor._delete_side_output(session=session, dataset_job_stage=dataset_job_stage)
+ payload = CleanupPayload(paths=[
+ '/data/dataset/321/side_output/20220101',
+ '/data/raw_data/workflow_uuid-raw-data-job',
+ '/data/data_source/workflow_uuid-psi-data-join-job/psi_output',
+ ])
+ cleanup_parmeter = CleanupParameter(resource_id=dataset_job_stage.id,
+ resource_type='DATASET_JOB_STAGE',
+ payload=payload,
+ target_start_at=to_timestamp(
+ datetime(2022, 1, 2, 0, 0, 0, 0, tzinfo=timezone.utc)))
+ cleanup_mock.assert_called_with(cleanup_parmeter=cleanup_parmeter)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/update_auth_status_executor.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/update_auth_status_executor.py
new file mode 100644
index 000000000..e5e27374c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/update_auth_status_executor.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List
+from sqlalchemy import or_
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.controllers import DatasetJobController
+from fedlearner_webconsole.dataset.services import DatasetService
+from fedlearner_webconsole.dataset.scheduler.base_executor import BaseExecutor
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.models import DATASET_JOB_FINISHED_STATE, Dataset, DatasetJob, \
+ DatasetJobSchedulerState
+
+
+class UpdateAuthStatusExecutor(BaseExecutor):
+
+ def get_item_ids(self) -> List[int]:
+ with db.session_scope() as session:
+ datasets = DatasetService(session=session).query_dataset_with_parent_job().filter(
+ Dataset.participants_info.isnot(None)).filter(
+ or_(DatasetJob.state.not_in(DATASET_JOB_FINISHED_STATE),
+ DatasetJob.scheduler_state != DatasetJobSchedulerState.STOPPED)).all()
+ return [dataset.id for dataset in datasets]
+
+ def run_item(self, item_id: int) -> ExecutorResult:
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(item_id)
+ # if all participants cache are authorized, just skip
+ if dataset.is_all_participants_authorized():
+ return ExecutorResult.SKIP
+ DatasetJobController(session=session).update_auth_status_cache(dataset_job=dataset.parent_dataset_job)
+ session.commit()
+ return ExecutorResult.SUCCEEDED
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/update_auth_status_executor_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/update_auth_status_executor_test.py
new file mode 100644
index 000000000..e5088e886
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/scheduler/update_auth_status_executor_test.py
@@ -0,0 +1,110 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.scheduler.consts import ExecutorResult
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetJobSchedulerState, \
+ DatasetJobState, DatasetKindV2, DatasetType
+from fedlearner_webconsole.dataset.scheduler.update_auth_status_executor import UpdateAuthStatusExecutor
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo, ParticipantsInfo
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+
+
+class UpdateAuthStatusExecutorTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _WORKFLOW_ID = 1
+ _INPUT_DATASET_ID = 1
+ _OUTPUT_DATASET_ID = 2
+ _OUTPUT_DATASET_2_ID = 3
+
+ def setUp(self) -> None:
+ super().setUp()
+ with db.session_scope() as session:
+ dataset_job_1 = DatasetJob(id=1,
+ uuid='dataset_job_1 uuid',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_ID,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.SUCCEEDED,
+ coordinator_id=0,
+ workflow_id=0,
+ scheduler_state=DatasetJobSchedulerState.STOPPED)
+ session.add(dataset_job_1)
+ output_dataset_1 = Dataset(id=self._OUTPUT_DATASET_ID,
+ uuid='dataset_1 uuid',
+ name='default dataset_1',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ is_published=True)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'coordinator-domain-name': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'participant-domain-name': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ output_dataset_1.set_participants_info(participants_info=participants_info)
+ session.add(output_dataset_1)
+ dataset_job_2 = DatasetJob(id=2,
+ uuid='dataset_job_2 uuid',
+ project_id=self._PROJECT_ID,
+ input_dataset_id=self._INPUT_DATASET_ID,
+ output_dataset_id=self._OUTPUT_DATASET_2_ID,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=0,
+ workflow_id=0,
+ scheduler_state=DatasetJobSchedulerState.STOPPED)
+ session.add(dataset_job_2)
+ output_dataset_2 = Dataset(id=self._OUTPUT_DATASET_2_ID,
+ uuid='dataset_2 uuid',
+ name='default dataset_1',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=self._PROJECT_ID,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ is_published=True)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'coordinator-domain-name': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'participant-domain-name': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ output_dataset_2.set_participants_info(participants_info=participants_info)
+ session.add(output_dataset_2)
+ session.commit()
+
+ def test_get_item_ids(self):
+ update_auth_status_executor = UpdateAuthStatusExecutor()
+ self.assertEqual(update_auth_status_executor.get_item_ids(), [3])
+
+ @patch('fedlearner_webconsole.dataset.controllers.DatasetJobController.update_auth_status_cache')
+ def test_run_item(self, mock_update_auth_status_cache: MagicMock):
+
+ update_auth_status_executor = UpdateAuthStatusExecutor()
+ with db.session_scope() as session:
+ executor_result = update_auth_status_executor.run_item(3)
+ self.assertEqual(executor_result, ExecutorResult.SUCCEEDED)
+ mock_update_auth_status_cache.assert_called_once()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/services.py b/web_console_v2/api/fedlearner_webconsole/dataset/services.py
index a21b7cbeb..561bc4ed7 100644
--- a/web_console_v2/api/fedlearner_webconsole/dataset/services.py
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/services.py
@@ -1,124 +1,717 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
-# coding: utf-8
import json
import logging
-from typing import List
-
-from sqlalchemy.orm import Session
+import os
+from datetime import datetime, timedelta
+from typing import List, Optional, Tuple, Union
+from sqlalchemy import and_, or_
+from sqlalchemy.orm import Session, joinedload, Query
-from fedlearner_webconsole.dataset.models import Dataset
-from fedlearner_webconsole.dataset.sparkapp.pipeline.util import \
- dataset_meta_path, dataset_features_path, dataset_hist_path
-from fedlearner_webconsole.exceptions import NotFoundException
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.flask_utils import get_current_user
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.workflow import fill_variables
+from fedlearner_webconsole.utils.pp_datetime import from_timestamp, to_timestamp, now
from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.exceptions import (InvalidArgumentException, NotFoundException, MethodNotAllowedException,
+ ResourceConflictException)
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.dataset.models import (DATASET_JOB_FINISHED_STATE, DATASET_STATE_CONVERT_MAP_V2,
+ LOCAL_DATASET_JOBS, MICRO_DATASET_JOB, DatasetFormat, ResourceState,
+ DatasetJobKind, DatasetJobStage, DatasetJobState, DatasetKindV2,
+ StoreFormat, DatasetType, Dataset, ImportType, DataBatch, DatasetJob,
+ DataSource, ProcessedDataset, DatasetJobSchedulerState)
+from fedlearner_webconsole.dataset.meta_data import MetaData, ImageMetaData
+from fedlearner_webconsole.dataset.delete_dependency import DatasetDeleteDependency
+from fedlearner_webconsole.dataset.dataset_directory import DatasetDirectory
+from fedlearner_webconsole.dataset.job_configer.dataset_job_configer import DatasetJobConfiger
+from fedlearner_webconsole.dataset.filter_funcs import dataset_format_filter_op_equal, dataset_format_filter_op_in
+from fedlearner_webconsole.dataset.util import get_dataset_path, parse_event_time_to_daily_folder_name, \
+ parse_event_time_to_hourly_folder_name
+from fedlearner_webconsole.dataset.metrics import emit_dataset_job_submission_store, emit_dataset_job_duration_store
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupParameter, CleanupPayload
+from fedlearner_webconsole.proto.dataset_pb2 import CronType, DatasetJobGlobalConfigs
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterOp
+from fedlearner_webconsole.proto.review_pb2 import TicketDetails, TicketType
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.cleanup.models import ResourceType
+from fedlearner_webconsole.cleanup.services import CleanupService
+
+
+class DataReader(object):
+
+ def __init__(self, dataset_path: str):
+ self._path = dataset_path
+ self._dataset_directory = DatasetDirectory(dataset_path=dataset_path)
+ self._file_manager = FileManager()
+
+ # meta is generated from sparkapp/pipeline/analyzer.py
+ def metadata(self, batch_name: str) -> MetaData:
+ meta_path = self._dataset_directory.batch_meta_file(batch_name=batch_name)
+ try:
+ return MetaData(json.loads(self._file_manager.read(meta_path)))
+ except Exception as e: # pylint: disable=broad-except
+ logging.info(f'failed to read meta file, path: {meta_path}, err: {e}')
+ return MetaData()
+
+ def image_metadata(self, thumbnail_dir_path: str, batch_name: str) -> ImageMetaData:
+ meta_path = self._dataset_directory.batch_meta_file(batch_name=batch_name)
+ try:
+ return ImageMetaData(thumbnail_dir_path, json.loads(self._file_manager.read(meta_path)))
+ except Exception as e: # pylint: disable=broad-except
+ logging.info(f'failed to read meta file, path: {meta_path}, err: {e}')
+ return ImageMetaData(thumbnail_dir_path)
class DatasetService(object):
+
+ DATASET_CLEANUP_DEFAULT_DELAY = timedelta(days=7)
+ PUBLISHED_DATASET_FILTER_FIELDS = {
+ 'uuid':
+ SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'kind':
+ SupportedField(type=FieldType.STRING, ops={
+ FilterOp.IN: None,
+ FilterOp.EQUAL: None
+ }),
+ 'dataset_format':
+ SupportedField(type=FieldType.STRING,
+ ops={
+ FilterOp.IN: dataset_format_filter_op_in,
+ FilterOp.EQUAL: dataset_format_filter_op_equal
+ }),
+ }
+
def __init__(self, session: Session):
self._session = session
self._file_manager = FileManager()
+ self._published_dataset_filter_builder = FilterBuilder(model_class=Dataset,
+ supported_fields=self.PUBLISHED_DATASET_FILTER_FIELDS)
- def get_dataset_preview(self, dataset_id: int = 0) -> dict:
- dataset = self._session.query(Dataset).filter(
- Dataset.id == dataset_id).first()
+ @staticmethod
+ def filter_dataset_state(query: Query, frontend_states: List[ResourceState]) -> Query:
+ if len(frontend_states) == 0:
+ return query
+ dataset_job_states = []
+ for k, v in DATASET_STATE_CONVERT_MAP_V2.items():
+ if v in frontend_states:
+ dataset_job_states.append(k)
+ state_filter = DatasetJob.state.in_(dataset_job_states)
+ # internal_processed dataset is now hack to succeeded,
+ # so here we add all internal_processed dataset when filter succeeded dataset
+ if ResourceState.SUCCEEDED in frontend_states:
+ state_filter = or_(state_filter, Dataset.dataset_kind == DatasetKindV2.INTERNAL_PROCESSED)
+ return query.filter(state_filter)
+
+ def query_dataset_with_parent_job(self) -> Query:
+ return self._session.query(Dataset).outerjoin(
+ DatasetJob, and_(DatasetJob.output_dataset_id == Dataset.id, DatasetJob.input_dataset_id != Dataset.id))
+
+ def create_dataset(self, dataset_parameter: dataset_pb2.DatasetParameter) -> Dataset:
+ # check project existense
+ project = self._session.query(Project).get(dataset_parameter.project_id)
+ if project is None:
+ raise NotFoundException(message=f'cannot found project with id: {dataset_parameter.project_id}')
+
+ # Create dataset
+ dataset = Dataset(
+ name=dataset_parameter.name,
+ uuid=dataset_parameter.uuid or resource_uuid(),
+ is_published=dataset_parameter.is_published,
+ dataset_type=DatasetType(dataset_parameter.type),
+ comment=dataset_parameter.comment,
+ project_id=dataset_parameter.project_id,
+ dataset_kind=DatasetKindV2(dataset_parameter.kind),
+ dataset_format=DatasetFormat[dataset_parameter.format].value,
+ # set participant dataset creator_username to empty if dataset is created by coordinator
+ # TODO(liuhehan): set participant dataset creator_username to username who authorize it
+ creator_username=get_current_user().username if get_current_user() else '',
+ )
+ if dataset_parameter.path and dataset.dataset_kind in [
+ DatasetKindV2.EXPORTED, DatasetKindV2.INTERNAL_PROCESSED
+ ]:
+ dataset.path = dataset_parameter.path
+ else:
+ dataset.path = get_dataset_path(dataset_name=dataset.name, uuid=dataset.uuid)
+ if dataset_parameter.import_type:
+ dataset.import_type = ImportType(dataset_parameter.import_type)
+ if dataset_parameter.store_format:
+ dataset.store_format = StoreFormat(dataset_parameter.store_format)
+ if dataset_parameter.auth_status:
+ dataset.auth_status = AuthStatus[dataset_parameter.auth_status]
+ if dataset_parameter.creator_username:
+ dataset.creator_username = dataset_parameter.creator_username
+ elif get_current_user():
+ dataset.creator_username = get_current_user().username
+ meta_info = dataset_pb2.DatasetMetaInfo(need_publish=dataset_parameter.need_publish,
+ value=dataset_parameter.value,
+ schema_checkers=dataset_parameter.schema_checkers)
+ dataset.set_meta_info(meta_info)
+ self._session.add(dataset)
+ return dataset
+
+ def get_dataset(self, dataset_id: int = 0) -> Union[dict, dataset_pb2.Dataset]:
+ dataset = self._session.query(Dataset).with_polymorphic([ProcessedDataset,
+ Dataset]).filter(Dataset.id == dataset_id).first()
+ if not dataset:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ return dataset.to_proto()
+
+ def get_dataset_preview(self, dataset_id: int, batch_id: int) -> dict:
+ batch = self._session.query(DataBatch).get(batch_id)
+ if batch is None:
+ raise NotFoundException(f'Failed to find data batch: {batch_id}')
+ dataset = self._session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise NotFoundException(f'Failed to find dataset: {dataset_id}')
- dataset_path = dataset.path
- # meta is generated from sparkapp/pipeline/analyzer.py
- meta_path = dataset_meta_path(dataset_path)
- # data format:
- # {
- # 'dtypes': {
- # 'f01': 'bigint'
- # },
- # 'samples': [
- # [1],
- # [0],
- # ],
- # 'metrics': {
- # 'f01': {
- # 'count': '2',
- # 'mean': '0.0015716767309123998',
- # 'stddev': '0.03961485047808605',
- # 'min': '0',
- # 'max': '1',
- # 'missing_count': '0'
- # }
- # }
- # }
+ reader = DataReader(dataset.path)
+ if dataset.is_image():
+ thumbnail_dir_path = DatasetDirectory(dataset_path=dataset.path).thumbnails_path(
+ batch_name=batch.batch_name)
+ meta = reader.image_metadata(thumbnail_dir_path=thumbnail_dir_path, batch_name=batch.batch_name)
+ else:
+ meta = reader.metadata(batch_name=batch.batch_name)
+ return meta.get_preview()
+
+ def feature_metrics(self, name: str, dataset_id: int, data_batch_id: int) -> dict:
+ dataset = self._session.query(Dataset).get(dataset_id)
+ if dataset is None:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ batch = self._session.query(DataBatch).get(data_batch_id)
+ if batch is None:
+ raise NotFoundException(f'Failed to find data batch: {data_batch_id}')
+ meta = DataReader(dataset.path).metadata(batch_name=batch.batch_name)
val = {}
- try:
- val = json.loads(self._file_manager.read(meta_path))
- except Exception as e: # pylint: disable=broad-except
- logging.info(
- f'failed to read meta file, path: {meta_path}, err: {e}')
- return {}
- # feature is generated from sparkapp/pipeline/analyzer.py
- feature_path = dataset_features_path(dataset_path)
- try:
- val['metrics'] = json.loads(self._file_manager.read(feature_path))
- except Exception as e: # pylint: disable=broad-except
- logging.info(
- f'failed to read feature file, path: {feature_path}, err: {e}')
+ val['name'] = name
+ val['metrics'] = meta.get_metrics_by_name(name)
+ val['hist'] = meta.get_hist_by_name(name)
return val
- def feature_metrics(self, name: str, dataset_id: int = 0) -> dict:
- dataset = self._session.query(Dataset).filter(
- Dataset.id == dataset_id).first()
+ def get_published_datasets(self,
+ project_id: int,
+ kind: Optional[DatasetJobKind] = None,
+ uuid: Optional[str] = None,
+ state: Optional[ResourceState] = None,
+ filter_exp: Optional[FilterExpression] = None,
+ time_range: Optional[timedelta] = None) -> List[dataset_pb2.ParticipantDatasetRef]:
+ query = self.query_dataset_with_parent_job()
+ query = query.options(joinedload(Dataset.data_batches))
+ query = query.filter(Dataset.project_id == project_id)
+ query = query.filter(Dataset.is_published.is_(True))
+ if kind is not None:
+ query = query.filter(Dataset.dataset_kind == kind)
+ if uuid is not None:
+ query = query.filter(Dataset.uuid == uuid)
+ if state is not None:
+ query = self.filter_dataset_state(query, frontend_states=[state])
+ if filter_exp is not None:
+ query = self._published_dataset_filter_builder.build_query(query, filter_exp)
+ if time_range:
+ query = query.filter(DatasetJob.time_range == time_range)
+ query = query.order_by(Dataset.id.desc())
+ datasets_ref = []
+ for dataset in query.all():
+ meta_info = dataset.get_meta_info()
+ dataset_ref = dataset_pb2.ParticipantDatasetRef(
+ uuid=dataset.uuid,
+ name=dataset.name,
+ format=DatasetFormat(dataset.dataset_format).name,
+ file_size=dataset.get_file_size(),
+ updated_at=to_timestamp(dataset.updated_at),
+ value=meta_info.value,
+ dataset_kind=dataset.dataset_kind.name,
+ dataset_type=dataset.dataset_type.name,
+ auth_status=dataset.auth_status.name if dataset.auth_status else '')
+ datasets_ref.append(dataset_ref)
+ return datasets_ref
+
+ def publish_dataset(self, dataset_id: int, value: int = 0) -> Dataset:
+ dataset: Dataset = self._session.query(Dataset).get(dataset_id)
if not dataset:
raise NotFoundException(f'Failed to find dataset: {dataset_id}')
- dataset_path = dataset.path
- feature_path = dataset_features_path(dataset_path)
- # data format:
- # {
- # 'name': 'f01',
- # 'metrics': {
- # 'count': '2',
- # 'mean': '0.0015716767309123998',
- # 'stddev': '0.03961485047808605',
- # 'min': '0',
- # 'max': '1',
- # 'missing_count': '0'
- # },
- # 'hist': {
- # 'x': [0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5,
- # 0.6000000000000001, 0.7000000000000001, 0.8, 0.9, 1],
- # 'y': [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
- # }
- # }
- val = {}
+ if dataset.dataset_kind != DatasetKindV2.RAW:
+ raise MethodNotAllowedException(
+ f'{dataset.dataset_kind.value} dataset cannot publish, dataset_id: {dataset.id}')
+ dataset.is_published = True
+ meta_info = dataset.get_meta_info()
+ meta_info.value = value
+ dataset.set_meta_info(meta_info)
+ # TODO(liuhehan): a hack to add uuid for old dataset when publish, remove in the feature
+ if dataset.uuid is None:
+ dataset.uuid = resource_uuid()
+
+ # create review ticket
+ if dataset.ticket_uuid is None:
+ ticket_helper = get_ticket_helper(session=self._session)
+ ticket_helper.create_ticket(TicketType.PUBLISH_DATASET, TicketDetails(uuid=dataset.uuid))
+
+ return dataset
+
+ def withdraw_dataset(self, dataset_id: int):
+ dataset = self._session.query(Dataset).get(dataset_id)
+ if not dataset:
+ raise NotFoundException(f'Failed to find dataset: {dataset_id}')
+ dataset.is_published = False
+
+ # reset ticket
+ dataset.ticket_uuid = None
+ dataset.ticket_status = None
+
+ def cleanup_dataset(self, dataset: Dataset, delay_time: Optional[timedelta] = None) -> Tuple[bool, List[str]]:
+ """ Register the dataset and underlying files to be cleaned with the cleanup module.
+
+ Args:
+ dataset: dataset which needs an exclusive lock to this row
+ delay_time: delay time to start the cleanup task afterwards
+
+ Raises:
+ ResourceConflictException: if the `dataset` can not be deleted
+ """
+ if not delay_time:
+ delay_time = self.DATASET_CLEANUP_DEFAULT_DELAY
+ target_start_at = to_timestamp(now() + delay_time)
+ is_deletable, error_msgs = DatasetDeleteDependency(self._session).is_deletable(dataset)
+ if not is_deletable:
+ error = {dataset.id: error_msgs}
+ raise ResourceConflictException(f'{error}')
+ logging.info(f'will mark the dataset:{dataset.id} is deleted')
+ payload = CleanupPayload(paths=[dataset.path])
+ dataset_cleanup_parm = CleanupParameter(resource_id=dataset.id,
+ resource_type=ResourceType.DATASET.name,
+ payload=payload,
+ target_start_at=target_start_at)
+ CleanupService(self._session).create_cleanup(cleanup_parmeter=dataset_cleanup_parm)
+ dataset.deleted_at = now()
+ logging.info(f'Has registered a cleanup for dataset:{dataset.id}')
+
+ def get_data_batch(self, dataset: Dataset, event_time: Optional[datetime] = None) -> Optional[DataBatch]:
+ if dataset.dataset_type == DatasetType.PSI:
+ return self._session.query(DataBatch).filter(DataBatch.dataset_id == dataset.id).first()
+ return self._session.query(DataBatch).filter(DataBatch.dataset_id == dataset.id).filter(
+ DataBatch.event_time == event_time).first()
+
+
+class DataSourceService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def create_data_source(self, data_source_parameter: dataset_pb2.DataSource) -> DataSource:
+ # check project existense
+ project = self._session.query(Project).get(data_source_parameter.project_id)
+ if project is None:
+ raise NotFoundException(message=f'cannot found project with id: {data_source_parameter.project_id}')
+
+ data_source = DataSource(
+ name=data_source_parameter.name,
+ comment=data_source_parameter.comment,
+ uuid=resource_uuid(),
+ is_published=False,
+ path=data_source_parameter.url,
+ project_id=data_source_parameter.project_id,
+ creator_username=get_current_user().username,
+ dataset_format=DatasetFormat[data_source_parameter.dataset_format].value,
+ store_format=StoreFormat(data_source_parameter.store_format),
+ dataset_type=DatasetType(data_source_parameter.dataset_type),
+ )
+ meta_info = dataset_pb2.DatasetMetaInfo(datasource_type=data_source_parameter.type,
+ is_user_upload=data_source_parameter.is_user_upload,
+ is_user_export=data_source_parameter.is_user_export)
+ data_source.set_meta_info(meta_info)
+ self._session.add(data_source)
+ return data_source
+
+ def get_data_sources(self, project_id: int) -> List[dataset_pb2.DataSource]:
+ data_sources = self._session.query(DataSource).order_by(Dataset.created_at.desc())
+ if project_id > 0:
+ data_sources = data_sources.filter_by(project_id=project_id)
+ data_source_ref = []
+ for data_source in data_sources.all():
+ # ignore user upload data_source and user export data_source
+ meto_info = data_source.get_meta_info()
+ if not meto_info.is_user_upload and not meto_info.is_user_export:
+ data_source_ref.append(data_source.to_proto())
+ return data_source_ref
+
+ def delete_data_source(self, data_source_id: int):
+ data_source = self._session.query(DataSource).get(data_source_id)
+ if not data_source:
+ raise NotFoundException(message=f'cannot find data_source with id: {data_source_id}')
+ dataset_jobs = self._session.query(DatasetJob).filter_by(input_dataset_id=data_source.id).all()
+ for dataset_job in dataset_jobs:
+ if not dataset_job.is_finished():
+ message = f'data_source {data_source.name} is still being processed by dataset_job {dataset_job.id}'
+ logging.error(message)
+ raise ResourceConflictException(message=message)
+
+ data_source.deleted_at = now()
+
+
+class BatchService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def create_batch(self, batch_parameter: dataset_pb2.BatchParameter) -> DataBatch:
+ dataset: Dataset = self._session.query(Dataset).filter_by(id=batch_parameter.dataset_id).first()
+ if dataset is None:
+ message = f'Failed to find dataset: {batch_parameter.dataset_id}'
+ logging.error(message)
+ raise NotFoundException(message=message)
+ if dataset.dataset_type == DatasetType.PSI:
+ # There should be one batch of a dataset in PSI mode.
+ # So the naming convention of batch is `{dataset_path}/batch/0`.
+ if len(dataset.data_batches) != 0:
+ raise InvalidArgumentException(details='there should be one batch for PSI dataset')
+ batch_folder_name = '0'
+ event_time = None
+ elif dataset.dataset_type == DatasetType.STREAMING:
+ if batch_parameter.event_time == 0:
+ raise InvalidArgumentException(
+ details='event time should be specified when create batch of streaming dataset')
+ event_time = from_timestamp(batch_parameter.event_time)
+ if batch_parameter.cron_type == CronType.DAILY:
+ batch_folder_name = parse_event_time_to_daily_folder_name(event_time=event_time)
+ elif batch_parameter.cron_type == CronType.HOURLY:
+ batch_folder_name = parse_event_time_to_hourly_folder_name(event_time=event_time)
+ else:
+ # old data may not has cron_tpye, we just set to daily cron_type by default
+ batch_folder_name = parse_event_time_to_daily_folder_name(event_time=event_time)
+ batch_parameter.path = os.path.join(dataset.path, 'batch', batch_folder_name)
+ # Create batch
+ batch = DataBatch(dataset_id=dataset.id,
+ event_time=event_time,
+ comment=batch_parameter.comment,
+ path=batch_parameter.path,
+ name=batch_folder_name)
+ self._session.add(batch)
+
+ return batch
+
+ def get_next_batch(self, data_batch: DataBatch) -> Optional[DataBatch]:
+ parent_dataset_job_stage: DatasetJobStage = data_batch.latest_parent_dataset_job_stage
+ if not parent_dataset_job_stage:
+ logging.warning(f'not found parent_dataset_job_stage, data_batch id: {data_batch.id}')
+ return None
+ parent_dataset_job: DatasetJob = parent_dataset_job_stage.dataset_job
+ if not parent_dataset_job:
+ logging.warning(f'not found parent_dataset_job, data_batch id: {data_batch.id}')
+ return None
+ if not parent_dataset_job.is_cron():
+ logging.warning(f'data_batch {data_batch.id} belongs to a non-cron dataset_job, has no next batch')
+ return None
+ next_time = data_batch.event_time + parent_dataset_job.time_range
+ return self._session.query(DataBatch).filter(DataBatch.dataset_id == data_batch.dataset_id).filter(
+ DataBatch.event_time == next_time).first()
+
+
+class DatasetJobService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def is_local(self, dataset_job_kind: DatasetJobKind) -> bool:
+ return dataset_job_kind in LOCAL_DATASET_JOBS
+
+ def need_distribute(self, dataset_job: DatasetJob) -> bool:
+ # coordinator_id != 0 means it is a participant,
+ # and dataset_job need to distribute when it has participants
+ if dataset_job.coordinator_id != 0:
+ return True
+ return not self.is_local(dataset_job.kind)
+
+ # filter participants which need to distribute dataset_job
+ def get_participants_need_distribute(self, dataset_job: DatasetJob) -> List:
+ participants = []
+ if self.need_distribute(dataset_job):
+ participants = ParticipantService(self._session).get_platform_participants_by_project(
+ dataset_job.project_id)
+ return participants
+
+ def create_as_coordinator(self,
+ project_id: int,
+ kind: DatasetJobKind,
+ output_dataset_id: int,
+ global_configs: DatasetJobGlobalConfigs,
+ time_range: timedelta = None) -> DatasetJob:
+ my_domain_name = SettingService.get_system_info().pure_domain_name
+ input_dataset_uuid = global_configs.global_configs[my_domain_name].dataset_uuid
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == input_dataset_uuid).first()
+ if input_dataset is None:
+ raise InvalidArgumentException(f'failed to find dataset {input_dataset_uuid}')
+ output_dataset = self._session.query(Dataset).get(output_dataset_id)
+ if output_dataset is None:
+ return InvalidArgumentException(details=f'failed to find dataset id {output_dataset_id}')
+ configer = DatasetJobConfiger.from_kind(kind, self._session)
+ config = configer.get_config()
try:
- feature_data = json.loads(self._file_manager.read(feature_path))
- val['name'] = name
- val['metrics'] = feature_data.get(name, {})
- except Exception as e: # pylint: disable=broad-except
- logging.info(
- f'failed to read feature file, path: {feature_path}, err: {e}')
- # hist is generated from sparkapp/pipeline/analyzer.py
- hist_path = dataset_hist_path(dataset_path)
+ global_configs = configer.auto_config_variables(global_configs)
+ fill_variables(config, global_configs.global_configs[my_domain_name].variables, dry_run=True)
+ except TypeError as err:
+ raise InvalidArgumentException(details=err.args) from err
+
+ dataset_job = DatasetJob()
+ dataset_job.uuid = resource_uuid()
+ dataset_job.project_id = project_id
+ dataset_job.coordinator_id = 0
+ dataset_job.input_dataset_id = input_dataset.id
+ dataset_job.output_dataset_id = output_dataset_id
+ dataset_job.name = output_dataset.name
+ dataset_job.kind = kind
+ dataset_job.time_range = time_range
+ dataset_job.set_global_configs(global_configs)
+ dataset_job.set_context(dataset_pb2.DatasetJobContext(has_stages=True))
+ current_user = get_current_user()
+ if current_user is not None:
+ dataset_job.creator_username = current_user.username
+
+ self._session.add(dataset_job)
+
+ emit_dataset_job_submission_store(uuid=dataset_job.uuid, kind=dataset_job.kind, coordinator_id=0)
+
+ return dataset_job
+
+ def create_as_participant(self,
+ project_id: int,
+ kind: DatasetJobKind,
+ global_configs: DatasetJobGlobalConfigs,
+ config: WorkflowDefinition,
+ output_dataset_id: int,
+ coordinator_id: int,
+ uuid: str,
+ creator_username: str,
+ time_range: timedelta = None) -> DatasetJob:
+ my_domain_name = SettingService.get_system_info().pure_domain_name
+ my_dataset_job_config = global_configs.global_configs[my_domain_name]
+
+ input_dataset = self._session.query(Dataset).filter(Dataset.uuid == my_dataset_job_config.dataset_uuid).first()
+ if input_dataset is None:
+ return InvalidArgumentException(details=f'failed to find dataset {my_dataset_job_config.dataset_uuid}')
+ output_dataset = self._session.query(Dataset).get(output_dataset_id)
+ if output_dataset is None:
+ return InvalidArgumentException(details=f'failed to find dataset id {output_dataset_id}')
try:
- hist_data = json.loads(self._file_manager.read(hist_path))
- val['hist'] = hist_data.get(name, {})
- except Exception as e: # pylint: disable=broad-except
- logging.info(
- f'failed to read hist file, path: {hist_path}, err: {e}')
- return val
+ fill_variables(config, my_dataset_job_config.variables, dry_run=True)
+ except TypeError as err:
+ raise InvalidArgumentException(details=err.args) from err
- def get_datasets(self, project_id: int = 0) -> List[Dataset]:
- q = self._session.query(Dataset).order_by(Dataset.created_at.desc())
- if project_id > 0:
- q = q.filter(Dataset.project_id == project_id)
- return q.all()
+ dataset_job = DatasetJob()
+ dataset_job.uuid = uuid
+ dataset_job.project_id = project_id
+ dataset_job.input_dataset_id = input_dataset.id
+ dataset_job.output_dataset_id = output_dataset_id
+ dataset_job.name = output_dataset.name
+ dataset_job.coordinator_id = coordinator_id
+ dataset_job.kind = kind
+ dataset_job.time_range = time_range
+ dataset_job.creator_username = creator_username
+ dataset_job.set_context(dataset_pb2.DatasetJobContext(has_stages=True))
+
+ self._session.add(dataset_job)
+
+ emit_dataset_job_submission_store(uuid=dataset_job.uuid,
+ kind=dataset_job.kind,
+ coordinator_id=dataset_job.coordinator_id)
+
+ return dataset_job
+
+ def start_dataset_job(self, dataset_job: DatasetJob):
+ dataset_job.state = DatasetJobState.RUNNING
+ dataset_job.started_at = now()
+
+ def finish_dataset_job(self, dataset_job: DatasetJob, finish_state: DatasetJobState):
+ if finish_state not in DATASET_JOB_FINISHED_STATE:
+ raise ValueError(f'get invalid finish state: [{finish_state}] when try to finish dataset_job')
+ dataset_job.state = finish_state
+ dataset_job.finished_at = now()
+ duration = to_timestamp(dataset_job.finished_at) - to_timestamp(dataset_job.created_at)
+ emit_dataset_job_duration_store(duration=duration,
+ uuid=dataset_job.uuid,
+ kind=dataset_job.kind,
+ coordinator_id=dataset_job.coordinator_id,
+ state=finish_state)
+
+ def start_cron_scheduler(self, dataset_job: DatasetJob):
+ if not dataset_job.is_cron():
+ logging.warning(f'[dataset_job_service]: failed to start schedule a non-cron dataset_job {dataset_job.id}')
+ return
+ dataset_job.scheduler_state = DatasetJobSchedulerState.RUNNABLE
+
+ def stop_cron_scheduler(self, dataset_job: DatasetJob):
+ if not dataset_job.is_cron():
+ logging.warning(f'[dataset_job_service]: failed to stop schedule a non-cron dataset_job {dataset_job.id}')
+ return
+ dataset_job.scheduler_state = DatasetJobSchedulerState.STOPPED
+
+ def delete_dataset_job(self, dataset_job: DatasetJob):
+ if not dataset_job.is_finished():
+ message = f'Failed to delete dataset_job: {dataset_job.id}; ' \
+ f'reason: dataset_job state is {dataset_job.state.name}'
+ logging.error(message)
+ raise ResourceConflictException(message)
+ dataset_job.deleted_at = now()
+
+
+class DatasetJobStageService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ # TODO(liuhehan): delete in the near future after we use as_coordinator func
+ def create_dataset_job_stage(self,
+ project_id: int,
+ dataset_job_id: int,
+ output_data_batch_id: int,
+ uuid: Optional[str] = None,
+ name: Optional[str] = None):
+ dataset_job: DatasetJob = self._session.query(DatasetJob).get(dataset_job_id)
+ if dataset_job is None:
+ raise InvalidArgumentException(details=f'failed to find dataset_job, id: {dataset_job_id}')
+ output_data_batch: DataBatch = self._session.query(DataBatch).get(output_data_batch_id)
+ if output_data_batch is None:
+ raise InvalidArgumentException(details=f'failed to find output_data_batch, id: {output_data_batch_id}')
+
+ dataset_job_stages: DatasetJobStage = self._session.query(DatasetJobStage).filter(
+ DatasetJobStage.data_batch_id == output_data_batch_id).filter(
+ DatasetJobStage.dataset_job_id == dataset_job_id).order_by(DatasetJobStage.created_at.desc()).all()
+ index = len(dataset_job_stages)
+ if index != 0 and not dataset_job_stages[0].is_finished():
+ raise InvalidArgumentException(
+ details=f'newest dataset_job_stage is still running, id: {dataset_job_stages[0].id}')
+
+ dataset_job_stage = DatasetJobStage()
+ dataset_job_stage.uuid = uuid or resource_uuid()
+ dataset_job_stage.name = name or f'{output_data_batch.name}-stage{index}'
+ dataset_job_stage.event_time = output_data_batch.event_time
+ dataset_job_stage.dataset_job_id = dataset_job_id
+ dataset_job_stage.data_batch_id = output_data_batch_id
+ dataset_job_stage.project_id = project_id
+ if dataset_job.coordinator_id == 0:
+ dataset_job_stage.set_global_configs(dataset_job.get_global_configs())
+ self._session.add(dataset_job_stage)
+
+ self._session.flush()
+ if dataset_job.kind not in MICRO_DATASET_JOB:
+ output_data_batch.latest_parent_dataset_job_stage_id = dataset_job_stage.id
+ elif dataset_job.kind == DatasetJobKind.ANALYZER:
+ output_data_batch.latest_analyzer_dataset_job_stage_id = dataset_job_stage.id
+
+ dataset_job.state = DatasetJobState.PENDING
+
+ return dataset_job_stage
+
+ def create_dataset_job_stage_as_coordinator(self, project_id: int, dataset_job_id: int, output_data_batch_id: int,
+ global_configs: DatasetJobGlobalConfigs):
+ dataset_job: DatasetJob = self._session.query(DatasetJob).get(dataset_job_id)
+ if dataset_job is None:
+ raise InvalidArgumentException(details=f'failed to find dataset_job, id: {dataset_job_id}')
+ output_data_batch: DataBatch = self._session.query(DataBatch).get(output_data_batch_id)
+ if output_data_batch is None:
+ raise InvalidArgumentException(details=f'failed to find output_data_batch, id: {output_data_batch_id}')
+
+ dataset_job_stages: DatasetJobStage = self._session.query(DatasetJobStage).filter(
+ DatasetJobStage.data_batch_id == output_data_batch_id).filter(
+ DatasetJobStage.dataset_job_id == dataset_job_id).order_by(DatasetJobStage.id.desc()).all()
+ index = len(dataset_job_stages)
+ if index != 0 and not dataset_job_stages[0].is_finished():
+ raise InvalidArgumentException(
+ details=f'newest dataset_job_stage is still running, id: {dataset_job_stages[0].id}')
+
+ dataset_job_stage = DatasetJobStage()
+ dataset_job_stage.uuid = resource_uuid()
+ dataset_job_stage.name = f'{output_data_batch.name}-stage{index}'
+ dataset_job_stage.event_time = output_data_batch.event_time
+ dataset_job_stage.dataset_job_id = dataset_job_id
+ dataset_job_stage.data_batch_id = output_data_batch_id
+ dataset_job_stage.project_id = project_id
+ dataset_job_stage.coordinator_id = 0
+ dataset_job_stage.set_global_configs(global_configs)
+ self._session.add(dataset_job_stage)
+
+ self._session.flush()
+ if dataset_job.kind not in MICRO_DATASET_JOB:
+ output_data_batch.latest_parent_dataset_job_stage_id = dataset_job_stage.id
+ elif dataset_job.kind == DatasetJobKind.ANALYZER:
+ output_data_batch.latest_analyzer_dataset_job_stage_id = dataset_job_stage.id
+
+ dataset_job.state = DatasetJobState.PENDING
+
+ return dataset_job_stage
+
+ def create_dataset_job_stage_as_participant(self, project_id: int, dataset_job_id: int, output_data_batch_id: int,
+ uuid: str, name: str, coordinator_id: int):
+ dataset_job: DatasetJob = self._session.query(DatasetJob).get(dataset_job_id)
+ if dataset_job is None:
+ raise InvalidArgumentException(details=f'failed to find dataset_job, id: {dataset_job_id}')
+ output_data_batch: DataBatch = self._session.query(DataBatch).get(output_data_batch_id)
+ if output_data_batch is None:
+ raise InvalidArgumentException(details=f'failed to find output_data_batch, id: {output_data_batch_id}')
+
+ dataset_job_stages: DatasetJobStage = self._session.query(DatasetJobStage).filter(
+ DatasetJobStage.data_batch_id == output_data_batch_id).filter(
+ DatasetJobStage.dataset_job_id == dataset_job_id).order_by(DatasetJobStage.id.desc()).all()
+ index = len(dataset_job_stages)
+ if index != 0 and not dataset_job_stages[0].is_finished():
+ raise InvalidArgumentException(
+ details=f'newest dataset_job_stage is still running, id: {dataset_job_stages[0].id}')
+
+ dataset_job_stage = DatasetJobStage()
+ dataset_job_stage.uuid = uuid
+ dataset_job_stage.name = name
+ dataset_job_stage.event_time = output_data_batch.event_time
+ dataset_job_stage.dataset_job_id = dataset_job_id
+ dataset_job_stage.data_batch_id = output_data_batch_id
+ dataset_job_stage.project_id = project_id
+ dataset_job_stage.coordinator_id = coordinator_id
+ self._session.add(dataset_job_stage)
+
+ self._session.flush()
+ if dataset_job.kind not in MICRO_DATASET_JOB:
+ output_data_batch.latest_parent_dataset_job_stage_id = dataset_job_stage.id
+ elif dataset_job.kind == DatasetJobKind.ANALYZER:
+ output_data_batch.latest_analyzer_dataset_job_stage_id = dataset_job_stage.id
+
+ dataset_job.state = DatasetJobState.PENDING
+
+ return dataset_job_stage
+
+ def start_dataset_job_stage(self, dataset_job_stage: DatasetJobStage):
+ dataset_job_stage.state = DatasetJobState.RUNNING
+ dataset_job_stage.started_at = now()
+
+ newest_job_stage_id, *_ = self._session.query(
+ DatasetJobStage.id).filter(DatasetJobStage.dataset_job_id == dataset_job_stage.dataset_job_id).order_by(
+ DatasetJobStage.created_at.desc()).first()
+ if newest_job_stage_id == dataset_job_stage.id:
+ dataset_job_stage.dataset_job.state = DatasetJobState.RUNNING
+
+ def finish_dataset_job_stage(self, dataset_job_stage: DatasetJobStage, finish_state: DatasetJobState):
+ if finish_state not in DATASET_JOB_FINISHED_STATE:
+ raise ValueError(f'get invalid finish state: [{finish_state}] when try to finish dataset_job')
+ dataset_job_stage.state = finish_state
+ dataset_job_stage.finished_at = now()
+
+ newest_job_stage_id, *_ = self._session.query(
+ DatasetJobStage.id).filter(DatasetJobStage.dataset_job_id == dataset_job_stage.dataset_job_id).order_by(
+ DatasetJobStage.created_at.desc()).first()
+ if newest_job_stage_id == dataset_job_stage.id:
+ dataset_job_stage.dataset_job.state = finish_state
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/services_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/services_test.py
new file mode 100644
index 000000000..4b92d69bc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/services_test.py
@@ -0,0 +1,1398 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import tempfile
+import unittest
+from unittest.mock import patch, MagicMock, ANY
+from datetime import datetime, timedelta, timezone
+from google.protobuf.struct_pb2 import Value
+from dataset_directory import DatasetDirectory
+from pathlib import Path
+
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.dataset.meta_data import ImageMetaData, MetaData
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterExpressionKind, SimpleExpression
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition, WorkflowDefinition
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.exceptions import InvalidArgumentException, NotFoundException, ResourceConflictException
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.dataset.models import (DataSourceType, DatasetFormat, DatasetJob, DatasetJobKind,
+ DatasetJobStage, DataBatch, DatasetKindV2, DatasetType, Dataset,
+ DatasetJobState, DataSource, ImportType, ResourceState, StoreFormat,
+ DatasetJobSchedulerState)
+from fedlearner_webconsole.dataset.services import (BatchService, DataReader, DataSourceService, DatasetJobService,
+ DatasetJobStageService, DatasetService)
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp, now
+from fedlearner_webconsole.proto.cleanup_pb2 import CleanupParameter, CleanupPayload
+from testing.common import NoWebServerTestCase
+from testing.dataset import FakeDatasetJobConfiger
+from testing.fake_time_patcher import FakeTimePatcher
+
+
+class DatasetServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.source_dir = tempfile.mkdtemp()
+ self._set_common_dataset()
+
+ def _set_common_dataset(self):
+ with db.session_scope() as session:
+ dataset = Dataset(
+ name='default dataset1',
+ uuid='default uuid',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path=str(self.source_dir),
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ )
+ session.add(dataset)
+ session.commit()
+
+ @patch('envs.Envs.STORAGE_ROOT', '/data')
+ def test_create_dataset(self):
+ dataset_para = dataset_pb2.DatasetParameter(name='test',
+ uuid='fake_uuid',
+ type=DatasetType.PSI.value,
+ comment='this is a comment',
+ project_id=1,
+ kind=DatasetKindV2.RAW.value,
+ format=DatasetFormat.IMAGE.name,
+ is_published=False,
+ import_type=ImportType.NO_COPY.value,
+ store_format=StoreFormat.CSV.value)
+
+ with db.session_scope() as session:
+ with self.assertRaises(NotFoundException):
+ DatasetService(session).create_dataset(dataset_parameter=dataset_para)
+ session.commit()
+
+ with db.session_scope() as session:
+ project = Project()
+ session.add(project)
+ session.commit()
+ dataset_para.project_id = project.id
+
+ with db.session_scope() as session:
+ DatasetService(session).create_dataset(dataset_parameter=dataset_para)
+ session.commit()
+
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).filter(Dataset.name == 'test').one()
+ self.assertEqual(dataset.comment, 'this is a comment')
+ self.assertEqual(dataset.path, 'file:///data/dataset/fake_uuid_test')
+ self.assertEqual(dataset.is_published, False)
+ self.assertEqual(dataset.import_type, ImportType.NO_COPY)
+ self.assertEqual(dataset.store_format, StoreFormat.CSV)
+
+ dataset_para_published = dataset_pb2.DatasetParameter(name='test_published',
+ uuid='fake_uuid_published',
+ type=DatasetType.PSI.value,
+ comment='this is a comment',
+ project_id=1,
+ kind=DatasetKindV2.PROCESSED.value,
+ format=DatasetFormat.IMAGE.name,
+ is_published=True,
+ creator_username='fakeuser')
+
+ with db.session_scope() as session:
+ DatasetService(session).create_dataset(dataset_parameter=dataset_para_published)
+ session.commit()
+
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).filter(Dataset.name == 'test_published').one()
+ self.assertEqual(dataset.comment, 'this is a comment')
+ self.assertEqual(dataset.path, 'file:///data/dataset/fake_uuid_published_test-published')
+ self.assertEqual(dataset.is_published, True)
+ self.assertEqual(dataset.creator_username, 'fakeuser')
+
+ def test_publish_dataset(self, mock_get_publish_reward: MagicMock):
+ with db.session_scope() as session:
+ unpublished_dataset = Dataset(id=11,
+ uuid='123',
+ name='none_published_dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=False,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value)
+ no_uuid_dataset = Dataset(id=12,
+ name='none_published_dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=False,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(unpublished_dataset)
+ session.add(no_uuid_dataset)
+ session.commit()
+ # test unpublish to publish
+ with db.session_scope() as session:
+ dataset = DatasetService(session=session).publish_dataset(dataset_id=11, value=100)
+ session.commit()
+ mock_get_publish_reward.assert_called_once_with(dataset_uuid='123')
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(11)
+ self.assertTrue(dataset.is_published)
+ self.assertEqual(dataset.get_meta_info().value, 100)
+ self.assertIsNotNone(dataset.ticket_uuid)
+ self.assertEqual(dataset.ticket_status, TicketStatus.APPROVED)
+ # test publish to publish
+ with db.session_scope() as session:
+ dataset = DatasetService(session=session).publish_dataset(dataset_id=11)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(11)
+ self.assertTrue(dataset.is_published)
+ self.assertIsNotNone(dataset.ticket_uuid)
+ # test unknown dataset
+ with db.session_scope() as session:
+ with self.assertRaises(NotFoundException):
+ DatasetService(session=session).publish_dataset(dataset_id=100)
+ # test no uuid dataset
+ with db.session_scope() as session:
+ dataset = DatasetService(session=session).publish_dataset(dataset_id=12)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(12)
+ self.assertIsNotNone(dataset.uuid)
+ self.assertIsNotNone(dataset.ticket_uuid)
+ self.assertEqual(dataset.ticket_status, TicketStatus.APPROVED)
+
+ def test_withdraw_dataset(self):
+ with db.session_scope() as session:
+ published_dataset = Dataset(id=10,
+ uuid='123',
+ name='published_dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=True,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value,
+ ticket_uuid='ticket_uuid',
+ ticket_status=TicketStatus.APPROVED)
+ session.add(published_dataset)
+ session.commit()
+ # test publish to unpublish
+ with db.session_scope() as session:
+ DatasetService(session=session).withdraw_dataset(dataset_id=10)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertFalse(dataset.is_published)
+ self.assertIsNone(dataset.ticket_uuid)
+ self.assertIsNone(dataset.ticket_status)
+ # test unpublish to unpublish
+ with db.session_scope() as session:
+ DatasetService(session=session).withdraw_dataset(dataset_id=10)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ self.assertFalse(dataset.is_published)
+ # test unknown dataset
+ with db.session_scope() as session:
+ with self.assertRaises(NotFoundException):
+ DatasetService(session=session).publish_dataset(dataset_id=100)
+
+ def test_get_published_datasets(self):
+ update_time = datetime(2012, 1, 14, 12, 0, 5)
+ with db.session_scope() as session:
+ dataset1 = Dataset(
+ id=10,
+ uuid='1',
+ name='dataset_1',
+ dataset_type=DatasetType.PSI,
+ project_id=1,
+ path='/data/dataset_1/',
+ is_published=True,
+ dataset_format=DatasetFormat.TABULAR.value,
+ updated_at=update_time,
+ dataset_kind=DatasetKindV2.RAW,
+ )
+ dataset_job1 = DatasetJob(id=10,
+ uuid=resource_uuid(),
+ input_dataset_id=0,
+ output_dataset_id=dataset1.id,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ deleted_at=datetime(2022, 1, 1))
+ session.add_all([dataset1, dataset_job1])
+ dataset2 = Dataset(
+ id=11,
+ uuid='2',
+ name='dataset_2',
+ dataset_type=DatasetType.PSI,
+ project_id=1,
+ path='/data/dataset_2/',
+ is_published=True,
+ dataset_format=DatasetFormat.TABULAR.value,
+ updated_at=update_time,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ )
+ dataset_job2 = DatasetJob(id=11,
+ uuid=resource_uuid(),
+ input_dataset_id=0,
+ output_dataset_id=dataset2.id,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ time_range=timedelta(days=1))
+ session.add_all([dataset2, dataset_job2])
+ data_source = DataSource(
+ id=12,
+ uuid='3',
+ name='dataset_3',
+ dataset_type=DatasetType.PSI,
+ project_id=1,
+ path='/data/dataset_2/',
+ is_published=True,
+ dataset_format=DatasetFormat.TABULAR.value,
+ updated_at=update_time,
+ )
+ session.add(data_source)
+ dataset4 = Dataset(
+ id=13,
+ uuid='4',
+ name='dataset_4',
+ dataset_type=DatasetType.PSI,
+ project_id=1,
+ path='/data/dataset_4/',
+ is_published=True,
+ dataset_format=DatasetFormat.TABULAR.value,
+ updated_at=update_time,
+ )
+ dataset_job4 = DatasetJob(id=13,
+ uuid=resource_uuid(),
+ input_dataset_id=0,
+ output_dataset_id=dataset4.id,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ project_id=1,
+ state=DatasetJobState.STOPPED)
+ session.add_all([dataset4, dataset_job4])
+ session.commit()
+ dataref_1 = dataset_pb2.ParticipantDatasetRef(uuid='1',
+ name='dataset_1',
+ format=DatasetFormat.TABULAR.name,
+ file_size=0,
+ updated_at=to_timestamp(update_time),
+ dataset_kind=DatasetKindV2.RAW.name,
+ dataset_type=DatasetType.PSI.name,
+ auth_status='PENDING')
+ dataref_2 = dataset_pb2.ParticipantDatasetRef(uuid='2',
+ name='dataset_2',
+ format=DatasetFormat.TABULAR.name,
+ file_size=0,
+ updated_at=to_timestamp(update_time),
+ dataset_kind=DatasetKindV2.PROCESSED.name,
+ dataset_type=DatasetType.PSI.name,
+ auth_status='PENDING')
+ with db.session_scope() as session:
+ dataset_service = DatasetService(session=session)
+ self.assertEqual(dataset_service.get_published_datasets(project_id=1, state=ResourceState.SUCCEEDED),
+ [dataref_2, dataref_1])
+ self.assertEqual(
+ dataset_service.get_published_datasets(project_id=1,
+ kind=DatasetKindV2.RAW,
+ state=ResourceState.SUCCEEDED), [dataref_1])
+ self.assertEqual(dataset_service.get_published_datasets(project_id=1, uuid='2'), [dataref_2])
+ filter_exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='uuid', string_value='2'))
+ self.assertEqual(dataset_service.get_published_datasets(project_id=1, filter_exp=filter_exp), [dataref_2])
+ self.assertEqual(
+ dataset_service.get_published_datasets(project_id=1,
+ state=ResourceState.SUCCEEDED,
+ time_range=timedelta(days=1)), [dataref_2])
+
+ @patch('fedlearner_webconsole.cleanup.services.CleanupService.create_cleanup')
+ def test_cleanup_dataset(self, cleanup_mock: MagicMock):
+ # create a test dateset
+ with db.session_scope() as session:
+ published_dataset = Dataset(id=333,
+ uuid='123',
+ name='published_dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=True,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(published_dataset)
+ session.commit()
+ # test cleanup failed case
+ cleanup_mock.reset_mock()
+ cleanup_mock.side_effect = Exception('fake-exception')
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).with_for_update().populate_existing().get(333)
+ service = DatasetService(session)
+ self.assertRaises(Exception, service.cleanup_dataset, dataset)
+ # test cleanup dataset success
+ cleanup_mock.side_effect = None
+ cleanup_mock.return_value = None
+ fake_time = datetime(2022, 4, 14, 0, 0, 0, 0, tzinfo=timezone.utc)
+ time_patcher = FakeTimePatcher()
+ time_patcher.start(fake_time)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).with_for_update().populate_existing().get(333)
+ service = DatasetService(session)
+ service.cleanup_dataset(dataset)
+ session.commit()
+ expected_target_start_at = to_timestamp(fake_time + DatasetService.DATASET_CLEANUP_DEFAULT_DELAY)
+ expected_payload = CleanupPayload(paths=['/data/dataset/123'])
+ expected_cleanup_param = CleanupParameter(resource_id=333,
+ resource_type='DATASET',
+ target_start_at=expected_target_start_at,
+ payload=expected_payload)
+ cleanup_mock.assert_called_with(cleanup_parmeter=expected_cleanup_param)
+ with db.session_scope() as session:
+ self.assertRaises(NotFoundException, service.get_dataset, 333)
+ time_patcher.stop()
+
+ def test_query_dataset_with_parent_job(self):
+ with db.session_scope() as session:
+ query = DatasetService(session).query_dataset_with_parent_job()
+ statement = self.generate_mysql_statement(query)
+ expected_statement = 'FROM datasets_v2 LEFT OUTER JOIN dataset_jobs_v2 ' \
+ 'ON dataset_jobs_v2.output_dataset_id = datasets_v2.id ' \
+ 'AND dataset_jobs_v2.input_dataset_id != datasets_v2.id'
+ self.assertTrue(expected_statement in statement)
+
+ @patch('fedlearner_webconsole.dataset.services.DataReader.metadata')
+ @patch('fedlearner_webconsole.dataset.services.DataReader.image_metadata')
+ def test_get_dataset_preview(self, mock_image_metadata: MagicMock, mock_metadata: MagicMock):
+ mock_metadata.return_value = MetaData()
+ mock_image_metadata.return_value = ImageMetaData(
+ thumbnail_dir_path='/data/dataset/123/meta/20220101/thumbnails')
+ with db.session_scope() as session:
+ dataset = Dataset(id=10,
+ uuid='dataset uuid',
+ name='dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=True,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='20220101',
+ path='/data/dataset/123/batch/20220101',
+ dataset_id=10,
+ event_time=datetime(year=2022, month=1, day=1))
+ session.add(data_batch)
+ session.commit()
+ with db.session_scope() as session:
+ dataset_service = DatasetService(session=session)
+ dataset_service.get_dataset_preview(dataset_id=10, batch_id=1)
+ mock_metadata.assert_called_once_with(batch_name='20220101')
+ dataset = session.query(Dataset).get(10)
+ dataset.dataset_format = DatasetFormat.IMAGE.value
+ session.flush()
+ dataset_service.get_dataset_preview(dataset_id=10, batch_id=1)
+ mock_image_metadata.assert_called_once_with(batch_name='20220101',
+ thumbnail_dir_path='/data/dataset/123/meta/20220101/thumbnails')
+
+ @patch('fedlearner_webconsole.dataset.services.DataReader.metadata')
+ def test_feature_metrics(self, mock_metadata: MagicMock):
+ mock_metadata.return_value = MetaData()
+ with db.session_scope() as session:
+ dataset = Dataset(id=10,
+ uuid='dataset uuid',
+ name='dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=True,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='20220101',
+ path='/data/dataset/123/batch/20220101',
+ dataset_id=10,
+ event_time=datetime(year=2022, month=1, day=1))
+ session.add(data_batch)
+ session.commit()
+ with db.session_scope() as session:
+ dataset_service = DatasetService(session=session)
+ val = dataset_service.feature_metrics(name='raw_id', dataset_id=10, data_batch_id=1)
+ expected_val = {
+ 'name': 'raw_id',
+ 'metrics': {},
+ 'hist': {},
+ }
+ self.assertEqual(val, expected_val)
+ mock_metadata.assert_called_once_with(batch_name='20220101')
+
+ def test_get_data_batch(self):
+ with db.session_scope() as session:
+ dataset = Dataset(id=10,
+ uuid='dataset uuid',
+ name='dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ is_published=True,
+ project_id=1,
+ dataset_format=DatasetFormat.TABULAR.value)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='20220101',
+ path='/data/dataset/123/batch/20220101',
+ dataset_id=10,
+ event_time=datetime(year=2022, month=1, day=1))
+ session.add(data_batch)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(10)
+ dataset_service = DatasetService(session=session)
+ data_batch = dataset_service.get_data_batch(dataset=dataset, event_time=datetime(2022, 1, 1))
+ self.assertEqual(data_batch.id, 1)
+ data_batch = dataset_service.get_data_batch(dataset=dataset, event_time=datetime(2022, 1, 2))
+ self.assertIsNone(data_batch)
+ dataset.dataset_type = DatasetType.PSI
+ data_batch = dataset_service.get_data_batch(dataset=dataset)
+ self.assertEqual(data_batch.id, 1)
+
+
+class DataSourceServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.source_dir = tempfile.mkdtemp()
+ self._set_default_project()
+ self._set_common_dataset()
+
+ def _set_default_project(self):
+ with db.session_scope() as session:
+ project = Project(id=1, name='default_project')
+ session.add(project)
+ session.commit()
+
+ def _set_common_dataset(self):
+ with db.session_scope() as session:
+ dataset = Dataset(
+ name='default dataset1',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path=str(self.source_dir),
+ project_id=1,
+ )
+ session.add(dataset)
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.services.get_current_user', lambda: User(id=1, username='xiaohang'))
+ def test_create_data_source(self):
+ data_source_parameter = dataset_pb2.DataSource(name='default data_source',
+ url='hdfs:///fack_url',
+ project_id=1,
+ type=DataSourceType.HDFS.value,
+ is_user_upload=False,
+ dataset_format=DatasetFormat.TABULAR.name,
+ store_format=StoreFormat.CSV.value,
+ dataset_type=DatasetType.PSI.value)
+ with db.session_scope() as session:
+ data_source = DataSourceService(session=session).create_data_source(
+ data_source_parameter=data_source_parameter)
+ session.commit()
+ with db.session_scope() as session:
+ data_source = session.query(DataSource).filter_by(name='default data_source').first()
+ self.assertEqual(data_source.name, data_source_parameter.name)
+ self.assertEqual(data_source.path, data_source_parameter.url)
+ self.assertEqual(data_source.project_id, data_source_parameter.project_id)
+ self.assertEqual(data_source.creator_username, 'xiaohang')
+ self.assertEqual(data_source.get_meta_info().datasource_type, data_source_parameter.type)
+ self.assertEqual(data_source.get_meta_info().is_user_upload, data_source_parameter.is_user_upload)
+ self.assertEqual(data_source.dataset_format, DatasetFormat.TABULAR.value)
+ self.assertEqual(data_source.store_format, StoreFormat.CSV)
+ self.assertEqual(data_source.dataset_type, DatasetType.PSI)
+ self.assertIsNotNone(data_source.id)
+ self.assertIsNotNone(data_source.created_at)
+
+ def test_get_data_sources(self):
+ with db.session_scope() as session:
+ datasource_1 = DataSource(id=100,
+ uuid='data_source_1_uuid',
+ name='datasource_1',
+ path='hdfs:///data/fake_path_1',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ is_published=False,
+ store_format=StoreFormat.TFRECORDS,
+ dataset_format=DatasetFormat.IMAGE.value,
+ dataset_type=DatasetType.STREAMING)
+ datasource_1.set_meta_info(meta=dataset_pb2.DatasetMetaInfo(
+ datasource_type=DataSourceType.HDFS.value, is_user_upload=False, is_user_export=False))
+ datasource_2 = DataSource(id=101,
+ uuid='data_source_2_uuid',
+ name='datasource_2',
+ path='file:///data/fake_path_2',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 6),
+ is_published=False,
+ store_format=StoreFormat.CSV,
+ dataset_format=DatasetFormat.TABULAR.value,
+ dataset_type=DatasetType.PSI)
+ datasource_2.set_meta_info(meta=dataset_pb2.DatasetMetaInfo(
+ datasource_type=DataSourceType.FILE.value, is_user_upload=False, is_user_export=False))
+ datasource_3 = DataSource(id=102,
+ uuid='data_source_3_uuid',
+ name='datasource_3',
+ path='/upload/fake_path_3',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ is_published=False)
+ datasource_3.set_meta_info(meta=dataset_pb2.DatasetMetaInfo(
+ datasource_type=DataSourceType.FILE.value, is_user_upload=True, is_user_export=False))
+ datasource_4 = DataSource(id=103,
+ uuid='data_source_4_uuid',
+ name='datasource_4',
+ path='/upload/fake_path_4',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 8),
+ is_published=False)
+ datasource_4.set_meta_info(meta=dataset_pb2.DatasetMetaInfo(
+ datasource_type=DataSourceType.FILE.value, is_user_upload=True, is_user_export=True))
+ session.add(datasource_1)
+ session.add(datasource_2)
+ session.add(datasource_3)
+ session.add(datasource_4)
+ session.commit()
+ with db.session_scope() as session:
+ expected_datasources = [
+ dataset_pb2.DataSource(id=101,
+ uuid='data_source_2_uuid',
+ name='datasource_2',
+ url='file:///data/fake_path_2',
+ project_id=1,
+ created_at=to_timestamp(datetime(2012, 1, 14, 12, 0, 6)),
+ type=DataSourceType.FILE.value,
+ is_user_upload=False,
+ dataset_format='TABULAR',
+ store_format='CSV',
+ dataset_type='PSI'),
+ dataset_pb2.DataSource(id=100,
+ uuid='data_source_1_uuid',
+ name='datasource_1',
+ url='hdfs:///data/fake_path_1',
+ project_id=1,
+ created_at=to_timestamp(datetime(2012, 1, 14, 12, 0, 5)),
+ type=DataSourceType.HDFS.value,
+ is_user_upload=False,
+ dataset_format='IMAGE',
+ store_format='TFRECORDS',
+ dataset_type='STREAMING')
+ ]
+ data_sources = DataSourceService(session=session).get_data_sources(project_id=1)
+ self.assertEqual(data_sources, expected_datasources)
+
+ def test_delete_data_source(self):
+ with db.session_scope() as session:
+ datasource = DataSource(id=100,
+ uuid=resource_uuid(),
+ name='datasource',
+ path='hdfs:///data/fake_path',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ is_published=False)
+ session.add(datasource)
+ session.commit()
+ with db.session_scope() as session:
+ DataSourceService(session=session).delete_data_source(data_source_id=100)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(DataSource).execution_options(include_deleted=True).get(100)
+ self.assertIsNotNone(dataset.deleted_at)
+ with self.assertRaises(NotFoundException):
+ DataSourceService(session=session).delete_data_source(data_source_id=102)
+
+ with db.session_scope() as session:
+ datasource = DataSource(id=101,
+ uuid=resource_uuid(),
+ name='datasource',
+ path='hdfs:///data/fake_path',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ is_published=False)
+ session.add(datasource)
+ dataset_job = DatasetJob(id=10,
+ uuid='test-uuid',
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=datasource.id,
+ output_dataset_id=2,
+ coordinator_id=1,
+ state=DatasetJobState.RUNNING)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ with self.assertRaises(ResourceConflictException):
+ DataSourceService(session=session).delete_data_source(data_source_id=101)
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(10)
+ dataset_job.state = DatasetJobState.SUCCEEDED
+ session.commit()
+ with db.session_scope() as session:
+ DataSourceService(session=session).delete_data_source(data_source_id=101)
+ session.commit()
+ with db.session_scope() as session:
+ dataset = session.query(DataSource).execution_options(include_deleted=True).get(101)
+ self.assertIsNotNone(dataset.deleted_at)
+
+
+class BatchServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test-project')
+ session.add(project)
+ session.commit()
+
+ def test_create_orphan_batch(self):
+ batch_para = dataset_pb2.BatchParameter(comment='this is a comment for batch',
+ path='/data/dataset/test/batch/batch_1')
+
+ with db.session_scope() as session:
+ with self.assertRaises(NotFoundException):
+ BatchService(session).create_batch(batch_parameter=batch_para)
+ session.commit()
+
+ @patch('envs.Envs.STORAGE_ROOT', '/data')
+ def test_create_psi_batch(self):
+ batch_para = dataset_pb2.BatchParameter()
+ with db.session_scope() as session:
+ dataset_para = dataset_pb2.DatasetParameter(name='test',
+ uuid='fake_uuid',
+ type=DatasetType.PSI.value,
+ comment='this is a comment',
+ project_id=1,
+ path='/data/dataset/test/',
+ kind=DatasetKindV2.EXPORTED.value,
+ format=DatasetFormat.IMAGE.name)
+ dataset = DatasetService(session).create_dataset(dataset_parameter=dataset_para)
+ session.commit()
+ batch_para.dataset_id = dataset.id
+
+ with db.session_scope() as session:
+ batch = BatchService(session).create_batch(batch_parameter=batch_para)
+ session.commit()
+ self.assertEqual(batch.dataset_id, batch_para.dataset_id)
+ self.assertIsNone(batch.event_time)
+ self.assertEqual(batch.path, '/data/dataset/test/batch/0')
+ self.assertEqual(batch.name, '0')
+
+ with db.session_scope() as session:
+ with self.assertRaises(InvalidArgumentException):
+ batch = BatchService(session).create_batch(batch_parameter=batch_para)
+ session.commit()
+
+ @patch('envs.Envs.STORAGE_ROOT', '/data')
+ def test_create_streaming_batch(self):
+ batch_para = dataset_pb2.BatchParameter()
+ with db.session_scope() as session:
+ dataset_para = dataset_pb2.DatasetParameter(name='test',
+ uuid='fake_uuid',
+ type=DatasetType.STREAMING.value,
+ comment='this is a comment',
+ project_id=1,
+ kind=DatasetKindV2.RAW.value,
+ format=DatasetFormat.IMAGE.name)
+ dataset = DatasetService(session).create_dataset(dataset_parameter=dataset_para)
+ session.commit()
+ batch_para.dataset_id = dataset.id
+
+ with db.session_scope() as session:
+ with self.assertRaises(InvalidArgumentException):
+ batch_para.event_time = 0
+ batch = BatchService(session).create_batch(batch_parameter=batch_para)
+ session.commit()
+
+ with db.session_scope() as session:
+ event_time = now()
+ batch_para.event_time = to_timestamp(event_time)
+ batch = BatchService(session).create_batch(batch_parameter=batch_para)
+ session.flush()
+ self.assertEqual(batch.dataset_id, batch_para.dataset_id)
+ self.assertEqual(batch.event_time, event_time.replace(microsecond=0))
+ self.assertEqual(batch.path, f'file:///data/dataset/fake_uuid_test/batch/{event_time.strftime("%Y%m%d")}')
+ self.assertEqual(batch.name, event_time.strftime('%Y%m%d'))
+
+ with db.session_scope() as session:
+ event_time = now()
+ batch_para.event_time = to_timestamp(event_time)
+ batch_para.cron_type = dataset_pb2.CronType.HOURLY
+ batch = BatchService(session).create_batch(batch_parameter=batch_para)
+ session.flush()
+ self.assertEqual(batch.dataset_id, batch_para.dataset_id)
+ self.assertEqual(batch.event_time, event_time.replace(microsecond=0))
+ self.assertEqual(batch.path,
+ f'file:///data/dataset/fake_uuid_test/batch/{event_time.strftime("%Y%m%d-%H")}')
+ self.assertEqual(batch.name, event_time.strftime('%Y%m%d-%H'))
+
+ def test_get_next_batch(self):
+ with db.session_scope() as session:
+ dataset = Dataset(id=1,
+ name='output dataset',
+ uuid=resource_uuid(),
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment2',
+ path='/data/dataset/123',
+ project_id=1)
+ session.add(dataset)
+
+ data_batch = DataBatch(id=1,
+ name='20220101-08',
+ dataset_id=1,
+ event_time=datetime(year=2000, month=1, day=1, hour=8),
+ latest_parent_dataset_job_stage_id=1)
+ session.add(data_batch)
+ session.commit()
+
+ # test no stage
+ with db.session_scope() as session:
+ data_batch = session.query(DataBatch).get(1)
+ self.assertIsNone(BatchService(session).get_next_batch(data_batch=data_batch), None)
+
+ # test no dataset_job
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStage(name='20220101-stage0',
+ dataset_job_id=1,
+ data_batch_id=1,
+ project_id=1,
+ event_time=datetime(year=2000, month=1, day=1),
+ uuid=resource_uuid())
+ session.add(dataset_job_stage)
+ session.commit()
+ with db.session_scope() as session:
+ data_batch = session.query(DataBatch).get(1)
+ self.assertIsNone(BatchService(session).get_next_batch(data_batch=data_batch), None)
+
+ # test no next batch
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='test-uuid',
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ workflow_id=0,
+ input_dataset_id=0,
+ output_dataset_id=1,
+ coordinator_id=0,
+ time_range=timedelta(hours=1))
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ data_batch = session.query(DataBatch).get(1)
+ self.assertIsNone(BatchService(session).get_next_batch(data_batch=data_batch), None)
+
+ # test get next batch
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=2,
+ name='20220102',
+ dataset_id=1,
+ event_time=datetime(year=2000, month=1, day=1, hour=9))
+ session.add(data_batch)
+ session.commit()
+ with db.session_scope() as session:
+ data_batch = session.query(DataBatch).get(1)
+ next_data_batch = session.query(DataBatch).get(2)
+ self.assertEqual(BatchService(session).get_next_batch(data_batch=data_batch), next_data_batch)
+
+
+class DatasetJobServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test-project')
+ session.add(project)
+ session.flush([project])
+
+ input_dataset = DataSource(name='input dataset',
+ uuid=resource_uuid(),
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=project.id)
+ session.add(input_dataset)
+
+ output_dataset = Dataset(name='output dataset',
+ uuid=resource_uuid(),
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=project.id)
+ session.add(output_dataset)
+
+ session.commit()
+ self.project_id = project.id
+ self.input_dataset_id = input_dataset.id
+ self.output_dataset_id = output_dataset.id
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobConfiger.from_kind',
+ lambda *args: FakeDatasetJobConfiger(None))
+ @patch('fedlearner_webconsole.dataset.services.get_current_user', lambda: User(id=1, username='test user'))
+ @patch('fedlearner_webconsole.dataset.services.emit_dataset_job_submission_store')
+ def test_create_dataset_job_as_coordinator(self, mock_emit_dataset_job_submission_store: MagicMock):
+ with db.session_scope() as session:
+ input_dataset = session.query(DataSource).get(self.input_dataset_id)
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs()
+ global_configs.global_configs['test_domain'].MergeFrom(
+ dataset_pb2.DatasetJobConfig(dataset_uuid=input_dataset.uuid,
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ dataset_job = DatasetJobService(session).create_as_coordinator(project_id=self.project_id,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ output_dataset_id=self.output_dataset_id,
+ global_configs=global_configs,
+ time_range=timedelta(days=1))
+ session.commit()
+ self.assertEqual(len(dataset_job.get_global_configs().global_configs['test_domain'].variables), 2)
+ self.assertEqual(dataset_job.name, 'output dataset')
+ self.assertTrue(dataset_job.get_context().has_stages)
+ self.assertEqual(dataset_job.time_range, timedelta(days=1))
+ self.assertEqual(dataset_job.creator_username, 'test user')
+ mock_emit_dataset_job_submission_store.assert_called_once_with(uuid=ANY,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ coordinator_id=0)
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(name='hahaha', domain_name='fl-test_domain.com', pure_domain_name='test_domain'))
+ @patch('fedlearner_webconsole.dataset.services.emit_dataset_job_submission_store')
+ def test_create_dataset_job_as_participant(self, mock_emit_dataset_job_submission_store: MagicMock):
+ with db.session_scope() as session:
+ input_dataset = session.query(Dataset).get(self.input_dataset_id)
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs()
+ global_configs.global_configs['test_domain'].MergeFrom(
+ dataset_pb2.DatasetJobConfig(dataset_uuid=input_dataset.uuid,
+ variables=[
+ Variable(name='hello',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=1)),
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ ]))
+ config = WorkflowDefinition(variables=[
+ Variable(name='hello', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=1))
+ ],
+ job_definitions=[
+ JobDefinition(variables=[
+ Variable(name='hello_from_job',
+ value_type=Variable.ValueType.NUMBER,
+ typed_value=Value(number_value=3))
+ ])
+ ])
+
+ dataset_job = DatasetJobService(session).create_as_participant(project_id=self.project_id,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ config=config,
+ output_dataset_id=self.output_dataset_id,
+ coordinator_id=1,
+ uuid='u12345',
+ global_configs=global_configs,
+ creator_username='test user')
+ session.commit()
+ self.assertTrue(dataset_job.get_context().has_stages)
+ mock_emit_dataset_job_submission_store.assert_called_once_with(uuid='u12345',
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ coordinator_id=1)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).filter(DatasetJob.uuid == 'u12345').first()
+ self.assertIsNone(dataset_job.global_configs)
+ self.assertEqual(dataset_job.output_dataset_id, self.output_dataset_id)
+ self.assertEqual(dataset_job.kind, DatasetJobKind.IMPORT_SOURCE)
+ self.assertEqual(dataset_job.project_id, self.project_id)
+ self.assertEqual(dataset_job.coordinator_id, 1)
+ self.assertEqual(dataset_job.name, 'output dataset')
+ self.assertEqual(dataset_job.creator_username, 'test user')
+ self.assertIsNone(dataset_job.time_range)
+
+ # test with time_range
+ mock_emit_dataset_job_submission_store.reset_mock()
+ with db.session_scope() as session:
+ dataset_job = DatasetJobService(session).create_as_participant(project_id=self.project_id,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ config=config,
+ output_dataset_id=self.output_dataset_id,
+ coordinator_id=1,
+ uuid='u12345 with time_range',
+ global_configs=global_configs,
+ creator_username='test user',
+ time_range=timedelta(days=1))
+ session.commit()
+ self.assertTrue(dataset_job.get_context().has_stages)
+ mock_emit_dataset_job_submission_store.assert_called_once_with(uuid='u12345 with time_range',
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ coordinator_id=1)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).filter(DatasetJob.uuid == 'u12345 with time_range').first()
+ self.assertEqual(dataset_job.time_range, timedelta(days=1))
+
+ def test_is_local(self):
+ with db.session_scope() as session:
+ service = DatasetJobService(session=session)
+ self.assertTrue(service.is_local(dataset_job_kind=DatasetJobKind.IMPORT_SOURCE))
+ self.assertFalse(service.is_local(dataset_job_kind=DatasetJobKind.DATA_ALIGNMENT))
+
+ def test_need_distribute(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='test-uuid',
+ kind=DatasetJobKind.EXPORT,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ coordinator_id=1)
+ service = DatasetJobService(session=session)
+ self.assertTrue(service.need_distribute(dataset_job=dataset_job))
+ dataset_job.coordinator_id = 0
+ self.assertFalse(service.need_distribute(dataset_job=dataset_job))
+ dataset_job.kind = DatasetJobKind.OT_PSI_DATA_JOIN
+ self.assertTrue(service.need_distribute(dataset_job=dataset_job))
+
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobService.need_distribute')
+ @patch('fedlearner_webconsole.dataset.services.ParticipantService.get_platform_participants_by_project')
+ def test_get_participants_need_distribute(self, mock_get_platform_participants_by_project: MagicMock,
+ mock_need_distribute: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='test-uuid',
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ coordinator_id=1)
+ service = DatasetJobService(session=session)
+
+ # test no need to distribute
+ mock_need_distribute.return_value = False
+ self.assertEqual(service.get_participants_need_distribute(dataset_job), [])
+
+ # test no plateform participant
+ mock_need_distribute.return_value = True
+ mock_get_platform_participants_by_project.return_value = []
+ self.assertEqual(service.get_participants_need_distribute(dataset_job), [])
+
+ # test get platform participants
+ mock_need_distribute.return_value = True
+ mock_get_platform_participants_by_project.return_value = ['participants1', 'participants2']
+ self.assertEqual(service.get_participants_need_distribute(dataset_job), ['participants1', 'participants2'])
+
+ @patch('fedlearner_webconsole.dataset.services.now', lambda: datetime(2022, 1, 1, 12, 0, 0))
+ def test_start(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ coordinator_id=2,
+ uuid='uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING)
+ DatasetJobService(session).start_dataset_job(dataset_job)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.state, DatasetJobState.RUNNING)
+ self.assertEqual(dataset_job.started_at, datetime(2022, 1, 1, 12, 0, 0))
+
+ @patch('fedlearner_webconsole.dataset.services.now', lambda: datetime(2022, 1, 1, 12, 0, 0))
+ @patch('fedlearner_webconsole.dataset.services.emit_dataset_job_duration_store')
+ def test_finish(self, mock_emit_dataset_job_duration_store: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ coordinator_id=2,
+ uuid='uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2022, 1, 1, 11, 0, 0))
+ dataset_job_service = DatasetJobService(session)
+ with self.assertRaises(ValueError):
+ dataset_job_service.finish_dataset_job(dataset_job=dataset_job, finish_state=DatasetJobState.RUNNING)
+ self.assertEqual(dataset_job.state, DatasetJobState.PENDING)
+ self.assertIsNone(dataset_job.finished_at)
+ mock_emit_dataset_job_duration_store.assert_not_called()
+ dataset_job_service.finish_dataset_job(dataset_job=dataset_job, finish_state=DatasetJobState.SUCCEEDED)
+ mock_emit_dataset_job_duration_store.assert_called_once_with(duration=3600,
+ uuid='uuid',
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ coordinator_id=2,
+ state=DatasetJobState.SUCCEEDED)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.state, DatasetJobState.SUCCEEDED)
+ self.assertEqual(dataset_job.finished_at, datetime(2022, 1, 1, 12, 0, 0))
+
+ def test_start_cron_scheduler(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ coordinator_id=2,
+ uuid='uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2022, 1, 1, 11, 0, 0),
+ scheduler_state=DatasetJobSchedulerState.STOPPED)
+ DatasetJobService(session=session).start_cron_scheduler(dataset_job=dataset_job)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.STOPPED)
+ dataset_job.time_range = timedelta(days=1)
+ DatasetJobService(session=session).start_cron_scheduler(dataset_job=dataset_job)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.RUNNABLE)
+
+ def test_stop_cron_scheduler(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ coordinator_id=2,
+ uuid='uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ created_at=datetime(2022, 1, 1, 11, 0, 0),
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE)
+ DatasetJobService(session=session).stop_cron_scheduler(dataset_job=dataset_job)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.RUNNABLE)
+ dataset_job.time_range = timedelta(days=1)
+ DatasetJobService(session=session).stop_cron_scheduler(dataset_job=dataset_job)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.STOPPED)
+
+ def test_delete_dataset_job(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='test-uuid',
+ kind=DatasetJobKind.DATA_ALIGNMENT,
+ state=DatasetJobState.PENDING,
+ project_id=1,
+ workflow_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ coordinator_id=1)
+
+ # test dataset_job is not finished:
+ with self.assertRaises(ResourceConflictException):
+ DatasetJobService(session).delete_dataset_job(dataset_job=dataset_job)
+ self.assertIsNone(dataset_job.deleted_at)
+
+ # test stop dataset_job successfully
+ dataset_job.state = DatasetJobState.SUCCEEDED
+ DatasetJobService(session).delete_dataset_job(dataset_job=dataset_job)
+ self.assertIsNotNone(dataset_job.deleted_at)
+
+
+class DatasetJobStageServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='test-project')
+ session.add(project)
+ session.flush([project])
+
+ input_dataset = DataSource(name='input dataset',
+ uuid=resource_uuid(),
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/data_source/123',
+ project_id=project.id)
+ session.add(input_dataset)
+
+ output_dataset = Dataset(name='output dataset',
+ uuid=resource_uuid(),
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment2',
+ path='/data/dataset/123',
+ project_id=project.id)
+ session.add(output_dataset)
+ session.flush()
+
+ dataset_job = DatasetJob(uuid='test-uuid',
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ state=DatasetJobState.SUCCEEDED,
+ project_id=project.id,
+ workflow_id=0,
+ input_dataset_id=input_dataset.id,
+ output_dataset_id=output_dataset.id,
+ coordinator_id=0)
+ dataset_job.set_global_configs(
+ dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()}))
+ session.add(dataset_job)
+
+ output_data_batch = DataBatch(name='20220101',
+ dataset_id=output_dataset.id,
+ event_time=datetime(year=2000, month=1, day=1))
+ session.add(output_data_batch)
+ session.flush()
+
+ dataset_job_stage = DatasetJobStage(name='20220101-stage0',
+ dataset_job_id=dataset_job.id,
+ data_batch_id=output_data_batch.id,
+ project_id=project.id,
+ event_time=datetime(year=2000, month=1, day=1),
+ uuid=resource_uuid())
+ session.add(dataset_job_stage)
+
+ session.commit()
+ self.project_id = project.id
+ self.input_dataset_id = input_dataset.id
+ self.output_dataset_id = output_dataset.id
+ self.dataset_job_id = dataset_job.id
+ self.output_data_batch_id = output_data_batch.id
+ self.dataset_job_stage_id = dataset_job_stage.id
+
+ def test_create_dataset_job_stage(self):
+ with db.session_scope() as session:
+ with self.assertRaises(InvalidArgumentException):
+ dataset_job_stage = DatasetJobStageService(session).create_dataset_job_stage(
+ project_id=self.project_id,
+ dataset_job_id=self.dataset_job_id,
+ output_data_batch_id=self.output_data_batch_id)
+ dataset_job_stage = session.query(DatasetJobStage).get(self.dataset_job_stage_id)
+ dataset_job_stage.state = DatasetJobState.SUCCEEDED
+ session.commit()
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageService(session).create_dataset_job_stage(
+ project_id=self.project_id,
+ dataset_job_id=self.dataset_job_id,
+ output_data_batch_id=self.output_data_batch_id)
+ session.commit()
+ dataset_job_stage_id = dataset_job_stage.id
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_id)
+ self.assertEqual(dataset_job_stage.name, '20220101-stage1')
+ self.assertEqual(dataset_job_stage.event_time, datetime(year=2000, month=1, day=1))
+ self.assertEqual(dataset_job_stage.dataset_job_id, self.dataset_job_id)
+ self.assertEqual(dataset_job_stage.data_batch_id, self.output_data_batch_id)
+ self.assertEqual(dataset_job_stage.project_id, self.project_id)
+ dataset_job = session.query(DatasetJob).get(self.dataset_job_id)
+ self.assertEqual(dataset_job.state, DatasetJobState.PENDING)
+ data_batch = session.query(DataBatch).get(self.output_data_batch_id)
+ self.assertEqual(data_batch.latest_parent_dataset_job_stage_id, 2)
+
+ def test_create_dataset_job_stage_as_coordinator(self):
+ global_configs = dataset_pb2.DatasetJobGlobalConfigs(global_configs={'test': dataset_pb2.DatasetJobConfig()})
+ with db.session_scope() as session:
+ with self.assertRaises(InvalidArgumentException):
+ dataset_job_stage = DatasetJobStageService(session).create_dataset_job_stage_as_coordinator(
+ project_id=self.project_id,
+ dataset_job_id=self.dataset_job_id,
+ output_data_batch_id=self.output_data_batch_id,
+ global_configs=global_configs)
+ dataset_job_stage = session.query(DatasetJobStage).get(self.dataset_job_stage_id)
+ dataset_job_stage.state = DatasetJobState.SUCCEEDED
+ session.commit()
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageService(session).create_dataset_job_stage_as_coordinator(
+ project_id=self.project_id,
+ dataset_job_id=self.dataset_job_id,
+ output_data_batch_id=self.output_data_batch_id,
+ global_configs=global_configs)
+ session.commit()
+ dataset_job_stage_id = dataset_job_stage.id
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_id)
+ self.assertEqual(dataset_job_stage.name, '20220101-stage1')
+ self.assertEqual(dataset_job_stage.event_time, datetime(year=2000, month=1, day=1))
+ self.assertEqual(dataset_job_stage.dataset_job_id, self.dataset_job_id)
+ self.assertEqual(dataset_job_stage.data_batch_id, self.output_data_batch_id)
+ self.assertEqual(dataset_job_stage.project_id, self.project_id)
+ self.assertEqual(dataset_job_stage.get_global_configs(), global_configs)
+ self.assertTrue(dataset_job_stage.is_coordinator())
+ dataset_job = session.query(DatasetJob).get(self.dataset_job_id)
+ self.assertEqual(dataset_job.state, DatasetJobState.PENDING)
+ data_batch = session.query(DataBatch).get(self.output_data_batch_id)
+ self.assertEqual(data_batch.latest_parent_dataset_job_stage_id, 2)
+
+ def test_create_dataset_job_stage_as_participant(self):
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(self.dataset_job_stage_id)
+ dataset_job_stage.state = DatasetJobState.SUCCEEDED
+ session.commit()
+ with db.session_scope() as session:
+ dataset_job_stage = DatasetJobStageService(session).create_dataset_job_stage_as_participant(
+ project_id=self.project_id,
+ dataset_job_id=self.dataset_job_id,
+ output_data_batch_id=self.output_data_batch_id,
+ uuid='test dataset_job_stage uuid',
+ name='test dataset_job_stage',
+ coordinator_id=1)
+ session.commit()
+ dataset_job_stage_id = dataset_job_stage.id
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(dataset_job_stage_id)
+ self.assertEqual(dataset_job_stage.name, 'test dataset_job_stage')
+ self.assertEqual(dataset_job_stage.uuid, 'test dataset_job_stage uuid')
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+ self.assertEqual(dataset_job_stage.event_time, datetime(year=2000, month=1, day=1))
+ self.assertEqual(dataset_job_stage.dataset_job_id, self.dataset_job_id)
+ self.assertEqual(dataset_job_stage.data_batch_id, self.output_data_batch_id)
+ self.assertEqual(dataset_job_stage.project_id, self.project_id)
+ dataset_job = session.query(DatasetJob).get(self.dataset_job_id)
+ self.assertEqual(dataset_job.state, DatasetJobState.PENDING)
+ data_batch = session.query(DataBatch).get(self.output_data_batch_id)
+ self.assertEqual(data_batch.latest_parent_dataset_job_stage_id, 2)
+
+ def test_start_dataset_job_stage(self):
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(self.dataset_job_stage_id)
+ DatasetJobStageService(session).start_dataset_job_stage(dataset_job_stage)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.RUNNING)
+ self.assertEqual(dataset_job_stage.dataset_job.state, DatasetJobState.RUNNING)
+
+ def test_finish_dataset_job_stage(self):
+ with db.session_scope() as session:
+ dataset_job_stage = session.query(DatasetJobStage).get(self.dataset_job_stage_id)
+ DatasetJobStageService(session).finish_dataset_job_stage(dataset_job_stage, DatasetJobState.STOPPED)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.STOPPED)
+ self.assertEqual(dataset_job_stage.dataset_job.state, DatasetJobState.STOPPED)
+
+ with self.assertRaises(ValueError):
+ DatasetJobStageService(session).finish_dataset_job_stage(dataset_job_stage, DatasetJobState.RUNNING)
+
+
+class DataReaderTest(unittest.TestCase):
+
+ def test_metadata(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ dataset_path = f'{temp_dir}/dataset'
+ batch_name = '20220101'
+ meta_file = DatasetDirectory(dataset_path=dataset_path).batch_meta_file(batch_name=batch_name)
+
+ # test no meta
+ reader = DataReader(dataset_path=dataset_path).metadata(batch_name=batch_name)
+ self.assertEqual(reader.metadata, {})
+
+ # test get meta
+ meta_info = {
+ 'dtypes': [{
+ 'key': 'f01',
+ 'value': 'bigint'
+ }],
+ 'sample': [
+ [1],
+ [0],
+ ],
+ 'count': 1000,
+ 'metrics': {
+ 'f01': {
+ 'count': '2',
+ 'mean': '0.0015716767309123998',
+ 'stddev': '0.03961485047808605',
+ 'min': '0',
+ 'max': '1',
+ 'missing_count': '0'
+ }
+ },
+ 'hist': {
+ 'x': [
+ 0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6000000000000001, 0.7000000000000001, 0.8, 0.9,
+ 1
+ ],
+ 'y': [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
+ }
+ }
+ Path(meta_file.split('/_META')[0]).mkdir(parents=True)
+ with open(meta_file, 'w', encoding='utf-8') as f:
+ f.write(json.dumps(meta_info))
+ reader = DataReader(dataset_path=dataset_path).metadata(batch_name=batch_name)
+ self.assertEqual(reader.metadata, meta_info)
+
+ def test_image_metadata(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ dataset_path = f'{temp_dir}/dataset'
+ batch_name = '20220101'
+ dataset_directory = DatasetDirectory(dataset_path=dataset_path)
+ meta_file = dataset_directory.batch_meta_file(batch_name=batch_name)
+ thumbnail_dir_path = dataset_directory.thumbnails_path(batch_name=batch_name)
+
+ # test no meta
+ reader = DataReader(dataset_path=dataset_path).image_metadata(thumbnail_dir_path=thumbnail_dir_path,
+ batch_name=batch_name)
+ self.assertEqual(reader.metadata, {})
+ self.assertEqual(reader.thumbnail_dir_path, thumbnail_dir_path)
+
+ # test get meta
+ meta_info = {
+ 'dtypes': [{
+ 'key': 'f01',
+ 'value': 'bigint'
+ }],
+ 'sample': [
+ [1],
+ [0],
+ ],
+ 'count': 1000,
+ 'metrics': {
+ 'f01': {
+ 'count': '2',
+ 'mean': '0.0015716767309123998',
+ 'stddev': '0.03961485047808605',
+ 'min': '0',
+ 'max': '1',
+ 'missing_count': '0'
+ }
+ },
+ 'hist': {
+ 'x': [
+ 0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6000000000000001, 0.7000000000000001, 0.8, 0.9,
+ 1
+ ],
+ 'y': [12070, 0, 0, 0, 0, 0, 0, 0, 0, 19]
+ }
+ }
+ Path(meta_file.split('/_META')[0]).mkdir(parents=True)
+ with open(meta_file, 'w', encoding='utf-8') as f:
+ f.write(json.dumps(meta_info))
+ reader = DataReader(dataset_path=dataset_path).image_metadata(thumbnail_dir_path=thumbnail_dir_path,
+ batch_name=batch_name)
+ self.assertEqual(reader.metadata, meta_info)
+ self.assertEqual(reader.thumbnail_dir_path, thumbnail_dir_path)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/analyzer.py b/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/analyzer.py
deleted file mode 100644
index 5759b334e..000000000
--- a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/analyzer.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-
-import os
-import sys
-import json
-import logging
-
-import fsspec
-import pandas
-
-from pyspark.sql import SparkSession
-from pyspark.sql.functions import col, lit, sum
-from util import dataset_features_path, dataset_meta_path, dataset_hist_path
-
-
-def analyze(dataset_path: str, wildcard: str):
- # for example:
- # dataset_path: /data/fl_v2_fish_fooding/dataset/20210527_221741_pipeline/
- # wildcard: rds/**
- spark = SparkSession.builder.getOrCreate()
- files = os.path.join(dataset_path, wildcard)
- logging.info(f'### loading df..., input files path: {files}')
- df = spark.read.format('tfrecords').load(files)
- # df_stats
- df_missing = df.select(*(sum(col(c).isNull().cast('int')).alias(c)
- for c in df.columns)).withColumn(
- 'summary', lit('missing_count'))
- df_stats = df.describe().unionByName(df_missing)
- df_stats = df_stats.toPandas().set_index('summary').transpose()
- features_path = dataset_features_path(dataset_path)
- logging.info(f'### writing features, features path is {features_path}')
- content = json.dumps(df_stats.to_dict(orient='index'))
- with fsspec.open(features_path, mode='w') as f:
- f.write(content)
- # meta
- meta = {}
- # dtypes
- logging.info('### loading dtypes...')
- dtypes = {}
- for d in df.dtypes:
- k, v = d # (feature, type)
- dtypes[k] = v
- meta['dtypes'] = dtypes
- # sample count
- logging.info('### loading count...')
- meta['count'] = df.count()
- # sample
- logging.info('### loading sample...')
- meta['sample'] = df.head(20)
- # meta
- meta_path = dataset_meta_path(dataset_path)
- logging.info(f'### writing meta, path is {meta_path}')
- with fsspec.open(meta_path, mode='w') as f:
- f.write(json.dumps(meta))
- # feature histogram
- logging.info('### loading hist...')
- hist = {}
- for c in df.columns:
- # TODO: histogram is too slow and needs optimization
- x, y = df.select(c).rdd.flatMap(lambda x: x).histogram(10)
- hist[c] = {'x': x, 'y': y}
- hist_path = dataset_hist_path(dataset_path)
- logging.info(f'### writing hist, path is {hist_path}')
- with fsspec.open(hist_path, mode='w') as f:
- f.write(json.dumps(hist))
-
- spark.stop()
-
-
-if __name__ == '__main__':
- logging.basicConfig(level=logging.INFO)
- if len(sys.argv) != 3:
- logging.error(
- f'spark-submit {sys.argv[0]} [dataset_path] [file_wildcard]')
- sys.exit(-1)
-
- dataset_path, wildcard = sys.argv[1], sys.argv[2]
- analyze(dataset_path, wildcard)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/converter.py b/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/converter.py
deleted file mode 100644
index 248210d4d..000000000
--- a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/converter.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-
-import sys
-import os
-import logging
-
-from pyspark.sql import SparkSession
-from util import dataset_rds_path
-
-
-def convert(dataset_path: str, wildcard: str):
- # for example:
- # dataset_path: /data/fl_v2_fish_fooding/dataset/20210527_221741_pipeline/
- # wildcard: batch/**/*.csv
- files = os.path.join(dataset_path, wildcard)
- logging.info(f'### input files path: {files}')
- spark = SparkSession.builder.getOrCreate()
- if wildcard.endswith('*.csv'):
- df = spark.read.format('csv').option('header', 'true').option(
- 'inferSchema', 'true').load(files)
- elif wildcard.endswith('*.rd') or wildcard.endswith('*.tfrecords'):
- df = spark.read.format('tfrecords').load(files)
- else:
- logging.error(f'### no valid file wildcard, wildcard: {wildcard}')
- return
-
- df.printSchema()
- save_path = dataset_rds_path(dataset_path)
- logging.info(f'### saving to {save_path}, in tfrecords')
- df.write.format('tfrecords').save(save_path, mode='overwrite')
- spark.stop()
-
-
-if __name__ == '__main__':
- logging.basicConfig(level=logging.INFO)
- if len(sys.argv) != 3:
- logging.error(
- f'spark-submit {sys.argv[0]} [dataset_path] [file_wildcard]')
- sys.exit(-1)
-
- dataset_path, wildcard = sys.argv[1], sys.argv[2]
- convert(dataset_path, wildcard)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/transformer.py b/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/transformer.py
deleted file mode 100644
index 4c6620de0..000000000
--- a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/transformer.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import os
-import json
-import sys
-import logging
-
-from pyspark.sql import SparkSession
-from util import dataset_transformer_path
-
-
-def transform(dataset_path: str, wildcard: str, conf: str):
- # for example:
- # dataset_path: /data/fl_v2_fish_fooding/dataset/20210527_221741_pipeline/
- # wildcard: rds/** or data_block/**/*.data
- # conf: {"f00001": 0.0, "f00002": 1.0}
- spark = SparkSession.builder.getOrCreate()
- files = os.path.join(dataset_path, wildcard)
- conf_dict = json.loads(conf)
- logging.info(f'### input files path: {files}, config: {conf_dict}')
- df = spark.read.format('tfrecords').load(files)
- filled_df = df.fillna(conf_dict)
- save_path = dataset_transformer_path(dataset_path)
- logging.info(f'### saving to {save_path}')
- filled_df.write.format('tfrecords').save(save_path, mode='overwrite')
- spark.stop()
-
-
-if __name__ == '__main__':
- logging.basicConfig(level=logging.INFO)
- if len(sys.argv) != 4:
- logging.error(
- f'spark-submit {sys.argv[0]} [dataset_path] [wildcard] [config]')
- sys.exit(-1)
-
- dataset_path, wildcard, conf = sys.argv[1], sys.argv[2], sys.argv[3]
- transform(dataset_path, wildcard, conf)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/util.py b/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/util.py
deleted file mode 100644
index 14085e93e..000000000
--- a/web_console_v2/api/fedlearner_webconsole/dataset/sparkapp/pipeline/util.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import os
-
-
-def dataset_rds_path(dataset_path: str) -> str:
- return os.path.join(dataset_path, 'rds/')
-
-
-def dataset_features_path(dataset_path: str) -> str:
- return os.path.join(dataset_path, '_FEATURES')
-
-
-def dataset_meta_path(dataset_path: str) -> str:
- return os.path.join(dataset_path, '_META')
-
-
-def dataset_hist_path(dataset_path: str) -> str:
- return os.path.join(dataset_path, '_HIST')
-
-
-def dataset_transformer_path(dataset_path: str) -> str:
- return os.path.join(dataset_path, 'fe/')
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/util.py b/web_console_v2/api/fedlearner_webconsole/dataset/util.py
new file mode 100644
index 000000000..f3dcf1047
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/util.py
@@ -0,0 +1,182 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Functions in web_console_v2/inspection/util.py which used by webconsole
+import os
+from typing import Optional, Tuple
+from envs import Envs
+from slugify import slugify
+from urllib.parse import urlparse
+from datetime import datetime
+import enum
+
+from fedlearner_webconsole.dataset.consts import PLACEHOLDER, CRON_SCHEDULER_BATCH_NOT_READY_ERROR_MESSAGE, \
+ CRON_SCHEDULER_CERTAIN_BATCH_NOT_READY_ERROR_MESSAGE, CRON_SCHEDULER_CERTAIN_FOLDER_NOT_READY_ERROR_MESSAGE, \
+ CRON_SCHEDULER_FOLDER_NOT_READY_ERROR_MESSAGE, CRON_SCHEDULER_SUCCEEDED_MESSAGE
+from fedlearner_webconsole.utils.file_manager import FileManager
+
+DEFAULT_SCHEME_TYPE = 'file'
+
+
+class CronInterval(enum.Enum):
+ DAYS = 'DAYS'
+ HOURS = 'HOURS'
+
+
+def get_dataset_path(dataset_name: str, uuid: str):
+ root_dir = add_default_url_scheme(Envs.STORAGE_ROOT)
+ # Builds a path for dataset according to the dataset name
+ # Example: '/data/dataset/xxxxxxxxxxxxx_test-dataset
+ return f'{root_dir}/dataset/{uuid}_{slugify(dataset_name)[:32]}'
+
+
+def get_export_dataset_name(index: int, input_dataset_name: str, input_data_batch_name: Optional[str] = None):
+ if input_data_batch_name:
+ return f'export-{input_dataset_name}-{input_data_batch_name}-{index}'
+ return f'export-{input_dataset_name}-{index}'
+
+
+def add_default_url_scheme(url: str) -> str:
+ url_parser = urlparse(url)
+ data_source_type = url_parser.scheme
+ # set default source_type if no source_type found
+ if data_source_type == '' and url.startswith('/'):
+ url = f'{DEFAULT_SCHEME_TYPE}://{url}'
+ return url
+
+
+def _is_daily(file_name: str) -> bool:
+ # YYYYMMDD format, like '20220701'
+ # format must be YYYYMMDD, but time without zero padded like '202271' still could be recognized by strptime
+ # so we force length of file_name must be 8
+ # ref: https://docs.python.org/3.6/library/datetime.html#strftime-strptime-behavior
+ if len(file_name) != 8:
+ return False
+ try:
+ datetime.strptime(file_name, '%Y%m%d')
+ return True
+ except ValueError:
+ return False
+
+
+def _is_hourly(file_name: str) -> bool:
+ # YYYYMMDD-HH format, like '20220701-01'
+ # format must be YYYYMMDD-HH, but time without zero padded like '202271-1' still could be recognized by strptime
+ # so we force length of file_name must be 11
+ # ref: https://docs.python.org/3.6/library/datetime.html#strftime-strptime-behavior
+ if len(file_name) != 11:
+ return False
+ try:
+ datetime.strptime(file_name, '%Y%m%d-%H')
+ return True
+ except ValueError:
+ return False
+
+
+def is_streaming_folder(folder: str) -> Tuple[bool, str]:
+ fm = FileManager()
+ file_names = fm.listdir(folder)
+ if len(file_names) == 0:
+ return False, f'streaming data_path should contain folder with correct format, but path {folder} is empty'
+ for file_name in file_names:
+ if not fm.isdir(path=os.path.join(folder, file_name)):
+ return False, f'data_source_url could only contains dir as subpath, {file_name} is not a dir'
+ if not _is_daily(file_name) and not _is_hourly(file_name):
+ return False, f'illegal dir format: {file_name}'
+ return True, ''
+
+
+def get_oldest_daily_folder_time(folder: str) -> Optional[datetime]:
+ fm = FileManager()
+ forder_names = fm.listdir(folder)
+ oldest_folder_time = None
+ for forder_name in forder_names:
+ if not fm.isdir(path=os.path.join(folder, forder_name)):
+ continue
+ if _is_daily(forder_name):
+ forder_time = datetime.strptime(forder_name, '%Y%m%d')
+ if oldest_folder_time is None:
+ oldest_folder_time = forder_time
+ else:
+ oldest_folder_time = min(oldest_folder_time, forder_time)
+ return oldest_folder_time
+
+
+def get_oldest_hourly_folder_time(folder: str) -> Optional[datetime]:
+ fm = FileManager()
+ forder_names = fm.listdir(folder)
+ oldest_folder_time = None
+ for forder_name in forder_names:
+ if not fm.isdir(path=os.path.join(folder, forder_name)):
+ continue
+ if _is_hourly(forder_name):
+ forder_time = datetime.strptime(forder_name, '%Y%m%d-%H')
+ if oldest_folder_time is None:
+ oldest_folder_time = forder_time
+ else:
+ oldest_folder_time = min(oldest_folder_time, forder_time)
+ return oldest_folder_time
+
+
+def parse_event_time_to_daily_folder_name(event_time: datetime) -> str:
+ return event_time.strftime('%Y%m%d')
+
+
+def parse_event_time_to_hourly_folder_name(event_time: datetime) -> str:
+ return event_time.strftime('%Y%m%d-%H')
+
+
+def check_batch_folder_ready(folder: str, batch_name: str) -> bool:
+ batch_path = os.path.join(folder, batch_name)
+ file_manager = FileManager()
+ if not file_manager.isdir(batch_path):
+ return False
+ # TODO(liuhehan): add is_file func to file_manager and check is_file here
+ if not file_manager.exists(os.path.join(batch_path, '_SUCCESS')):
+ return False
+ return True
+
+
+# ====================================
+# scheduler message funcs
+# ====================================
+
+
+def get_daily_folder_not_ready_err_msg() -> str:
+ return CRON_SCHEDULER_FOLDER_NOT_READY_ERROR_MESSAGE.replace(PLACEHOLDER, 'YYYYMMDD')
+
+
+def get_hourly_folder_not_ready_err_msg() -> str:
+ return CRON_SCHEDULER_FOLDER_NOT_READY_ERROR_MESSAGE.replace(PLACEHOLDER, 'YYYYMMDD-HH')
+
+
+def get_certain_folder_not_ready_err_msg(folder_name: str) -> str:
+ return CRON_SCHEDULER_CERTAIN_FOLDER_NOT_READY_ERROR_MESSAGE.replace(PLACEHOLDER, folder_name)
+
+
+def get_daily_batch_not_ready_err_msg() -> str:
+ return CRON_SCHEDULER_BATCH_NOT_READY_ERROR_MESSAGE.replace(PLACEHOLDER, 'YYYYMMDD')
+
+
+def get_hourly_batch_not_ready_err_msg() -> str:
+ return CRON_SCHEDULER_BATCH_NOT_READY_ERROR_MESSAGE.replace(PLACEHOLDER, 'YYYYMMDD-HH')
+
+
+def get_certain_batch_not_ready_err_msg(batch_name: str) -> str:
+ return CRON_SCHEDULER_CERTAIN_BATCH_NOT_READY_ERROR_MESSAGE.replace(PLACEHOLDER, batch_name)
+
+
+def get_cron_succeeded_msg(batch_name: str) -> str:
+ return CRON_SCHEDULER_SUCCEEDED_MESSAGE.replace(PLACEHOLDER, batch_name)
diff --git a/web_console_v2/api/fedlearner_webconsole/dataset/util_test.py b/web_console_v2/api/fedlearner_webconsole/dataset/util_test.py
new file mode 100644
index 000000000..9726403f4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/dataset/util_test.py
@@ -0,0 +1,190 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+from unittest.mock import patch
+from datetime import datetime
+import fsspec
+
+from fedlearner_webconsole.dataset.util import get_oldest_daily_folder_time, get_oldest_hourly_folder_time, \
+ check_batch_folder_ready, get_dataset_path, add_default_url_scheme, _is_daily, _is_hourly, \
+ is_streaming_folder, parse_event_time_to_daily_folder_name, parse_event_time_to_hourly_folder_name, \
+ get_export_dataset_name, get_certain_batch_not_ready_err_msg, get_certain_folder_not_ready_err_msg, \
+ get_cron_succeeded_msg, get_daily_batch_not_ready_err_msg, get_daily_folder_not_ready_err_msg, \
+ get_hourly_batch_not_ready_err_msg, get_hourly_folder_not_ready_err_msg
+
+
+class UtilsTest(unittest.TestCase):
+
+ @patch('envs.Envs.STORAGE_ROOT', '/test')
+ def test_get_dataset_path(self):
+ res = get_dataset_path('fake_dataset', 'fake_uuid')
+ self.assertEqual(res, 'file:///test/dataset/fake_uuid_fake-dataset')
+
+ def test_get_export_dataset_name(self):
+ self.assertEqual(get_export_dataset_name(index=0, input_dataset_name='fake_dataset'), 'export-fake_dataset-0')
+ self.assertEqual(
+ get_export_dataset_name(index=0, input_dataset_name='fake_dataset', input_data_batch_name='20220101'),
+ 'export-fake_dataset-20220101-0')
+
+ def test_add_default_url_scheme(self):
+ path = add_default_url_scheme('')
+ self.assertEqual(path, '')
+
+ path = add_default_url_scheme('/')
+ self.assertEqual(path, 'file:///')
+
+ path = add_default_url_scheme('/test/123')
+ self.assertEqual(path, 'file:///test/123')
+
+ path = add_default_url_scheme('test/123')
+ self.assertEqual(path, 'test/123')
+
+ path = add_default_url_scheme('hdfs:///test/123')
+ self.assertEqual(path, 'hdfs:///test/123')
+
+ def test_is_daily(self):
+ self.assertTrue(_is_daily('20220701'))
+ self.assertFalse(_is_daily('2022711'))
+ self.assertFalse(_is_daily('20221711'))
+ self.assertFalse(_is_daily('2022x711'))
+
+ def test_is_hourly(self):
+ self.assertTrue(_is_hourly('20220701-01'))
+ self.assertFalse(_is_hourly('20220711'))
+ self.assertFalse(_is_hourly('20220701-1'))
+ self.assertFalse(_is_hourly('20220701-25'))
+
+ def test_is_streaming_folder(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220701'))
+ fs.mkdirs(os.path.join(test_path, '20220702'))
+ fs.mkdirs(os.path.join(test_path, '20220703'))
+ res, _ = is_streaming_folder(test_path)
+ self.assertTrue(res)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220701-01'))
+ fs.mkdirs(os.path.join(test_path, '20220701-02'))
+ fs.mkdirs(os.path.join(test_path, '20220701-03'))
+ res, _ = is_streaming_folder(test_path)
+ self.assertTrue(res)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(test_path)
+ res, _ = is_streaming_folder(test_path)
+ self.assertFalse(res)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20221331-25'))
+ res, _ = is_streaming_folder(test_path)
+ self.assertFalse(res)
+
+ def test_get_oldest_daily_folder_time(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220701'))
+ fs.mkdirs(os.path.join(test_path, '20220702'))
+ fs.mkdirs(os.path.join(test_path, '20220703'))
+ event_time = get_oldest_daily_folder_time(test_path)
+ self.assertEqual(event_time, datetime(2022, 7, 1))
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220701-01'))
+ event_time = get_oldest_daily_folder_time(test_path)
+ self.assertIsNone(event_time)
+
+ def test_get_oldest_hourly_folder_time(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220701-01'))
+ fs.mkdirs(os.path.join(test_path, '20220701-02'))
+ fs.mkdirs(os.path.join(test_path, '20220701-03'))
+ event_time = get_oldest_hourly_folder_time(test_path)
+ self.assertEqual(event_time, datetime(2022, 7, 1, 1))
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220701'))
+ event_time = get_oldest_hourly_folder_time(test_path)
+ self.assertIsNone(event_time)
+
+ def test_parse_event_time_to_daily_folder_name(self):
+ self.assertEqual(parse_event_time_to_daily_folder_name(datetime(2022, 1, 1)), '20220101')
+
+ def test_parse_event_time_to_hourly_folder_name(self):
+ self.assertEqual(parse_event_time_to_hourly_folder_name(datetime(2022, 1, 1, 1)), '20220101-01')
+
+ def test_check_batch_folder_ready(self):
+ # test no batch_path
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ self.assertFalse(check_batch_folder_ready(folder=test_path, batch_name='20220101'))
+
+ # test no _SUCCESS file
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220101'))
+ self.assertFalse(check_batch_folder_ready(folder=test_path, batch_name='20220101'))
+
+ # test ready
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ test_path = os.path.join(tmp_dir, 'test')
+ fs, _ = fsspec.core.url_to_fs(test_path)
+ fs.mkdirs(os.path.join(test_path, '20220101'))
+ fs.touch(os.path.join(test_path, '20220101', '_SUCCESS'))
+ self.assertTrue(check_batch_folder_ready(folder=test_path, batch_name='20220101'))
+
+ def test_get_daily_folder_not_ready_err_msg(self):
+ self.assertEqual(get_daily_folder_not_ready_err_msg(), '数据源下未找到满足格式要求的文件夹,请确认文件夹以YYYYMMDD格式命名')
+
+ def test_get_hourly_folder_not_ready_err_msg(self):
+ self.assertEqual(get_hourly_folder_not_ready_err_msg(), '数据源下未找到满足格式要求的文件夹,请确认文件夹以YYYYMMDD-HH格式命名')
+
+ def test_get_daily_batch_not_ready_err_msg(self):
+ self.assertEqual(get_daily_batch_not_ready_err_msg(), '未找到满足格式要求的数据批次,请确保输入数据集有YYYYMMDD格式命名的数据批次')
+
+ def test_get_hourly_batcb_not_ready_err_msg(self):
+ self.assertEqual(get_hourly_batch_not_ready_err_msg(), '未找到满足格式要求的数据批次,请确保输入数据集有YYYYMMDD-HH格式命名的数据批次')
+
+ def test_get_certain_folder_not_ready_err_msg(self):
+ self.assertEqual(get_certain_folder_not_ready_err_msg(folder_name='20220101-08'),
+ '20220101-08文件夹检查失败,请确认数据源下存在以20220101-08格式命名的文件夹,且文件夹下有_SUCCESS文件')
+
+ def test_get_certain_batch_not_ready_err_msg(self):
+ self.assertEqual(get_certain_batch_not_ready_err_msg(batch_name='20220101-08'),
+ '数据批次20220101-08检查失败,请确认该批次命名格式及状态')
+
+ def test_get_cron_succeeded_msg(self):
+ self.assertEqual(get_cron_succeeded_msg(batch_name='20220101-08'), '已成功发起20220101-08批次处理任务')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/db.py b/web_console_v2/api/fedlearner_webconsole/db.py
index b40ff033b..edca200ca 100644
--- a/web_console_v2/api/fedlearner_webconsole/db.py
+++ b/web_console_v2/api/fedlearner_webconsole/db.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,29 +15,28 @@
# coding: utf-8
import os
from contextlib import contextmanager
-from typing import ContextManager, Callable
-
+from typing import ContextManager
+from pymysql.constants.CLIENT import FOUND_ROWS
import sqlalchemy as sa
-
+from sqlalchemy import orm, event, null
from sqlalchemy.engine import Engine, create_engine
-from sqlalchemy.ext.declarative.api import DeclarativeMeta, declarative_base
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import sessionmaker, DeclarativeMeta, declarative_base
from sqlalchemy.orm.session import Session
-from flask_sqlalchemy import SQLAlchemy
from envs import Envs
+from fedlearner_webconsole.utils.base_model.softdelete_model import SoftDeleteModel
+
BASE_DIR = Envs.BASE_DIR
# Explicitly set autocommit and autoflush
# Disables autocommit to make developers to commit manually
# Enables autoflush to make changes visible in the same session
# Disable expire_on_commit to make it possible that object can detach
-SESSION_OPTIONS = {
- 'autocommit': False,
- 'autoflush': True,
- 'expire_on_commit': False
-}
-ENGINE_OPTIONS = {}
+SESSION_OPTIONS = {'autocommit': False, 'autoflush': True, 'expire_on_commit': False}
+# Add flag FOUND_ROWS to make update statement return matched rows but not changed rows.
+# When use Sqlalchemy, must set this flag to make update statement validation bug free.
+MYSQL_OPTIONS = {'connect_args': {'client_flag': FOUND_ROWS}}
+SQLITE_OPTIONS = {}
def default_table_args(comment: str) -> dict:
@@ -48,7 +47,22 @@ def default_table_args(comment: str) -> dict:
}
-def _turn_db_timezone_to_utc(original_uri: str) -> str:
+# an option is added to all SELECT statements that will limit all queries against Dataset to filter on deleted == null
+# global WHERE/ON criteria eg: https://docs.sqlalchemy.org/en/14/_modules/examples/extending_query/filter_public.html
+# normal orm execution wont get the soft-deleted data, eg: session.query(A).get(1)
+# use options can get the soft-deleted data eg: session.query(A).execution_options(include_deleted=True).get(1)
+@event.listens_for(Session, 'do_orm_execute')
+def _add_filtering_criteria(execute_state):
+ if (not execute_state.is_column_load and not execute_state.execution_options.get('include_deleted', False)):
+ execute_state.statement = execute_state.statement.options(
+ orm.with_loader_criteria(
+ SoftDeleteModel,
+ lambda cls: cls.deleted_at == null(),
+ include_aliases=True,
+ ))
+
+
+def turn_db_timezone_to_utc(original_uri: str) -> str:
""" string operator that make any db into utc timezone
Args:
@@ -101,17 +115,15 @@ def get_database_uri() -> str:
Returns:
str: database uri with utc timezone
"""
- uri = ''
- if 'SQLALCHEMY_DATABASE_URI' in os.environ:
- uri = os.getenv('SQLALCHEMY_DATABASE_URI')
- else:
- uri = 'sqlite:///{}?check_same_thread=False'.format(
- os.path.join(BASE_DIR, 'app.db'))
- return _turn_db_timezone_to_utc(uri)
+ uri = Envs.SQLALCHEMY_DATABASE_URI
+ if not uri:
+ db_path = os.path.join(BASE_DIR, 'app.db')
+ uri = f'sqlite:///{db_path}?check_same_thread=False'
+ return turn_db_timezone_to_utc(uri)
-def get_engine(database_uri: str) -> Engine:
- """get engine according to database uri
+def _get_engine(database_uri: str) -> Engine:
+ """Gets engine according to database uri.
Args:
database_uri (str): database uri used for create engine
@@ -119,7 +131,12 @@ def get_engine(database_uri: str) -> Engine:
Returns:
Engine: engine used for managing connections
"""
- return create_engine(database_uri, **ENGINE_OPTIONS)
+ engine_options = {}
+ if database_uri.startswith('mysql'):
+ engine_options = MYSQL_OPTIONS
+ elif database_uri.startswith('sqlite'):
+ engine_options = SQLITE_OPTIONS
+ return create_engine(database_uri, **engine_options)
@contextmanager
@@ -133,8 +150,8 @@ def get_session(db_engine: Engine) -> ContextManager[Session]:
"""
try:
session: Session = sessionmaker(bind=db_engine, **SESSION_OPTIONS)()
- except Exception:
- raise Exception('unknown db engine')
+ except Exception as e:
+ raise Exception('unknown db engine') from e
try:
yield session
@@ -145,40 +162,12 @@ def get_session(db_engine: Engine) -> ContextManager[Session]:
session.close()
-def make_session_context() -> Callable[[], ContextManager[Session]]:
- """A functional closure that will store engine
- Call it n times if you want to n connection pools
-
- Returns:
- Callable[[], Callable[[], ContextManager[Session]]]
- a function that return a contextmanager
-
-
- Examples:
- # First initialize a connection pool,
- # when you want to a new connetion pool
- session_context = make_session_context()
- ...
- # You use it multiple times as follows.
- with session_context() as session:
- session.query(SomeMapperClass).filter_by(id=1).one()
- """
- engine = None
-
- def wrapper_get_session():
- nonlocal engine
- if engine is None:
- engine = get_engine(get_database_uri())
- return get_session(engine)
-
- return wrapper_get_session
-
-
class DBHandler(object):
+
def __init__(self) -> None:
super().__init__()
- self.engine: Engine = get_engine(get_database_uri())
+ self.engine: Engine = _get_engine(get_database_uri())
self.Model: DeclarativeMeta = declarative_base(bind=self.engine)
for module in sa, sa.orm:
for key in module.__all__:
@@ -193,7 +182,7 @@ def metadata(self) -> DeclarativeMeta:
return self.Model.metadata
def rebind(self, database_uri: str):
- self.engine = get_engine(database_uri)
+ self.engine = _get_engine(database_uri)
self.Model = declarative_base(bind=self.engine, metadata=self.metadata)
def create_all(self):
@@ -203,9 +192,6 @@ def drop_all(self):
return self.metadata.drop_all()
-# now db_handler and db are alive at the same time
-# db will be replaced by db_handler in the near future
-db_handler = DBHandler()
-db = SQLAlchemy(session_options=SESSION_OPTIONS,
- engine_options=ENGINE_OPTIONS,
- metadata=db_handler.metadata)
+# now db and db are alive at the same time
+# db will be replaced by db in the near future
+db = DBHandler()
diff --git a/web_console_v2/api/fedlearner_webconsole/db_test.py b/web_console_v2/api/fedlearner_webconsole/db_test.py
new file mode 100644
index 000000000..c1eb82deb
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/db_test.py
@@ -0,0 +1,61 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import patch
+
+from fedlearner_webconsole.db import get_database_uri, turn_db_timezone_to_utc
+
+
+class EngineSessionTest(unittest.TestCase):
+
+ def test_turn_db_timezone_to_utc(self):
+ sqlite_uri = 'sqlite:///app.db'
+ self.assertEqual(turn_db_timezone_to_utc(sqlite_uri), 'sqlite:///app.db')
+
+ mysql_uri_naive = 'mysql+pymysql://root:root@localhost:33600/fedlearner'
+ self.assertEqual(
+ turn_db_timezone_to_utc(mysql_uri_naive),
+ 'mysql+pymysql://root:root@localhost:33600/fedlearner?init_command=SET SESSION time_zone=\'%2B00:00\'')
+
+ mysql_uri_with_init_command = 'mysql+pymysql://root:root@localhost:33600/fedlearner?init_command=HELLO'
+ self.assertEqual(
+ turn_db_timezone_to_utc(mysql_uri_with_init_command),
+ 'mysql+pymysql://root:root@localhost:33600/fedlearner?init_command=SET SESSION time_zone=\'%2B00:00\';HELLO'
+ )
+
+ mysql_uri_with_other_args = 'mysql+pymysql://root:root@localhost:33600/fedlearner?charset=utf8mb4'
+ self.assertEqual(
+ turn_db_timezone_to_utc(mysql_uri_with_other_args),
+ 'mysql+pymysql://root:root@localhost:33600/fedlearner?init_command=SET SESSION time_zone=\'%2B00:00\'&&charset=utf8mb4' # pylint: disable=line-too-long
+ )
+
+ mysql_uri_with_set_time_zone = 'mysql+pymysql://root:root@localhost:33600/fedlearner?init_command=SET SESSION time_zone=\'%2B08:00\'' # pylint: disable=line-too-long
+ self.assertEqual(
+ turn_db_timezone_to_utc(mysql_uri_with_set_time_zone),
+ 'mysql+pymysql://root:root@localhost:33600/fedlearner?init_command=SET SESSION time_zone=\'%2B00:00\'')
+
+ def test_get_database_uri(self):
+ # test with environmental variable
+ with patch('fedlearner_webconsole.db.Envs.SQLALCHEMY_DATABASE_URI',
+ 'mysql+pymysql://root:root@localhost:33600/fedlearner'):
+ self.assertTrue(get_database_uri().startswith('mysql+pymysql://root:root@localhost:33600/fedlearner'))
+
+ # test with fallback options
+ self.assertTrue(get_database_uri().startswith('sqlite:///'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/debug/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/debug/BUILD.bazel
new file mode 100644
index 000000000..5eb4742b0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/debug/BUILD.bazel
@@ -0,0 +1,37 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_cache_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:tfrecords_reader_lib",
+ "@common_flask_restful//:pkg",
+ "@common_pyyaml//:pkg",
+ "@common_tensorflow//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/debug/__init__.py b/web_console_v2/api/fedlearner_webconsole/debug/__init__.py
index 3e28547fe..c13b80f8f 100644
--- a/web_console_v2/api/fedlearner_webconsole/debug/__init__.py
+++ b/web_console_v2/api/fedlearner_webconsole/debug/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/web_console_v2/api/fedlearner_webconsole/debug/apis.py b/web_console_v2/api/fedlearner_webconsole/debug/apis.py
index 4c74e9a80..ab58f220f 100644
--- a/web_console_v2/api/fedlearner_webconsole/debug/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/debug/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,80 +13,154 @@
# limitations under the License.
# coding: utf-8
+import datetime
import json
-from flask_restful import Resource, Api, request
+import tensorflow as tf
+import yaml
+from flask_restful import Resource, Api, request, reqparse
-from fedlearner_webconsole.composer.composer import composer
-from fedlearner_webconsole.composer.runner import MemoryItem
-from fedlearner_webconsole.dataset.data_pipeline import DataPipelineItem, \
- DataPipelineType
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.composer.models import SchedulerRunner, \
+ SchedulerItem
+from fedlearner_webconsole.utils.tfrecords_reader import tf_record_reader
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.k8s.k8s_cache import k8s_cache
+from fedlearner_webconsole.k8s.k8s_client import k8s_client
+from fedlearner_webconsole.db import db
-class ComposerApi(Resource):
+class DebugComposerApi(Resource):
+
def get(self, name):
- interval = request.args.get('interval', -1)
+ cron_config = request.args.get('cron_config')
finish = request.args.get('finish', 0)
- if int(finish) == 1:
- composer.finish(name)
- else:
- composer.collect(
- name,
- [MemoryItem(1), MemoryItem(2)],
- { # meta data
- 1: {
- 'input': 'fs://data/memory_1',
- },
- 2: {
- 'input': 'fs://data/memory_2',
- }
- },
- interval=int(interval),
- )
- return {'data': {'name': name}}
-
-
-class DataPipelineApi(Resource):
- def get(self, name: str):
- # '/data/fl_v2_fish_fooding/dataset/20210527_221741_pipeline'
- input_dir = request.args.get('input_dir', None)
- if not input_dir:
- return {'msg': 'no input dir'}
- if 'pipe' in name:
- composer.collect(
- name,
- [DataPipelineItem(1), DataPipelineItem(2)],
- { # meta data
- 1: { # convertor
- 'sparkapp_name': '1',
- 'task_type': DataPipelineType.CONVERTER.value,
- 'input': [input_dir, 'batch/**/*.csv'],
- },
- 2: { # analyzer
- 'sparkapp_name': '2',
- 'task_type': DataPipelineType.ANALYZER.value,
- 'input': [input_dir, 'rds/**'],
- },
- },
- )
- elif 'fe' in name:
- composer.collect(
- name,
- [DataPipelineItem(1)],
- { # meta data
- 1: { # transformer
- 'sparkapp_name': '1',
- 'task_type': DataPipelineType.TRANSFORMER.value,
- 'input': [input_dir, 'rds/**', json.dumps({
- 'f00000': 1.0,
- 'f00010': 0.0,
- })],
- },
- },
- )
- return {'data': {'name': name}}
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ if int(finish) == 1:
+ service.finish(name)
+ session.commit()
+ return {'data': {'name': name}}
+
+
+class DebugSparkAppApi(Resource):
+
+ def post(self, name: str):
+ data = yaml.load(f"""
+apiVersion: "sparkoperator.k8s.io/v1beta2"
+kind: SparkApplication
+metadata:
+ name: {name}
+ namespace: default
+spec:
+ type: Python
+ pythonVersion: "3"
+ mode: cluster
+ image: "registry.cn-beijing.aliyuncs.com/fedlearner/spark-tfrecord:latest"
+ imagePullPolicy: Always
+ volumes:
+ - name: data
+ persistentVolumeClaim:
+ claimName: pvc-fedlearner-default
+ mainApplicationFile: local:///data/sparkapp_test/tyt_test/schema_check.py
+ arguments:
+ - /data/sparkapp_test/tyt_test/test.csv
+ - /data/sparkapp_test/tyt_test/schema.json
+ sparkVersion: "3.0.0"
+ restartPolicy:
+ type: OnFailure
+ onFailureRetries: 3
+ onFailureRetryInterval: 10
+ onSubmissionFailureRetries: 5
+ onSubmissionFailureRetryInterval: 20
+ driver:
+ cores: 1
+ coreLimit: "1200m"
+ memory: "512m"
+ labels:
+ version: 3.0.0
+ serviceAccount: spark
+ volumeMounts:
+ - name: data
+ mountPath: /data
+ executor:
+ cores: 1
+ instances: 1
+ memory: "512m"
+ labels:
+ version: 3.0.0
+ volumeMounts:
+ - name: data
+ mountPath: /data
+""",
+ Loader=None)
+ data = k8s_client.create_sparkapplication(data)
+ return {'data': data}
+
+
+class DebugK8sCacheApi(Resource):
+
+ def get(self):
+
+ def default(o):
+ if isinstance(o, (datetime.date, datetime.datetime)):
+ return o.isoformat()
+ return str(o)
+
+ return {'data': json.dumps(k8s_cache.inspect(), default=default)}
+
+
+class DebugTfRecordApi(Resource):
+
+ def get(self):
+ path = request.args.get('path', None)
+
+ if path is None or not tf.io.gfile.exists(path):
+ raise InvalidArgumentException('path is not found')
+
+ lines = request.args.get('lines', 25, int)
+ tf_matrix = tf_record_reader(path, lines, matrix_view=True)
+
+ return {'data': tf_matrix}
+
+
+class DebugSchedulerItemsApi(Resource):
+
+ def get(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument('status', type=int, location='args', required=False, choices=[0, 1])
+ parser.add_argument('id', type=int, location='args', required=False)
+ data = parser.parse_args()
+ with db.session_scope() as session:
+ items = session.query(SchedulerItem)
+ if data['status'] is not None:
+ items = items.filter_by(status=data['status'])
+ if data['id'] is not None:
+ runners = session.query(SchedulerRunner).filter_by(item_id=data['id']).order_by(
+ SchedulerRunner.updated_at.desc()).limit(10).all()
+ return {'data': [runner.to_dict() for runner in runners]}
+ items = items.order_by(SchedulerItem.created_at.desc()).all()
+ return {'data': [item.to_dict() for item in items]}
+
+
+class DebugSchedulerRunnersApi(Resource):
+
+ def get(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument('status', type=int, location='args', required=False, choices=[0, 1, 2, 3])
+ data = parser.parse_args()
+ with db.session_scope() as session:
+ runners = session.query(SchedulerRunner)
+ if data['status'] is not None:
+ runners = runners.filter_by(status=data['status'])
+ runners = runners.order_by(SchedulerRunner.updated_at.desc()).all()
+ return {'data': [runner.to_dict() for runner in runners]}
def initialize_debug_apis(api: Api):
- api.add_resource(ComposerApi, '/debug/composer/')
- api.add_resource(DataPipelineApi, '/debug/pipeline/')
+ api.add_resource(DebugComposerApi, '/debug/composer/')
+ api.add_resource(DebugSparkAppApi, '/debug/sparkapp/')
+ api.add_resource(DebugK8sCacheApi, '/debug/k8scache/')
+ api.add_resource(DebugTfRecordApi, '/debug/tfrecord')
+ api.add_resource(DebugSchedulerItemsApi, '/debug/scheduler_items')
+ api.add_resource(DebugSchedulerRunnersApi, '/debug/scheduler_runners')
diff --git a/web_console_v2/api/fedlearner_webconsole/debug/apis_test.py b/web_console_v2/api/fedlearner_webconsole/debug/apis_test.py
new file mode 100644
index 000000000..8b0b33e51
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/debug/apis_test.py
@@ -0,0 +1,83 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+import unittest
+from http import HTTPStatus
+
+from testing.common import BaseTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.models import (SchedulerItem, SchedulerRunner, ItemStatus, RunnerStatus)
+
+
+class DebugSchedulerApiTest(BaseTestCase):
+
+ _ITEM_ON_ID = 123
+ _PRESET_SCHEDULER_ITEM = [
+ 'test_item_off',
+ 'test_item_on',
+ 'workflow_scheduler_v2',
+ 'job_scheduler_v2',
+ 'cleanup_cron_job',
+ 'dataset_short_period_scheduler',
+ 'dataset_long_period_scheduler',
+ 'project_scheduler_v2',
+ 'tee_create_runner',
+ 'tee_resource_check_runner',
+ 'model_job_scheduler_runner',
+ 'model_job_group_scheduler_runner',
+ 'model_job_group_long_period_scheduler_runner',
+ ]
+
+ def setUp(self):
+ super().setUp()
+ scheduler_item_on = SchedulerItem(id=self._ITEM_ON_ID, name='test_item_on', status=ItemStatus.ON.value)
+ scheduler_item_off = SchedulerItem(name='test_item_off', status=ItemStatus.OFF.value)
+ with db.session_scope() as session:
+ session.add_all([scheduler_item_on, scheduler_item_off])
+ session.commit()
+ scheduler_runner_init = SchedulerRunner(id=0, item_id=self._ITEM_ON_ID, status=RunnerStatus.INIT.value)
+ scheduler_runner_running_1 = SchedulerRunner(id=1, item_id=self._ITEM_ON_ID, status=RunnerStatus.RUNNING.value)
+ scheduler_runner_running_2 = SchedulerRunner(id=2, item_id=self._ITEM_ON_ID, status=RunnerStatus.RUNNING.value)
+
+ with db.session_scope() as session:
+ session.add_all([scheduler_runner_init, scheduler_runner_running_1, scheduler_runner_running_2])
+ session.commit()
+
+ def test_get_scheduler_item(self):
+ # test get all scheduler item
+ data = self.get_response_data(self.get_helper('/api/v2/debug/scheduler_items'))
+ # there exists a preset scheduler item
+ self.assertCountEqual([d['name'] for d in data], self._PRESET_SCHEDULER_ITEM)
+ # test get scheduler item with status
+ data = self.get_response_data(self.get_helper('/api/v2/debug/scheduler_items?status=0'))
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'test_item_off')
+ # test get recent scheduler runners
+ data = self.get_response_data(self.get_helper(f'/api/v2/debug/scheduler_items?id={self._ITEM_ON_ID}'))
+ self.assertEqual(len(data), 3)
+ self.assertEqual(data[0]['status'], RunnerStatus.INIT.value)
+
+ def test_get_scheduler_runner(self):
+ # test get running runners
+ response = self.get_helper('/api/v2/debug/scheduler_runners?status=1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['id'], 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/e2e/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/e2e/BUILD.bazel
new file mode 100644
index 000000000..27174d95c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/e2e/BUILD.bazel
@@ -0,0 +1,74 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_pyyaml//:pkg",
+ ],
+)
+
+py_library(
+ name = "controllers_lib",
+ srcs = ["controllers.py"],
+ imports = ["../.."],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_client_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_kubernetes//:pkg",
+ ],
+)
+
+py_test(
+ name = "controllers_test",
+ size = "small",
+ srcs = [
+ "controllers_test.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":controllers_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_flasgger//:pkg",
+ "@common_flask_restful//:pkg",
+ "@common_kubernetes//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_test",
+ size = "medium",
+ srcs = ["apis_test.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/e2e/__init__.py b/web_console_v2/api/fedlearner_webconsole/e2e/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/e2e/apis.py b/web_console_v2/api/fedlearner_webconsole/e2e/apis.py
new file mode 100644
index 000000000..404c276ac
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/e2e/apis.py
@@ -0,0 +1,145 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from flask_restful import Api, Resource
+from kubernetes.client import ApiException
+from marshmallow import validate, post_load
+from webargs.flaskparser import use_args
+from flasgger import Schema, fields
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.e2e.controllers import ROLES_MAPPING, get_job, get_job_logs, initiate_all_tests
+from fedlearner_webconsole.exceptions import NotFoundException, InvalidArgumentException
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.e2e_pb2 import InitiateE2eJobsParameter
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required
+
+
+class E2eJobsApi(Resource):
+
+ @credentials_required
+ @admin_required
+ def get(self, job_name: str):
+ """Get an existing job
+ ---
+ tags:
+ - e2e
+ description: get a job
+ parameters:
+ - in: path
+ name: job_name
+ required: true
+ schema:
+ type: string
+ description: The name of the job
+ responses:
+ 200:
+ description: The corresponding job
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ job_name:
+ type: string
+ description: The name of the job
+ status:
+ type: object
+ description: The status of the job
+ log:
+ type: array
+ items:
+ type: string
+ description: The log of the job; if the job is still active, an empty string is returned
+ 404:
+ description: The job is not found
+ """
+ try:
+ status = get_job(job_name)['status']
+ except ApiException as e:
+ raise NotFoundException(f'failed to find job {job_name}') from e
+ # if the pod is still running, do not query for logs
+ log = get_job_logs(job_name) if 'active' not in status else []
+ return make_flask_response(data={'job_name': job_name, 'status': status, 'log': log})
+
+
+# If schema is defined as "...Schema", the last "Schema" will be deleted, so the reference to this schema
+# is "#/definitions/InitiateE2eJobsParameter"
+class InitiateE2eJobsParameterSchema(Schema):
+ role = fields.String(required=True, validate=validate.OneOf(ROLES_MAPPING.keys()))
+ name_prefix = fields.String(required=True, validate=validate.Length(min=5))
+ project_name = fields.String(required=True, validate=validate.Length(min=1))
+ e2e_image_uri = fields.String(required=True, validate=lambda x: 'fedlearner_e2e:' in x)
+ fedlearner_image_uri = fields.String(required=True, validate=lambda x: 'fedlearner:' in x)
+ platform_endpoint = fields.String(required=False,
+ load_default='http://fedlearner-fedlearner-web-console-v2-http:1989',
+ validate=validate.URL(require_tld=False))
+
+ @post_load
+ def make_initiate_e2e_jobs_parameter(self, data, **kwargs) -> InitiateE2eJobsParameter:
+ del kwargs
+ return InitiateE2eJobsParameter(**data)
+
+
+class InitiateE2eJobsApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_args(InitiateE2eJobsParameterSchema(), location='json')
+ def post(self, params: InitiateE2eJobsParameter):
+ """Initiate a series of E2e jobs
+ ---
+ tags:
+ - e2e
+ description: initiate a series of E2e jobs
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/InitiateE2eJobsParameter'
+ responses:
+ 201:
+ description: Jobs are launched and job names are returned
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: object
+ properties:
+ job_type:
+ type: string
+ job_name:
+ type: string
+ """
+ with db.session_scope() as session:
+ project = session.query(Project).filter(Project.name == params.project_name).first()
+ if project is None:
+ raise InvalidArgumentException(f'failed to find project with name={params.project_name}')
+ try:
+ jobs = initiate_all_tests(params)
+ except ValueError as e:
+ raise InvalidArgumentException(str(e)) from e
+ return make_flask_response(jobs)
+
+
+def initialize_e2e_apis(api: Api):
+ api.add_resource(E2eJobsApi, '/e2e_jobs/')
+ api.add_resource(InitiateE2eJobsApi, '/e2e_jobs:initiate')
+ schema_manager.append(InitiateE2eJobsParameterSchema)
diff --git a/web_console_v2/api/fedlearner_webconsole/e2e/apis_test.py b/web_console_v2/api/fedlearner_webconsole/e2e/apis_test.py
new file mode 100644
index 000000000..0eaff12d2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/e2e/apis_test.py
@@ -0,0 +1,110 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from json import loads
+from unittest import main
+from unittest.mock import patch
+
+from testing.common import BaseTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+
+
+class InitiateE2eJobsApiTest(BaseTestCase):
+
+ def test_post(self):
+ self.signin_as_admin()
+ response = self.post_helper(
+ '/api/v2/e2e_jobs:initiate', {
+ 'role': 'some_role',
+ 'name_prefix': 'test',
+ 'project_name': '',
+ 'e2e_image_uri': 'invalid',
+ 'fedlearner_image_uri': 'invalid',
+ 'platform_endpoint': 'invalid'
+ })
+ self.assert400(response)
+ error_details = loads(response.data)['details']['json']
+ self.assertRegex(error_details['role'][0], 'coordinator, participant')
+ self.assertRegex(error_details['name_prefix'][0], 'minimum length 5')
+ self.assertRegex(error_details['project_name'][0], 'minimum length 1')
+ self.assertRegex(error_details['e2e_image_uri'][0], 'Invalid value.')
+ self.assertRegex(error_details['fedlearner_image_uri'][0], 'Invalid value.')
+ self.assertRegex(error_details['platform_endpoint'][0], 'Not a valid URL')
+
+ response = self.post_helper(
+ '/api/v2/e2e_jobs:initiate', {
+ 'role': 'coordinator',
+ 'name_prefix': 'test_me',
+ 'project_name': 'project',
+ 'e2e_image_uri': 'fedlearner_e2e:hey',
+ 'fedlearner_image_uri': 'fedlearner:hey',
+ 'platform_endpoint': 'hey-hello:80/index.html'
+ })
+ self.assert400(response)
+ error_details = loads(response.data)['details']['json']
+ self.assertIsNone(error_details.get('role'))
+ self.assertIsNone(error_details.get('name_prefix'))
+ self.assertIsNone(error_details.get('project_name'))
+ self.assertIsNone(error_details.get('e2e_image_uri'))
+ self.assertIsNone(error_details.get('fedlearner_image_uri'))
+ self.assertRegex(error_details['platform_endpoint'][0], 'Not a valid URL')
+
+ response = self.post_helper(
+ '/api/v2/e2e_jobs:initiate', {
+ 'role': 'coordinator',
+ 'name_prefix': 'test_me',
+ 'project_name': 'project',
+ 'e2e_image_uri': 'fedlearner_e2e:hey',
+ 'fedlearner_image_uri': 'fedlearner:hey',
+ 'platform_endpoint': 'http://hey-hello:80/index.html'
+ })
+ self.assert400(response)
+ error_details = loads(response.data)['details']
+ self.assertRegex(error_details, 'failed to find project')
+
+ with db.session_scope() as session:
+ session.add(Project(id=1000, name='project'))
+ session.commit()
+
+ response = self.post_helper(
+ '/api/v2/e2e_jobs:initiate', {
+ 'role': 'coordinator',
+ 'name_prefix': 'test_me',
+ 'project_name': 'project',
+ 'e2e_image_uri': 'fedlearner_e2e:hey',
+ 'fedlearner_image_uri': 'fedlearner:hey',
+ 'platform_endpoint': 'http://hey-hello:80/index.html'
+ })
+ self.assert400(response)
+ error_details = loads(response.data)['details']
+ self.assertRegex(error_details, r'job with job_name=[\w-]* exists')
+
+ with patch('fedlearner_webconsole.e2e.apis.initiate_all_tests') as mock_initiate_all_tests:
+ mock_initiate_all_tests.return_value = [{}]
+ response = self.post_helper(
+ '/api/v2/e2e_jobs:initiate', {
+ 'role': 'coordinator',
+ 'name_prefix': 'test_me',
+ 'project_name': 'project',
+ 'e2e_image_uri': 'fedlearner_e2e:hey',
+ 'fedlearner_image_uri': 'fedlearner:hey',
+ 'platform_endpoint': 'http://hey-hello:80/index.html'
+ })
+ self.assert200(response)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/web_console_v2/api/fedlearner_webconsole/e2e/controllers.py b/web_console_v2/api/fedlearner_webconsole/e2e/controllers.py
new file mode 100644
index 000000000..99f0e12a1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/e2e/controllers.py
@@ -0,0 +1,78 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict, List, Tuple
+
+from kubernetes.client import ApiException
+
+from envs import Envs
+from fedlearner_webconsole.e2e.utils import e2e_job_to_dict
+from fedlearner_webconsole.proto.e2e_pb2 import E2eJob, InitiateE2eJobsParameter
+from fedlearner_webconsole.k8s.k8s_client import k8s_client
+
+COORDINATOR_TESTS = {
+ 'fed-workflow': 'scripts/auto_e2e/fed_workflow/test_coordinator.py',
+ 'vertical-dataset-model-serving': 'scripts/auto_e2e/vertical_dataset_model_serving/test_coordinator.py'
+}
+
+PARTICIPANT_TESTS = {
+ 'fed-workflow': 'scripts/auto_e2e/fed_workflow/test_participant.py',
+ 'vertical-dataset-model-serving': 'scripts/auto_e2e/vertical_dataset_model_serving/test_participant.py'
+}
+
+ROLES_MAPPING: Dict[str, Dict] = {'coordinator': COORDINATOR_TESTS, 'participant': PARTICIPANT_TESTS}
+
+
+def start_job(e2e_job: E2eJob):
+ try:
+ get_job(e2e_job.job_name)
+ raise ValueError(f'failed to start {e2e_job.job_name}; job with job_name={e2e_job.job_name} exists')
+ except ApiException:
+ pass
+ k8s_client.create_app(e2e_job_to_dict(e2e_job), group='batch', version='v1', plural='jobs')
+
+
+def get_job(job_name: str) -> dict:
+ return k8s_client.crds.get_namespaced_custom_object(group='batch',
+ version='v1',
+ namespace=Envs.K8S_NAMESPACE,
+ plural='jobs',
+ name=job_name)
+
+
+def get_job_logs(job_name: str) -> List[str]:
+ return k8s_client.get_pod_log(job_name, Envs.K8S_NAMESPACE, 30)
+
+
+def generate_job_list(params: InitiateE2eJobsParameter) -> List[Tuple[str, E2eJob]]:
+ jobs = []
+ fed_jobs = ROLES_MAPPING[params.role]
+ for job_type, script_path in fed_jobs.items():
+ jobs.append((job_type,
+ E2eJob(project_name=params.project_name,
+ script_path=script_path,
+ fedlearner_image_uri=params.fedlearner_image_uri,
+ e2e_image_uri=params.e2e_image_uri,
+ job_name=f'auto-e2e-{params.name_prefix}-{job_type}',
+ platform_endpoint=params.platform_endpoint,
+ name_prefix=f'auto-e2e-{params.name_prefix}')))
+ return jobs
+
+
+def initiate_all_tests(params: InitiateE2eJobsParameter) -> List[Dict[str, str]]:
+ jobs = generate_job_list(params)
+ for _, job in jobs:
+ start_job(job)
+ return [{'job_type': job_type, 'job_name': job.job_name} for job_type, job in jobs]
diff --git a/web_console_v2/api/fedlearner_webconsole/e2e/controllers_test.py b/web_console_v2/api/fedlearner_webconsole/e2e/controllers_test.py
new file mode 100644
index 000000000..52ed19c07
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/e2e/controllers_test.py
@@ -0,0 +1,104 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import Mock, patch, call
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.e2e.controllers import initiate_all_tests
+from fedlearner_webconsole.proto.e2e_pb2 import E2eJob, InitiateE2eJobsParameter
+
+
+class E2eControllerTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.e2e.controllers.start_job')
+ def test_initiate_participant_tests(self, start_job_mock: Mock):
+ start_job_mock.return_value = None
+ self.assertRaises(KeyError, initiate_all_tests,
+ InitiateE2eJobsParameter(role='invalid_role', platform_endpoint='some_uri.com'))
+ jobs = initiate_all_tests(
+ InitiateE2eJobsParameter(role='participant',
+ name_prefix='test',
+ project_name='hello',
+ e2e_image_uri='some_image',
+ fedlearner_image_uri='some_image',
+ platform_endpoint='some_uri.com'))
+ self.assertEqual([{
+ 'job_name': 'auto-e2e-test-fed-workflow',
+ 'job_type': 'fed-workflow'
+ }, {
+ 'job_name': 'auto-e2e-test-vertical-dataset-model-serving',
+ 'job_type': 'vertical-dataset-model-serving'
+ }], jobs)
+
+ self.assertEqual([
+ call(
+ E2eJob(project_name='hello',
+ script_path='scripts/auto_e2e/fed_workflow/test_participant.py',
+ fedlearner_image_uri='some_image',
+ e2e_image_uri='some_image',
+ job_name='auto-e2e-test-fed-workflow',
+ platform_endpoint='some_uri.com',
+ name_prefix='auto-e2e-test')),
+ call(
+ E2eJob(project_name='hello',
+ script_path='scripts/auto_e2e/vertical_dataset_model_serving/test_participant.py',
+ fedlearner_image_uri='some_image',
+ e2e_image_uri='some_image',
+ job_name='auto-e2e-test-vertical-dataset-model-serving',
+ platform_endpoint='some_uri.com',
+ name_prefix='auto-e2e-test')),
+ ], start_job_mock.call_args_list)
+
+ @patch('fedlearner_webconsole.e2e.controllers.start_job')
+ def test_initiate_coordinator_tests(self, start_job_mock: Mock):
+ start_job_mock.return_value = None
+ jobs = initiate_all_tests(
+ InitiateE2eJobsParameter(role='coordinator',
+ name_prefix='test',
+ project_name='hello',
+ e2e_image_uri='some_image',
+ fedlearner_image_uri='some_image',
+ platform_endpoint='some_uri.com'))
+ self.assertEqual([{
+ 'job_name': 'auto-e2e-test-fed-workflow',
+ 'job_type': 'fed-workflow'
+ }, {
+ 'job_name': 'auto-e2e-test-vertical-dataset-model-serving',
+ 'job_type': 'vertical-dataset-model-serving'
+ }], jobs)
+
+ self.assertEqual([
+ call(
+ E2eJob(project_name='hello',
+ script_path='scripts/auto_e2e/fed_workflow/test_coordinator.py',
+ fedlearner_image_uri='some_image',
+ e2e_image_uri='some_image',
+ job_name='auto-e2e-test-fed-workflow',
+ platform_endpoint='some_uri.com',
+ name_prefix='auto-e2e-test')),
+ call(
+ E2eJob(project_name='hello',
+ script_path='scripts/auto_e2e/vertical_dataset_model_serving/test_coordinator.py',
+ fedlearner_image_uri='some_image',
+ e2e_image_uri='some_image',
+ job_name='auto-e2e-test-vertical-dataset-model-serving',
+ platform_endpoint='some_uri.com',
+ name_prefix='auto-e2e-test')),
+ ], start_job_mock.call_args_list)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/e2e/utils.py b/web_console_v2/api/fedlearner_webconsole/e2e/utils.py
new file mode 100644
index 000000000..afc15b884
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/e2e/utils.py
@@ -0,0 +1,70 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from string import Template
+
+import yaml
+
+from fedlearner_webconsole.proto.e2e_pb2 import E2eJob
+
+
+def e2e_job_model_to_yaml(job: E2eJob) -> str:
+ return _E2E_JOB_TEMPLATE.substitute(
+ job_name=job.job_name,
+ name_prefix=job.name_prefix,
+ e2e_image_uri=job.e2e_image_uri,
+ project_name=job.project_name,
+ platform_endpoint=job.platform_endpoint,
+ fedlearner_image_uri=job.fedlearner_image_uri,
+ script_path=job.script_path,
+ )
+
+
+def e2e_job_to_dict(job: E2eJob) -> dict:
+ return yaml.load(e2e_job_model_to_yaml(job), Loader=yaml.Loader)
+
+
+_E2E_JOB_TEMPLATE = Template("""apiVersion: batch/v1
+kind: Job
+metadata:
+ name: $job_name
+ labels:
+ owner: wangsen.0914
+ psm: data.aml.fl
+spec:
+ template:
+ spec:
+ containers:
+ - name: $job_name
+ image: $e2e_image_uri
+ env:
+ - name: PYTHONPATH
+ value: /app
+ - name: PROJECT_NAME
+ value: $project_name
+ - name: PLATFORM_ENDPOINT
+ value: $platform_endpoint
+ - name: FEDLEARNER_IMAGE
+ value: $fedlearner_image_uri
+ - name: NAME_PREFIX
+ value: $name_prefix
+ command:
+ - python
+ - $script_path
+ imagePullSecrets:
+ - name: regcred
+ restartPolicy: Never
+ backoffLimit: 0
+""")
diff --git a/web_console_v2/api/fedlearner_webconsole/exceptions.py b/web_console_v2/api/fedlearner_webconsole/exceptions.py
index 3de880de0..1d1695639 100644
--- a/web_console_v2/api/fedlearner_webconsole/exceptions.py
+++ b/web_console_v2/api/fedlearner_webconsole/exceptions.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
class WebConsoleApiException(Exception):
+
def __init__(self, status_code, error_code, message, details=None):
Exception.__init__(self)
self.status_code = status_code
@@ -40,40 +41,52 @@ def to_dict(self):
class InvalidArgumentException(WebConsoleApiException):
+
+ def __init__(self, details):
+ WebConsoleApiException.__init__(self, HTTPStatus.BAD_REQUEST, 400, 'Invalid argument or payload.', details)
+
+
+class NetworkException(WebConsoleApiException):
+
def __init__(self, details):
- WebConsoleApiException.__init__(self, HTTPStatus.BAD_REQUEST, 400,
- 'Invalid argument or payload.', details)
+ WebConsoleApiException.__init__(self, HTTPStatus.BAD_REQUEST, 400, 'Network exception', details)
class NotFoundException(WebConsoleApiException):
+
def __init__(self, message=None):
- WebConsoleApiException.__init__(
- self, HTTPStatus.NOT_FOUND, 404,
- message if message else 'Resource not found.')
+ WebConsoleApiException.__init__(self, HTTPStatus.NOT_FOUND, 404, message if message else 'Resource not found.')
class UnauthorizedException(WebConsoleApiException):
+
def __init__(self, message):
- WebConsoleApiException.__init__(self, HTTPStatus.UNAUTHORIZED,
- 401, message)
+ WebConsoleApiException.__init__(self, HTTPStatus.UNAUTHORIZED, 401, message)
class NoAccessException(WebConsoleApiException):
+
def __init__(self, message):
- WebConsoleApiException.__init__(self, HTTPStatus.FORBIDDEN,
- 403, message)
+ WebConsoleApiException.__init__(self, HTTPStatus.FORBIDDEN, 403, message)
+
+
+class MethodNotAllowedException(WebConsoleApiException):
+
+ def __init__(self, message):
+ WebConsoleApiException.__init__(self, HTTPStatus.METHOD_NOT_ALLOWED, 405, message)
class ResourceConflictException(WebConsoleApiException):
+
def __init__(self, message):
WebConsoleApiException.__init__(self, HTTPStatus.CONFLICT, 409, message)
class InternalException(WebConsoleApiException):
+
def __init__(self, details=None):
- WebConsoleApiException.__init__(
- self, HTTPStatus.INTERNAL_SERVER_ERROR, 500,
- 'Internal Error met when handling the request', details)
+ WebConsoleApiException.__init__(self, HTTPStatus.INTERNAL_SERVER_ERROR, 500,
+ 'Internal Error met when handling the request', details)
def make_response(exception: WebConsoleApiException):
diff --git a/web_console_v2/api/fedlearner_webconsole/exceptions_test.py b/web_console_v2/api/fedlearner_webconsole/exceptions_test.py
new file mode 100644
index 000000000..ba866b29b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/exceptions_test.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+
+from http import HTTPStatus
+from fedlearner_webconsole.exceptions import (InvalidArgumentException, NotFoundException)
+
+
+class ExceptionsTest(unittest.TestCase):
+
+ def test_invalid_argument_exception(self):
+ """Checks if the information of the exception is correct."""
+ exception = InvalidArgumentException(['123', 'df'])
+ self.assertEqual(exception.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(exception.to_dict(), {
+ 'code': 400,
+ 'message': 'Invalid argument or payload.',
+ 'details': [
+ '123',
+ 'df',
+ ]
+ })
+
+ def test_not_found_exception(self):
+ exception1 = NotFoundException('User A not found.')
+ self.assertEqual(exception1.status_code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(exception1.to_dict(), {
+ 'code': 404,
+ 'message': 'User A not found.',
+ })
+ exception2 = NotFoundException()
+ self.assertEqual(exception2.status_code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(exception2.to_dict(), {
+ 'code': 404,
+ 'message': 'Resource not found.',
+ })
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/file/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/file/BUILD.bazel
new file mode 100644
index 000000000..34f8b8f8b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/file/BUILD.bazel
@@ -0,0 +1,38 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask//:pkg",
+ "@common_flask_restful//:pkg",
+ "@common_tensorflow//:pkg",
+ "@common_webargs//:pkg",
+ "@common_werkzeug//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "@common_werkzeug//:pkg",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/file/__init__.py b/web_console_v2/api/fedlearner_webconsole/file/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/file/apis.py b/web_console_v2/api/fedlearner_webconsole/file/apis.py
new file mode 100644
index 000000000..034192294
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/file/apis.py
@@ -0,0 +1,290 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import enum
+import os
+import tempfile
+import logging
+from datetime import datetime
+from typing import List, Optional
+from urllib.parse import unquote
+from werkzeug.formparser import FileStorage
+from werkzeug.utils import secure_filename
+from io import BytesIO
+from envs import Envs
+from flask import send_file
+from flask_restful import Resource, Api
+from google.protobuf.json_format import MessageToDict
+from webargs.flaskparser import use_kwargs
+from webargs import fields
+from tensorflow.io import gfile
+
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.exceptions import (InvalidArgumentException, NoAccessException, NotFoundException)
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.file_operator import FileOperator
+
+
+class FileType(enum.Enum):
+ FILE = 'file'
+ DATASET = 'dataset'
+
+
+# Files with these extentions will be displayed directly.
+DISPLAYABLE_EXTENTION = ['.txt', '.py']
+UPLOAD_FILE_PATH = f'upload_{FileType.FILE.value}'
+IMAGE_EXTENSION = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')
+FILE_WHITELIST = (Envs.STORAGE_ROOT, 'hdfs://')
+
+
+def _is_path_accessible(path):
+ return path.startswith(FILE_WHITELIST)
+
+
+def _is_image_extension(filename):
+ return filename.lower().endswith(IMAGE_EXTENSION)
+
+
+class FileApi(Resource):
+
+ def __init__(self):
+ self._file_manager = FileManager()
+
+ @credentials_required
+ @use_kwargs({'path': fields.String(required=True, help='the filepath that you want to read')}, location='query')
+ def get(self, path: str):
+ """Get file content by filepath
+ ---
+ tags:
+ - file
+ description: >
+ Get file content by filepath.
+ Note that this api isn't design for binary content.
+ parameters:
+ - in: query
+ name: path
+ schema:
+ type: string
+ responses:
+ 200:
+ description: content of the specified path
+ content:
+ application/json:
+ schema:
+ type: string
+ """
+ filepath = path
+ if not _is_path_accessible(filepath):
+ raise NoAccessException('access to this file or directory is not allowed ')
+ content = self._file_manager.read(filepath)
+ return {'data': content}
+
+
+class FilesApi(Resource):
+
+ def __init__(self):
+ self._storage_root = Envs.STORAGE_ROOT
+ self._file_manager = FileManager()
+ self._file_operator = FileOperator()
+ self._file_dir = os.path.join(self._storage_root, UPLOAD_FILE_PATH)
+ self._file_manager.mkdir(self._file_dir)
+ # keep align with original upload directory
+ self._dataset_dir = os.path.join(self._storage_root, 'upload')
+ self._file_manager.mkdir(self._dataset_dir)
+
+ @credentials_required
+ @use_kwargs({'directory': fields.String(required=False, load_default=None)}, location='query')
+ def get(self, directory: Optional[str]):
+ """Get files and directories under some directory
+ ---
+ tags:
+ - file
+ description: Get files and directories under some directory
+ parameters:
+ - in: query
+ name: directory
+ schema:
+ type: string
+ responses:
+ 200:
+ description: files and directories
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: object
+ properties:
+ path:
+ type: string
+ size:
+ type: integer
+ mtime:
+ type: integer
+ is_directory:
+ type: boolean
+ """
+ if directory is None:
+ directory = os.path.join(self._storage_root, 'upload')
+ if not _is_path_accessible(directory):
+ raise NoAccessException('access to this file or directory is not allowed ')
+ if not self._file_manager.isdir(directory):
+ raise NotFoundException('directory is not exist ')
+ files = self._file_manager.ls(directory, include_directory=True)
+ return {'data': [dict(file._asdict()) for file in files]}
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'kind':
+ fields.String(required=False, load_default=FileType.FILE.value, help='file type'),
+ 'id':
+ fields.String(required=False,
+ load_default='',
+ help='id to locate the file upload location. '
+ 'For example, use jobs/job_id for algorithm '
+ 'upload for a certain job.'),
+ 'extract':
+ fields.String(
+ required=False, load_default='False', help='If it is necessary to '
+ 'extract the uploaded file.'),
+ },
+ location='form')
+ @use_kwargs({'file': fields.List(fields.Field(required=True))}, location='files')
+ def post(self, kind: str, id: str, extract: str, file: List[FileStorage]): # pylint: disable=redefined-builtin
+ """Post one or a set of files for upload
+ ---
+ tags:
+ - file
+ description: Post one or a set of files for upload
+ parameters:
+ - in: form
+ name: kind
+ schema:
+ type: string
+ - in: form
+ name: id
+ schema:
+ type: string
+ - in: form
+ name: extract
+ schema:
+ type: string
+ - in: form
+ name: file
+ schema:
+ type: array
+ items:
+ type: string
+ format: binary
+ description: list of files in binary format
+ responses:
+ 200:
+ description: information of uploaded files
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.UploadedFiles'
+ """
+ location_id = id
+ extract = bool(extract.lower() == 'true')
+
+ upload_files = file
+ if extract:
+ if len(upload_files) != 1:
+ raise InvalidArgumentException('Extraction only allows 1 file each time.')
+
+ # file root dir: {storage_root}/upload_file/{location_id}/
+ root_dir = os.path.join(self._file_dir, location_id)
+ if kind == FileType.DATASET.value:
+ # TODO: clean dataset regularly
+ location_id = datetime.utcnow().strftime('%Y%m%d_%H%M%S%f')
+ root_dir = os.path.join(self._dataset_dir, location_id)
+
+ response = common_pb2.UploadedFiles()
+ # file root dir: {storage_root}/upload_file/{location_id}/{datetime}/
+ self._file_manager.mkdir(root_dir)
+ for upload_file in upload_files:
+ file_content: bytes = upload_file.read()
+ if extract:
+ secure_tarfile_name = secure_filename(os.path.basename(upload_file.filename))
+ target_dir_path = os.path.join(root_dir, secure_tarfile_name.split('.')[0])
+ self._file_manager.mkdir(target_dir_path)
+ logging.info(f'target_dir_path:{target_dir_path}')
+ extension = '.' + secure_tarfile_name.split('.')[-1]
+ with tempfile.NamedTemporaryFile(suffix=extension) as f:
+ f.write(file_content)
+ self._file_operator.extract_to(f.name, target_dir_path)
+ response.uploaded_files.append(
+ common_pb2.UploadedFile(display_file_name=secure_tarfile_name,
+ internal_path=target_dir_path,
+ internal_directory=target_dir_path))
+ else:
+ # copy the file to the target destination.
+ secure_file_name = secure_filename(os.path.basename(upload_file.filename))
+ response.uploaded_files.append(
+ self._save_secured_file(root_dir,
+ display_name=secure_file_name,
+ secure_file_name=secure_file_name,
+ content=file_content))
+ return {'data': MessageToDict(response, preserving_proto_field_name=True)}
+
+ def _save_secured_file(self, root_folder: str, display_name: str, secure_file_name: str, content: str) -> str:
+ """Save the file to fedlearner and return the UI view."""
+ self._file_manager.write(os.path.join(root_folder, secure_file_name), content)
+ return common_pb2.UploadedFile(display_file_name=display_name,
+ internal_path=os.path.join(root_folder, secure_file_name),
+ internal_directory=root_folder)
+
+
+class ImageApi(Resource):
+
+ def __init__(self):
+ self._file_manager = FileManager()
+
+ @use_kwargs({'name': fields.String(required=True, help='image name that you want')}, location='query')
+ def get(self, name: str):
+ """Get image content by image path
+ ---
+ tags:
+ - file
+ description: Get image content by image path
+ parameters:
+ - in: query
+ name: name
+ schema:
+ type: string
+ description: file path of image
+ responses:
+ 200:
+ description:
+ content:
+ image/jpeg:
+ type: string
+ format: binary
+ """
+ if not _is_path_accessible(name):
+ raise NoAccessException('access to this file or directory is not allowed ')
+ if not _is_image_extension(name):
+ raise InvalidArgumentException('access to this file or directory is not allowed ')
+ content = gfile.GFile(unquote(name), 'rb').read()
+ return send_file(BytesIO(content), mimetype='image/jpeg')
+
+
+def initialize_files_apis(api: Api):
+ api.add_resource(FilesApi, '/files')
+ api.add_resource(FileApi, '/file')
+ api.add_resource(ImageApi, '/image')
diff --git a/web_console_v2/api/fedlearner_webconsole/file/apis_test.py b/web_console_v2/api/fedlearner_webconsole/file/apis_test.py
new file mode 100644
index 000000000..d4c6f274b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/file/apis_test.py
@@ -0,0 +1,170 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tarfile
+import os
+import shutil
+import tempfile
+import unittest
+from io import BytesIO
+from unittest.mock import patch
+
+from envs import Envs
+from http import HTTPStatus
+from pathlib import Path
+from collections import namedtuple
+
+from werkzeug.utils import secure_filename
+from werkzeug.datastructures import FileStorage
+from testing.common import BaseTestCase
+
+from fedlearner_webconsole.file.apis import UPLOAD_FILE_PATH
+from fedlearner_webconsole.utils.file_manager import FileManager
+
+BASE_DIR = Envs.BASE_DIR
+FakeFileStatistics = namedtuple('FakeFileStatistics', ['length', 'mtime_nsec', 'is_directory'])
+
+_FAKE_STORAGE_ROOT = str(tempfile.gettempdir())
+
+
+@patch('fedlearner_webconsole.file.apis.Envs.STORAGE_ROOT', _FAKE_STORAGE_ROOT)
+@patch('fedlearner_webconsole.file.apis.FILE_WHITELIST', (_FAKE_STORAGE_ROOT))
+class FilesApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+
+ self._file_manager = FileManager()
+ self._tempdir = os.path.join(_FAKE_STORAGE_ROOT, 'upload')
+ os.makedirs(self._tempdir, exist_ok=True)
+ subdir = Path(self._tempdir).joinpath('s')
+ subdir.mkdir(exist_ok=True)
+ Path(self._tempdir).joinpath('f1.txt').write_text('f1', encoding='utf-8')
+ Path(self._tempdir).joinpath('f2.txt').write_text('f2f2', encoding='utf-8')
+ subdir.joinpath('s3.txt').write_text('s3s3s3', encoding='utf-8')
+
+ def tearDown(self):
+ # Remove the directory after the test
+ shutil.rmtree(self._tempdir)
+
+ def _get_temp_path(self, file_path: str = None) -> str:
+ return str(Path(self._tempdir, file_path or '').absolute())
+
+ def test_get_storage_root(self):
+ get_response = self.get_helper('/api/v2/files')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ self.assertEqual(len(self.get_response_data(get_response)), 3)
+
+ def test_get_specified_illegal_directory(self):
+ get_response = self.get_helper('/api/v2/files?directory=/var/log')
+ self.assertEqual(get_response.status_code, HTTPStatus.FORBIDDEN)
+
+ def test_get_not_exist_directory(self):
+ fake_dir = os.path.join(_FAKE_STORAGE_ROOT, 'fake_dir')
+ get_response = self.get_helper(f'/api/v2/files?directory={fake_dir}')
+ self.assertEqual(get_response.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_upload_files(self):
+ data = {}
+ data['file'] = [(BytesIO(b'abcdef'), os.path.join(BASE_DIR, 'test.jpg')),
+ (BytesIO(b'aaabbb'), os.path.join(BASE_DIR, 'test.txt'))]
+ data['id'] = 'jobs/123'
+ upload_response = self.client.post('/api/v2/files',
+ data=data,
+ content_type='multipart/form-data',
+ headers=self._get_headers())
+ self.assertEqual(upload_response.status_code, HTTPStatus.OK)
+ uploaded_files = self.get_response_data(upload_response)
+ self.assertEqual(
+ {
+ 'uploaded_files': [{
+ 'display_file_name':
+ 'test.jpg',
+ 'internal_path':
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123', secure_filename('test.jpg')),
+ 'internal_directory':
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123'),
+ }, {
+ 'display_file_name':
+ 'test.txt',
+ 'internal_path':
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123', secure_filename('test.txt')),
+ 'internal_directory':
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123'),
+ }],
+ }, uploaded_files)
+
+ # Check the saved files.
+ self.assertEqual(
+ 'abcdef',
+ self._file_manager.read(
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123', secure_filename('test.jpg'))))
+ self.assertEqual(
+ 'aaabbb',
+ self._file_manager.read(
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123', secure_filename('test.txt'))))
+
+ # Delete the saved files
+ self._file_manager.remove(
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123', secure_filename('test.jpg')))
+ self._file_manager.remove(
+ os.path.join(_FAKE_STORAGE_ROOT, UPLOAD_FILE_PATH, 'jobs/123', secure_filename('test.txt')))
+
+
+@patch('fedlearner_webconsole.file.apis.Envs.STORAGE_ROOT', _FAKE_STORAGE_ROOT)
+@patch('fedlearner_webconsole.file.apis.FILE_WHITELIST', (_FAKE_STORAGE_ROOT))
+class FileApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+
+ self._tempdir = _FAKE_STORAGE_ROOT
+ os.makedirs(self._tempdir, exist_ok=True)
+ Path(self._tempdir).joinpath('exists.txt').write_text('Hello World', encoding='utf-8')
+
+ def test_get_file_content_api(self):
+ get_response = self.get_helper(f'/api/v2/file?path={self._tempdir}/exists.txt')
+ self.assertEqual(self.get_response_data(get_response), 'Hello World')
+
+ get_response = self.get_helper('/api/v2/file?path=/system/fd.txt')
+ self.assertEqual(get_response.status_code, HTTPStatus.FORBIDDEN)
+
+
+@patch('fedlearner_webconsole.file.apis.Envs.STORAGE_ROOT', _FAKE_STORAGE_ROOT)
+@patch('fedlearner_webconsole.file.apis.FILE_WHITELIST', (_FAKE_STORAGE_ROOT))
+class ImageApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.signout_helper()
+
+ self._tempdir = _FAKE_STORAGE_ROOT
+ os.makedirs(self._tempdir, exist_ok=True)
+ Path(self._tempdir).joinpath('fake_image.jpg').write_bytes(b'This is a image')
+
+ def test_get_image_content_api(self):
+ get_response = self.get_helper(f'/api/v2/image?name={self._tempdir}/fake_image.jpg')
+ self.assertEqual(get_response.data, b'This is a image')
+ self.assertEqual(get_response.mimetype, 'image/jpeg')
+
+ get_response = self.get_helper(f'/api/v2/image?name={self._tempdir}/fd.txt')
+ self.assertEqual(get_response.status_code, HTTPStatus.BAD_REQUEST)
+
+ get_response = self.get_helper('/api/v2/image?name=/system/fd.txt')
+ self.assertEqual(get_response.status_code, HTTPStatus.FORBIDDEN)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/flag/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/flag/BUILD.bazel
new file mode 100644
index 000000000..bb5f070c4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/flag/BUILD.bazel
@@ -0,0 +1,46 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "@common_flask_restful//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api:envs_lib"],
+)
+
+py_test(
+ name = "models_lib_test",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/flag/__init__.py b/web_console_v2/api/fedlearner_webconsole/flag/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/flag/apis.py b/web_console_v2/api/fedlearner_webconsole/flag/apis.py
new file mode 100644
index 000000000..017c3e47e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/flag/apis.py
@@ -0,0 +1,47 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+from http import HTTPStatus
+
+from flask_restful import Resource, Api
+from fedlearner_webconsole.flag.models import get_flags
+
+
+class FlagsApi(Resource):
+
+ def get(self):
+ """Get flags
+ ---
+ tags:
+ - flag
+ responses:
+ 200:
+ description: Flags are returned
+ content:
+ application/json:
+ schema:
+ type: object
+ additionalProperties: true
+ example:
+ FLAG_1: string_value
+ FLAG_2: true
+ FLAG_3: 1
+ """
+ return {'data': get_flags()}, HTTPStatus.OK
+
+
+def initialize_flags_apis(api: Api):
+ api.add_resource(FlagsApi, '/flags')
diff --git a/web_console_v2/api/fedlearner_webconsole/flag/apis_test.py b/web_console_v2/api/fedlearner_webconsole/flag/apis_test.py
new file mode 100644
index 000000000..7a6d81567
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/flag/apis_test.py
@@ -0,0 +1,35 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+from unittest.mock import patch
+from testing.common import BaseTestCase
+import unittest
+
+
+class FlagsApisTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.flag.apis.get_flags')
+ def test_get_flags(self, get_flags):
+ get_flags.return_value = {'first_flag': False, 'second_flag': 0}
+ response = self.get_helper('/api/v2/flags')
+ flags = self.get_response_data(response)
+
+ self.assertEqual(False, flags.get('first_flag'))
+ self.assertEqual(0, flags.get('second_flag'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/flag/models.py b/web_console_v2/api/fedlearner_webconsole/flag/models.py
new file mode 100644
index 000000000..d0aabf467
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/flag/models.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+import inspect
+import json
+import logging
+from envs import Envs
+
+
+class _Flag(object):
+ FLAGS_DICT = json.loads(Envs.FLAGS)
+
+ def __init__(self, name: str, fallback_value):
+ self.name = name
+ self.value = fallback_value
+ self._merge()
+
+ def _merge(self):
+ """Merge fallback values with those ones set using env"""
+ value_from_env = self.FLAGS_DICT.get(self.name)
+
+ # update the value of a flag if env exists and it is of the correct type
+ if value_from_env is not None:
+ if isinstance(value_from_env, type(self.value)):
+ self.value = value_from_env
+ logging.info(f'Setting flag {self.name} to {self.value}.')
+
+ else:
+ logging.warning(f"""
+ Flag {self.name} is set of the wrong type, falling back to {self.value}.
+ Expected: {type(self.value)}; Got: {type(value_from_env)}
+ """)
+
+
+class Flag(object):
+ WORKSPACE_ENABLED = _Flag('workspace_enabled', False)
+ USER_MANAGEMENT_ENABLED = _Flag('user_management_enabled', True)
+ PRESET_TEMPLATE_EDIT_ENABLED = _Flag('preset_template_edit_enabled', False)
+ BCS_SUPPORT_ENABLED = _Flag('bcs_support_enabled', False)
+ TRUSTED_COMPUTING_ENABLED = _Flag('trusted_computing_enabled', True)
+ TEE_MACHINE_DEPLOYED = _Flag('tee_machine_deployed', False)
+ DASHBOARD_ENABLED = _Flag('dashboard_enabled', False)
+ OT_PSI_ENABLED = _Flag('ot_psi_enabled', True)
+ DATASET_STATE_FIX_ENABLED = _Flag('dataset_state_fix_enabled', False)
+ HASH_DATA_JOIN_ENABLED = _Flag('hash_data_join_enabled', False)
+ HELP_DOC_URL = _Flag('help_doc_url', '')
+ MODEL_JOB_GLOBAL_CONFIG_ENABLED = _Flag('model_job_global_config_enabled', False)
+ REVIEW_CENTER_CONFIGURATION = _Flag('review_center_configuration', '{}')
+ # show dataset with auth status but auto authority
+ DATASET_AUTH_STATUS_ENABLED = _Flag('dataset_auth_status_enabled', True)
+ # decide whether to check auth status when create dataset_job
+ DATASET_AUTH_STATUS_CHECK_ENABLED = _Flag('dataset_auth_status_check_enabled', False)
+ # set true after we implement this rpc func
+ LIST_DATASETS_RPC_ENABLED = _Flag('list_datasets_rpc_enabled', True)
+ # set true after we implement this rpc func
+ PENDING_PROJECT_ENABLED = _Flag('pending_project_enabled', True)
+ DATA_BATCH_RERUN_ENABLED = _Flag('data_batch_rerun_enabled', True)
+
+
+def get_flags() -> dict:
+ """Construct a dictionary for flags"""
+ dct = {}
+
+ # Gets flags (members of Flag)
+ # Ref: https://stackoverflow.com/questions/9058305/getting-attributes-of-a-class
+ attributes = inspect.getmembers(Flag, lambda a: not inspect.isroutine(a))
+ flags = [a for a in attributes if not (a[0].startswith('__') and a[0].endswith('__'))]
+ for _, flag in flags:
+ dct[flag.name] = flag.value
+
+ return dct
diff --git a/web_console_v2/api/fedlearner_webconsole/flag/models_test.py b/web_console_v2/api/fedlearner_webconsole/flag/models_test.py
new file mode 100644
index 000000000..cf66c6854
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/flag/models_test.py
@@ -0,0 +1,56 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+from unittest.mock import patch
+import unittest
+from fedlearner_webconsole.flag.models import _Flag, get_flags
+
+
+class FlagMock(object):
+ FIRST_FLAG = _Flag('first_flag', False)
+ SECOND_FLAG = _Flag('second_flag', 0)
+
+
+MOCK_ENV_FLAGS = {'first_flag': True, 'second_flag': 1}
+
+
+class FlagsModelsTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.flag.models._Flag.FLAGS_DICT', MOCK_ENV_FLAGS)
+ def test_fallback(self):
+ # this instance will be modified to True
+ first_flag = _Flag('first_flag', False)
+
+ # this instance will fallback to False due to type error
+ second_flag = _Flag('second_flag', False)
+
+ # this instance will fallback to 0 due to the absence of its value in envs
+ third_flag = _Flag('third_flag', 0)
+
+ self.assertEqual(True, first_flag.value)
+ self.assertEqual(False, second_flag.value)
+ self.assertEqual(0, third_flag.value)
+
+ @patch('fedlearner_webconsole.flag.models.Flag', FlagMock)
+ def test_get_flags(self):
+ flags = get_flags()
+
+ self.assertEqual(False, flags.get('first_flag'))
+ self.assertEqual(0, flags.get('second_flag'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/iam/BUILD.bazel
new file mode 100644
index 000000000..9829b89bb
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/BUILD.bazel
@@ -0,0 +1,153 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "permission_lib",
+ srcs = ["permission.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ ],
+)
+
+py_test(
+ name = "permission_test",
+ size = "small",
+ srcs = [
+ "permission_test.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":permission_lib",
+ ":resource_lib",
+ ],
+)
+
+py_library(
+ name = "resource_lib",
+ srcs = ["resource.py"],
+ imports = ["../.."],
+)
+
+py_test(
+ name = "resource_test",
+ size = "small",
+ srcs = [
+ "resource_test.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":resource_lib",
+ ],
+)
+
+py_library(
+ name = "checker_lib",
+ srcs = [
+ "checker.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":permission_lib",
+ "//web_console_v2/api:envs_lib",
+ ],
+)
+
+py_library(
+ name = "client_lib",
+ srcs = [
+ "client.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":checker_lib",
+ ":permission_lib",
+ ":resource_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ ],
+)
+
+py_test(
+ name = "client_test",
+ size = "small",
+ srcs = [
+ "client_test.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":client_lib",
+ ":permission_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ ],
+)
+
+py_library(
+ name = "iam_required_lib",
+ srcs = [
+ "iam_required.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":client_lib",
+ ":permission_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "@common_flask//:pkg",
+ ],
+)
+
+py_test(
+ name = "iam_required_integration_test",
+ size = "medium",
+ srcs = [
+ "iam_required_integration_test.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_base64_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = [
+ "apis.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_test",
+ size = "small",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":permission_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/__init__.py b/web_console_v2/api/fedlearner_webconsole/iam/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/apis.py b/web_console_v2/api/fedlearner_webconsole/iam/apis.py
new file mode 100644
index 000000000..a7bd45418
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/apis.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+
+from flask_restful import Resource
+from marshmallow import fields
+
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.decorators.pp_flask import use_kwargs
+from fedlearner_webconsole.utils.flask_utils import get_current_user, make_flask_response
+from fedlearner_webconsole.iam.client import get_iams
+
+
+class CheckPermissionsApi(Resource):
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'resource': fields.String(required=False, load_default=None),
+ 'permission': fields.String(required=False, load_default=None),
+ },
+ location='query',
+ )
+ def get(self, resource: Optional[str], permission: Optional[str]):
+ """Gets all IAM policies.
+ ---
+ tags:
+ - iam
+ description: gets all IAM policies.
+ responses:
+ 200:
+ description:
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ iams:
+ description: list of policies
+ type: array
+ items:
+ type: string
+ """
+ user = get_current_user()
+ result = get_iams(user, resource, permission)
+ return make_flask_response({'iams': result})
+
+
+def initialize_iams_apis(api):
+ api.add_resource(CheckPermissionsApi, '/iams')
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/apis_test.py b/web_console_v2/api/fedlearner_webconsole/iam/apis_test.py
new file mode 100644
index 000000000..e7556fcf3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/apis_test.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from testing.common import BaseTestCase
+from fedlearner_webconsole.auth.models import Role
+from fedlearner_webconsole.iam.permission import _DEFAULT_PERMISSIONS
+
+
+class IamApisTest(BaseTestCase):
+
+ def test_workflow_with_iam(self):
+ resp = self.get_helper('/api/v2/iams')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data['iams']), len(_DEFAULT_PERMISSIONS[Role.USER]))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/checker.py b/web_console_v2/api/fedlearner_webconsole/iam/checker.py
new file mode 100644
index 000000000..f931a2b98
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/checker.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from abc import ABCMeta, abstractmethod
+from typing import List, Tuple, Optional
+
+from envs import Envs
+from fedlearner_webconsole.iam.permission import Permission
+
+
+class IamChecker(metaclass=ABCMeta):
+
+ @abstractmethod
+ def check(self, identity: str, resource: str, permission: Permission) -> bool:
+ pass
+
+ @abstractmethod
+ def create(self, identity: str, resource: str, permissions: List[Permission]):
+ pass
+
+ @abstractmethod
+ def get(self, identity: str, resource: Optional[str],
+ permission: Optional[Permission]) -> List[Tuple[str, str, Permission]]:
+ pass
+
+
+class ThirdPartyChecker(IamChecker):
+
+ def check(self, identity: str, resource: str, permission: Permission) -> bool:
+ # Calls API according to the configuration
+ return True
+
+ def create(self, identity: str, resource: str, permissions: List[Permission]):
+ # Calls API according to the configuration
+ return
+
+ def get(self, identity: str, resource: Optional[str],
+ permission: Optional[Permission]) -> List[Tuple[str, str, Permission]]:
+ # Calls API according to the configuration
+ pass
+
+
+class TempChecker(IamChecker):
+
+ def __init__(self):
+ self.iams = []
+
+ def check(self, identity: str, resource: str, permission: Permission) -> bool:
+ # Calls API according to the configuration
+ if Envs.FLASK_ENV == 'production':
+ return True
+ if (identity, resource, permission) in self.iams:
+ return True
+ return False
+
+ def create(self, identity: str, resource: str, permissions: List[Permission]):
+ # Calls API according to the configuration
+ for permission in permissions:
+ self.iams.append((identity, resource, permission))
+ self.iams = list(set(self.iams))
+
+ def get(self, identity: str, resource: Optional[str],
+ permission: Optional[Permission]) -> List[Tuple[str, str, Permission]]:
+ return [
+ item for item in self.iams if item[0] == identity and resource is None or
+ item[1] == resource and permission is None or item[2] == permission
+ ]
+
+
+checker: IamChecker = TempChecker()
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/client.py b/web_console_v2/api/fedlearner_webconsole/iam/client.py
new file mode 100644
index 000000000..8e981bc77
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/client.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import List, Union, Optional, Tuple
+
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.iam.checker import checker
+from fedlearner_webconsole.iam.permission import Permission, is_valid_binding
+from fedlearner_webconsole.iam.resource import parse_resource_name, Resource, ResourceType
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.iam.permission import get_valid_permissions, \
+ get_role_default_permissions
+
+
+def check(username: str, resource_name: str, permission: Permission) -> bool:
+ resources: List[Resource] = parse_resource_name(resource_name)
+ # Checks bindings
+ for resource in resources:
+ if not is_valid_binding(resource.type, permission):
+ raise ValueError(f'Invalid binding: {resource.type}-{permission}')
+ for resource in resources:
+ if checker.check(username, resource.name, permission):
+ return True
+ return False
+
+
+def create_iams_for_resource(resource: Union[Project, Workflow], user: User):
+ # Should not be used in grpc server.
+ if isinstance(resource, Project):
+ resource = f'/projects/{resource.id}'
+ permissions = get_valid_permissions(ResourceType.PROJECT)
+ else:
+ return
+ checker.create(user.username, resource, permissions)
+
+
+def create_iams_for_user(user: User):
+ checker.create(user.username, '/', get_role_default_permissions(user.role))
+
+
+def get_iams(user: User, resource: Optional[str], permission: Optional[Permission]) -> List[Tuple[str, str, str]]:
+ return [(item[0], item[1], item[2].value) for item in checker.get(user.username, resource, permission)]
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/client_test.py b/web_console_v2/api/fedlearner_webconsole/iam/client_test.py
new file mode 100644
index 000000000..5d765568e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/client_test.py
@@ -0,0 +1,81 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, call
+
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.iam.client import check, create_iams_for_resource, create_iams_for_user
+from fedlearner_webconsole.iam.permission import Permission
+from fedlearner_webconsole.project.models import Project
+# must import for db analyze
+# pylint: disable=unused-import
+from fedlearner_webconsole.participant.models import ProjectParticipant, Participant
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.auth.models import Role
+
+
+class ClientTest(unittest.TestCase):
+
+ def test_check_invalid_binding(self):
+ with self.assertRaises(ValueError) as cm:
+ check('xiangyuxuan.prs', '/projects/123/workflows/3', Permission.DATASETS_POST)
+ self.assertIn('Invalid binding', str(cm.exception))
+
+ @patch('fedlearner_webconsole.iam.client.checker.check')
+ def test_check_false(self, mock_checker):
+ mock_checker.return_value = False
+ self.assertFalse(check('xprs', '/projects/123/workflows/3', Permission.WORKFLOW_PUT))
+ calls = [
+ call('xprs', '/', Permission.WORKFLOW_PUT),
+ call('xprs', '/projects/123', Permission.WORKFLOW_PUT),
+ call('xprs', '/projects/123/workflows/3', Permission.WORKFLOW_PUT),
+ ]
+ mock_checker.assert_has_calls(calls)
+
+ @patch('fedlearner_webconsole.iam.client.checker.check')
+ def test_check_true(self, mock_checker):
+ mock_checker.side_effect = [False, True]
+ self.assertTrue(check('prs', '/projects/123/workflows/3', Permission.WORKFLOW_PUT))
+ calls = [
+ call('prs', '/', Permission.WORKFLOW_PUT),
+ call('prs', '/projects/123', Permission.WORKFLOW_PUT),
+ ]
+ mock_checker.assert_has_calls(calls)
+
+ def test_create_iams_for_resource(self):
+ username = 'testu'
+ project_id = 1111
+ self.assertFalse(check(username, f'/projects/{project_id}', Permission.PROJECT_PATCH))
+ create_iams_for_resource(Project(id=project_id, name='test'), User(username=username))
+ self.assertTrue(check(username, f'/projects/{project_id}', Permission.PROJECT_PATCH))
+ workflow_id = 3333
+ self.assertTrue(check(username, f'/projects/{project_id}/workflows/{workflow_id}', Permission.WORKFLOW_PATCH))
+ self.assertFalse(check(username, f'/projects/{project_id+1}/workflows/{workflow_id}',
+ Permission.WORKFLOW_PATCH))
+
+ def test_create_iams_for_user(self):
+ admin = User(username='test_admin', role=Role.ADMIN)
+ user = User(username='test_user', role=Role.USER)
+ create_iams_for_user(admin)
+ create_iams_for_user(user)
+ project_id = 1
+ self.assertTrue(check(admin.username, f'/projects/{project_id}/workflows/123123', Permission.WORKFLOW_PATCH))
+ self.assertFalse(check(user.username, f'/projects/{project_id}/workflows/123123', Permission.WORKFLOW_PATCH))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/iam_required.py b/web_console_v2/api/fedlearner_webconsole/iam/iam_required.py
new file mode 100644
index 000000000..7395d612d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/iam_required.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+from functools import wraps
+
+from flask import request
+
+from fedlearner_webconsole.exceptions import NoAccessException
+from fedlearner_webconsole.utils.const import API_VERSION
+from fedlearner_webconsole.utils.flask_utils import get_current_user
+from fedlearner_webconsole.iam.permission import Permission
+from fedlearner_webconsole.iam.client import check
+
+
+def iam_required(permission: Permission):
+
+ def decorator(fn):
+
+ @wraps(fn)
+ def wraper(*args, **kwargs):
+ if permission is None:
+ return fn(*args, **kwargs)
+ # remove the prefix of url (/api/v2/)
+ resource_name = request.path.rpartition(API_VERSION)[-1]
+ user = get_current_user()
+ try:
+ if not check(user.username, resource_name, permission):
+ raise NoAccessException('No permission.')
+ except Exception as e:
+ # defensive programming for internal errors.
+ logging.error(f'Check permission failed: {user.username} ' f'{resource_name} {permission}: {str(e)}')
+ raise NoAccessException('No permission.') from e
+ return fn(*args, **kwargs)
+
+ return wraper
+
+ return decorator
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/iam_required_integration_test.py b/web_console_v2/api/fedlearner_webconsole/iam/iam_required_integration_test.py
new file mode 100644
index 000000000..b977642d0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/iam_required_integration_test.py
@@ -0,0 +1,103 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from http import HTTPStatus
+
+from testing.common import BaseTestCase
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.utils.pp_base64 import base64encode
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.db import db
+
+
+class IamRequiredTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ new_user = {
+ 'username': 'test_user',
+ 'password': base64encode('test_user12312'),
+ 'email': 'hello@bytedance.com',
+ 'role': 'USER',
+ 'name': 'codemonkey',
+ }
+ self.signin_as_admin()
+ resp = self.post_helper('/api/v2/auth/users', data=new_user)
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ no_permission_one = User(id=5, username='no_permission_one')
+ no_permission_one.set_password('no_permission_one')
+ session.add(no_permission_one)
+ session.commit()
+
+ def test_workflow_with_iam(self):
+ project_id = 123
+ workflow = Workflow(
+ name='test-workflow',
+ project_id=project_id,
+ config=WorkflowDefinition().SerializeToString(),
+ forkable=False,
+ state=WorkflowState.READY,
+ )
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.commit()
+ self.signin_helper()
+ response = self.patch_helper(f'/api/v2/projects/{project_id}/workflows/{workflow.id}',
+ data={'target_state': 'RUNNING'})
+ self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN)
+ self.signin_as_admin()
+ response = self.patch_helper(f'/api/v2/projects/{project_id}/workflows/{workflow.id}',
+ data={'target_state': 'RUNNING'})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+
+ # test project create hook
+ self.signin_helper()
+ data = {
+ 'name': 'test1',
+ 'config': {
+ 'variables': [{
+ 'name': 'test-post',
+ 'value': 'test'
+ }]
+ },
+ 'participant_ids': [2]
+ }
+ resp = self.post_helper('/api/v2/projects', data=data)
+ pro_id = self.get_response_data(resp)['id']
+ workflow = Workflow(
+ name='test-workflow-2',
+ project_id=pro_id,
+ config=WorkflowDefinition().SerializeToString(),
+ forkable=False,
+ state=WorkflowState.READY,
+ )
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.commit()
+ response = self.patch_helper(f'/api/v2/projects/{pro_id}/workflows/{workflow.id}',
+ data={'target_state': 'RUNNING'})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+
+ self.signin_helper('test_user', 'test_user12312')
+ response = self.patch_helper(f'/api/v2/projects/{pro_id}/workflows/{workflow.id}',
+ data={'target_state': 'RUNNING'})
+ self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/permission.py b/web_console_v2/api/fedlearner_webconsole/iam/permission.py
new file mode 100644
index 000000000..3175f2be4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/permission.py
@@ -0,0 +1,92 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import enum
+from typing import List
+from fedlearner_webconsole.iam.resource import ResourceType
+from fedlearner_webconsole.auth.models import Role
+
+
+class Permission(enum.Enum):
+ # Create project
+ PROJECTS_POST = 'projects.post'
+ PROJECT_GET = 'project.get'
+ # Manage project
+ PROJECT_PATCH = 'project.patch'
+ # Create dataset
+ DATASETS_POST = 'datasets.post'
+ DATASET_DELETE = 'dataset.delete'
+ # Create workflow
+ WORKFLOWS_POST = 'workflows.post'
+ # Config workflow
+ WORKFLOW_PUT = 'workflow.put'
+ # Update workflow
+ WORKFLOW_PATCH = 'workflow.patch'
+
+
+# Valid bindings between resources and permissions
+_VALID_BINDINGS = [
+ (ResourceType.APPLICATION, Permission.PROJECTS_POST),
+ (ResourceType.APPLICATION, Permission.PROJECT_GET),
+ (ResourceType.APPLICATION, Permission.PROJECT_PATCH),
+ (ResourceType.APPLICATION, Permission.DATASETS_POST),
+ (ResourceType.APPLICATION, Permission.DATASET_DELETE),
+ (ResourceType.APPLICATION, Permission.WORKFLOWS_POST),
+ (ResourceType.APPLICATION, Permission.WORKFLOW_PUT),
+ (ResourceType.APPLICATION, Permission.WORKFLOW_PATCH),
+ (ResourceType.PROJECT, Permission.PROJECT_GET),
+ (ResourceType.PROJECT, Permission.PROJECT_PATCH),
+ (ResourceType.PROJECT, Permission.DATASETS_POST),
+ (ResourceType.PROJECT, Permission.DATASET_DELETE),
+ (ResourceType.PROJECT, Permission.WORKFLOWS_POST),
+ (ResourceType.PROJECT, Permission.WORKFLOW_PUT),
+ (ResourceType.PROJECT, Permission.WORKFLOW_PATCH),
+ (ResourceType.DATASET, Permission.DATASET_DELETE),
+ (ResourceType.WORKFLOW, Permission.WORKFLOW_PUT),
+ (ResourceType.WORKFLOW, Permission.WORKFLOW_PATCH),
+]
+
+_DEFAULT_PERMISSIONS = {
+ Role.ADMIN: [
+ Permission.PROJECTS_POST,
+ Permission.PROJECT_GET,
+ Permission.PROJECT_PATCH,
+ Permission.DATASETS_POST,
+ Permission.DATASET_DELETE,
+ Permission.WORKFLOWS_POST,
+ Permission.WORKFLOW_PUT,
+ Permission.WORKFLOW_PATCH,
+ ],
+ Role.USER: [
+ Permission.PROJECTS_POST,
+ Permission.PROJECT_GET,
+ Permission.DATASETS_POST,
+ Permission.WORKFLOWS_POST,
+ Permission.WORKFLOW_PUT,
+ ]
+}
+
+
+def is_valid_binding(resource_type: ResourceType, permission: Permission) -> bool:
+ return (resource_type, permission) in _VALID_BINDINGS
+
+
+def get_valid_permissions(resource_type: ResourceType) -> List[Permission]:
+ return [item[1] for item in _VALID_BINDINGS if item[0] == resource_type]
+
+
+def get_role_default_permissions(user_role: Role) -> List[Permission]:
+ # Because the enum in sqlalchemy could be any string, so we should defensive code as below.
+ return _DEFAULT_PERMISSIONS.get(user_role, _DEFAULT_PERMISSIONS[Role.USER])
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/permission_test.py b/web_console_v2/api/fedlearner_webconsole/iam/permission_test.py
new file mode 100644
index 000000000..b067d9e63
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/permission_test.py
@@ -0,0 +1,31 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.iam.permission import is_valid_binding, Permission
+from fedlearner_webconsole.iam.resource import ResourceType
+
+
+class PermissionTest(unittest.TestCase):
+
+ def test_is_valid_binding(self):
+ self.assertTrue(is_valid_binding(ResourceType.APPLICATION, Permission.PROJECTS_POST))
+ self.assertTrue(is_valid_binding(ResourceType.PROJECT, Permission.DATASETS_POST))
+ self.assertFalse(is_valid_binding(ResourceType.DATASET, Permission.WORKFLOW_PUT))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/resource.py b/web_console_v2/api/fedlearner_webconsole/iam/resource.py
new file mode 100644
index 000000000..0c0382889
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/resource.py
@@ -0,0 +1,85 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+# pylint: disable=redefined-builtin
+import enum
+import logging
+import re
+from typing import List
+
+
+class ResourceType(enum.Enum):
+ # Application level
+ APPLICATION = 'application'
+ PROJECT = 'projects'
+ DATASET = 'datasets'
+ WORKFLOW = 'workflows'
+
+
+# yapf: disable
+# Resource type hierarchies
+_HIERARCHIES = [
+ (ResourceType.APPLICATION, ResourceType.PROJECT),
+ (ResourceType.PROJECT, ResourceType.DATASET),
+ (ResourceType.PROJECT, ResourceType.WORKFLOW)
+]
+# yapf: enable
+
+
+def is_valid_hierarchy(parent: ResourceType, child: ResourceType) -> bool:
+ return (parent, child) in _HIERARCHIES
+
+
+class Resource(object):
+
+ def __init__(self, type: ResourceType, id: str, name: str):
+ self.type = type
+ self.id = id
+ # Resource name, example: /projects/123/workflows/234
+ self.name = name
+
+
+_RESOURCE_PATTERN = re.compile(r'/([a-z]+)/([0-9]+)')
+
+
+def parse_resource_name(name: str) -> List[Resource]:
+ """Parses resource names to a list of resources.
+
+ Why not using repeat groups in regex?
+ Python does support this yet, so iterate the resources one by one in name.
+ """
+ resources = [Resource(ResourceType.APPLICATION, '', '/')]
+ if name == '/':
+ return resources
+ last_match = 0
+ normalized_name = ''
+ for match in _RESOURCE_PATTERN.finditer(name):
+ if match.start(0) != last_match:
+ raise ValueError('Invalid resource name')
+ last_match = match.end(0)
+ try:
+ r_type = ResourceType(match.group(1))
+ except ValueError as e:
+ logging.error(f'Unexpected resource type: {match.group(1)}')
+ raise ValueError('Invalid resource name') from e
+ id = match.group(2)
+ normalized_name = f'{normalized_name}/{r_type.value}/{id}'
+ resources.append(Resource(type=r_type, id=id, name=normalized_name))
+ # ignore the resource suffix such as /peer_workflows, so the last match
+ # may be not same as len(name).
+ for i in range(1, len(resources)):
+ if not is_valid_hierarchy(resources[i - 1].type, resources[i].type):
+ raise ValueError('Invalid resource hierarchy')
+ return resources
diff --git a/web_console_v2/api/fedlearner_webconsole/iam/resource_test.py b/web_console_v2/api/fedlearner_webconsole/iam/resource_test.py
new file mode 100644
index 000000000..0a0516eb0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/iam/resource_test.py
@@ -0,0 +1,66 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from fedlearner_webconsole.iam.resource import ResourceType, is_valid_hierarchy, parse_resource_name
+
+
+class ResourceTest(unittest.TestCase):
+
+ def test_is_valid_hierarchy(self):
+ self.assertTrue(is_valid_hierarchy(ResourceType.APPLICATION, ResourceType.PROJECT))
+ self.assertTrue(is_valid_hierarchy(ResourceType.PROJECT, ResourceType.WORKFLOW))
+ self.assertFalse(is_valid_hierarchy(ResourceType.DATASET, ResourceType.WORKFLOW))
+
+ def test_parse_resource_name_correctly(self):
+ resources = parse_resource_name('/')
+ self.assertEqual(len(resources), 1)
+ self.assertEqual(resources[0].type, ResourceType.APPLICATION)
+ self.assertEqual(resources[0].name, '/')
+ resources = parse_resource_name('/projects/234234')
+ self.assertEqual(len(resources), 2)
+ self.assertEqual(resources[0].type, ResourceType.APPLICATION)
+ self.assertEqual(resources[0].name, '/')
+ self.assertEqual(resources[1].type, ResourceType.PROJECT)
+ self.assertEqual(resources[1].name, '/projects/234234')
+ self.assertEqual(resources[1].id, '234234')
+ resources = parse_resource_name('/projects/123/workflows/333')
+ self.assertEqual(len(resources), 3)
+ self.assertEqual(resources[0].type, ResourceType.APPLICATION)
+ self.assertEqual(resources[0].name, '/')
+ self.assertEqual(resources[1].type, ResourceType.PROJECT)
+ self.assertEqual(resources[1].name, '/projects/123')
+ self.assertEqual(resources[1].id, '123')
+ self.assertEqual(resources[2].type, ResourceType.WORKFLOW)
+ self.assertEqual(resources[2].name, '/projects/123/workflows/333')
+ self.assertEqual(resources[2].id, '333')
+ resources = parse_resource_name('/projects/123/workflows')
+ self.assertEqual(len(resources), 2)
+ resources = parse_resource_name('/projects/123/workflows/2/peer_workflows')
+ self.assertEqual(len(resources), 3)
+
+ def test_parse_resource_name_invalid_hierarchy(self):
+ with self.assertRaises(ValueError) as cm:
+ parse_resource_name('/datasets/123/workflows/234')
+ self.assertEqual(str(cm.exception), 'Invalid resource hierarchy')
+
+ def test_parse_resource_name_invalid_string(self):
+ with self.assertRaises(ValueError) as cm:
+ parse_resource_name('/project/123')
+ self.assertEqual(str(cm.exception), 'Invalid resource name')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/initial_db.py b/web_console_v2/api/fedlearner_webconsole/initial_db.py
index da9997ba0..a5e80487d 100644
--- a/web_console_v2/api/fedlearner_webconsole/initial_db.py
+++ b/web_console_v2/api/fedlearner_webconsole/initial_db.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,30 +11,251 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
+import json
+import os
-from fedlearner_webconsole.auth.models import User, Role, State
-from fedlearner_webconsole.db import db_handler as db
+from pathlib import Path
+
+from sqlalchemy.orm import Session
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.auth.models import Role, State, User
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput
+from fedlearner_webconsole.proto.setting_pb2 import SystemVariables
+from fedlearner_webconsole.setting.models import Setting
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate, WorkflowTemplateKind
+from fedlearner_webconsole.proto.workflow_definition_pb2 import (WorkflowDefinition, WorkflowTemplateEditorInfo)
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.flag.models import Flag
+
+SettingTuple = collections.namedtuple('SettingTuple', ['key', 'value'])
INITIAL_USER_INFO = [{
'username': 'ada',
- 'password': 'fl@123.',
+ 'password': 'fl@12345.',
'name': 'ada',
'email': 'ada@fedlearner.com',
'role': Role.USER,
'state': State.ACTIVE,
}, {
'username': 'admin',
- 'password': 'fl@123.',
+ 'password': 'fl@12345.',
'name': 'admin',
'email': 'admin@fedlearner.com',
'role': Role.ADMIN,
'state': State.ACTIVE,
+}, {
+ 'username': 'robot',
+ 'password': 'fl@12345.',
+ 'name': 'robot',
+ 'email': 'robot@fedlearner.com',
+ 'role': Role.ADMIN,
+ 'state': State.ACTIVE,
}]
+INITIAL_SYSTEM_VARIABLES = ParseDict(
+ {
+ 'variables': [{
+ 'name': 'labels',
+ 'value': {},
+ 'value_type': 'OBJECT',
+ 'fixed': True
+ }, {
+ 'name': 'volume_mounts_list',
+ 'value': [{
+ 'mountPath': '/data',
+ 'name': 'data'
+ }],
+ 'value_type': 'LIST',
+ 'fixed': True
+ }, {
+ 'name': 'volumes_list',
+ 'value': [{
+ 'persistentVolumeClaim': {
+ 'claimName': 'pvc-fedlearner-default'
+ },
+ 'name': 'data'
+ }],
+ 'value_type': 'LIST',
+ 'fixed': True
+ }, {
+ 'name': 'envs_list',
+ 'value': [{
+ 'name': 'HADOOP_HOME',
+ 'value': ''
+ }, {
+ 'name': 'MANUFACTURER',
+ 'value': 'dm9sY2VuZ2luZQ=='
+ }],
+ 'value_type': 'LIST',
+ 'fixed': True
+ }, {
+ 'name': 'namespace',
+ 'value': 'default',
+ 'value_type': 'STRING',
+ 'fixed': True
+ }, {
+ 'name': 'serving_image',
+ 'value': 'artifact.bytedance.com/fedlearner/'
+ 'privacy_perserving_computing_serving:7359b10685e1646450dfda389d228066',
+ 'value_type': 'STRING',
+ 'fixed': True
+ }, {
+ 'name': 'spark_image',
+ 'value': 'artifact.bytedance.com/fedlearner/pp_data_inspection:2.2.4.1',
+ 'value_type': 'STRING',
+ 'fixed': True
+ }, {
+ 'name': 'image_repo',
+ 'value': 'artifact.bytedance.com/fedlearner',
+ 'value_type': 'STRING',
+ 'fixed': False
+ }]
+ }, SystemVariables())
+
+INITIAL_EMAIL_GROUP = SettingTuple(key='sys_email_group', value='privacy_computing@bytedance.com')
+
+
+def _insert_setting_if_not_exists(session: Session, st: SettingTuple):
+ if session.query(Setting).filter_by(uniq_key=st.key).first() is None:
+ setting = Setting(uniq_key=st.key, value=st.value)
+ session.add(setting)
+
+
+def migrate_system_variables(session: Session, initial_vars: SystemVariables):
+ setting_service = SettingService(session)
+ origin_sys_vars = setting_service.get_system_variables()
+ result = merge_system_variables(initial_vars, origin_sys_vars)
+ setting_service.set_system_variables(result)
+
+
+def merge_system_variables(extend: SystemVariables, origin: SystemVariables) -> SystemVariables:
+ """Merge two Systemvariables, when two SystemVariable has the same name, use origin's value."""
+ key_map = {var.name: var for var in extend.variables}
+ for var in origin.variables:
+ key_map[var.name] = var
+ return SystemVariables(variables=[key_map[key] for key in key_map])
+
+
+def _insert_or_update_templates(session: Session):
+ path = Path(__file__, '../sys_preset_templates/').resolve()
+ template_files = path.rglob('*.json')
+ for template_file in template_files:
+ with open(os.path.join(path, template_file), encoding='utf-8') as f:
+ data = json.load(f)
+ template_proto = ParseDict(data['config'], WorkflowDefinition(), ignore_unknown_fields=True)
+ editor_info_proto = ParseDict(data['editor_info'], WorkflowTemplateEditorInfo(), ignore_unknown_fields=True)
+ template = session.query(WorkflowTemplate).filter_by(name=data['name']).first()
+ if template is None:
+ template = WorkflowTemplate(name=data['name'])
+ template.comment = data['comment']
+ template.group_alias = template_proto.group_alias
+ template.kind = WorkflowTemplateKind.PRESET.value
+ template.set_config(template_proto)
+ template.set_editor_info(editor_info_proto)
+ session.add(template)
+
+
+def _insert_schedule_workflow_item(session):
+ composer_service = ComposerService(session)
+ # Finishes the old one
+ composer_service.finish('workflow_scheduler')
+ composer_service.collect_v2(
+ 'workflow_scheduler_v2',
+ items=[(ItemType.SCHEDULE_WORKFLOW, RunnerInput())],
+ # cron job at every 1 minute, specific time to avoid congestion.
+ cron_config='* * * * * 45')
+ composer_service.collect_v2(
+ 'job_scheduler_v2',
+ items=[(ItemType.SCHEDULE_JOB, RunnerInput())],
+ # cron job at every 1 minute, specific time to avoid congestion.
+ cron_config='* * * * * 15')
+
+
+def _insert_dataset_job_scheduler_item(session):
+ composer_service = ComposerService(session)
+ # finish the old scheduler
+ composer_service.finish('dataset_job_scheduler')
+ composer_service.finish('dataset_cron_job_scheduler')
+ # insert new scheduler
+ composer_service.collect_v2(
+ 'dataset_short_period_scheduler',
+ items=[(ItemType.DATASET_SHORT_PERIOD_SCHEDULER, RunnerInput())],
+ # cron job at every 30 seconds
+ cron_config='* * * * * */30')
+ composer_service.collect_v2(
+ 'dataset_long_period_scheduler',
+ items=[(ItemType.DATASET_LONG_PERIOD_SCHEDULER, RunnerInput())],
+ # cron job at every 30 min
+ cron_config='*/30 * * * *')
+
+
+def _insert_cleanup_cronjob_item(session):
+ composer_service = ComposerService(session)
+ composer_service.collect_v2(
+ 'cleanup_cron_job',
+ items=[(ItemType.CLEANUP_CRON_JOB, RunnerInput())],
+ # cron job at every 30 min
+ cron_config='*/30 * * * *')
+
+
+def _insert_tee_runner_item(session):
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ return
+ composer_service = ComposerService(session)
+ composer_service.collect_v2(
+ 'tee_create_runner',
+ items=[(ItemType.TEE_CREATE_RUNNER, RunnerInput())],
+ # cron job at every 30 seconds
+ cron_config='* * * * * */30')
+ composer_service.collect_v2(
+ 'tee_resource_check_runner',
+ items=[(ItemType.TEE_RESOURCE_CHECK_RUNNER, RunnerInput())],
+ # cron job at every 30 min
+ cron_config='*/30 * * * *')
+
+
+def _insert_project_runner_item(session):
+ if not Flag.PENDING_PROJECT_ENABLED.value:
+ return
+ composer_service = ComposerService(session)
+ composer_service.collect_v2(
+ 'project_scheduler_v2',
+ items=[(ItemType.SCHEDULE_PROJECT, RunnerInput())],
+ # cron job at every 1 minute, specific time to avoid congestion.
+ cron_config='* * * * * 30')
+
+
+def _insert_model_job_scheduler_runner_item(session: Session):
+ if not Flag.MODEL_JOB_GLOBAL_CONFIG_ENABLED:
+ return
+ composer_service = ComposerService(session)
+ composer_service.collect_v2('model_job_scheduler_runner',
+ items=[(ItemType.SCHEDULE_MODEL_JOB, RunnerInput())],
+ cron_config='* * * * * */30')
+
+
+def _insert_model_job_group_scheduler_runner_item(session: Session):
+ if not Flag.MODEL_JOB_GLOBAL_CONFIG_ENABLED:
+ return
+ composer_service = ComposerService(session)
+ composer_service.collect_v2('model_job_group_scheduler_runner',
+ items=[(ItemType.SCHEDULE_MODEL_JOB_GROUP, RunnerInput())],
+ cron_config='* * * * * */30')
+ composer_service.collect_v2(
+ 'model_job_group_long_period_scheduler_runner',
+ items=[(ItemType.SCHEDULE_LONG_PERIOD_MODEL_JOB_GROUP, RunnerInput())],
+ # cron job at every 30 min
+ cron_config='*/30 * * * *')
+
def initial_db():
with db.session_scope() as session:
- # initial user info first
+ # Initializes user info first
for u_info in INITIAL_USER_INFO:
username = u_info['username']
password = u_info['password']
@@ -42,13 +263,19 @@ def initial_db():
email = u_info['email']
role = u_info['role']
state = u_info['state']
- if session.query(User).filter_by(
- username=username).first() is None:
- user = User(username=username,
- name=name,
- email=email,
- role=role,
- state=state)
+ if session.query(User).filter_by(username=username).first() is None:
+ user = User(username=username, name=name, email=email, role=role, state=state)
user.set_password(password=password)
session.add(user)
+ # Initializes settings
+ _insert_setting_if_not_exists(session, INITIAL_EMAIL_GROUP)
+ migrate_system_variables(session, INITIAL_SYSTEM_VARIABLES)
+ _insert_or_update_templates(session)
+ _insert_schedule_workflow_item(session)
+ _insert_dataset_job_scheduler_item(session)
+ _insert_cleanup_cronjob_item(session)
+ _insert_tee_runner_item(session)
+ _insert_project_runner_item(session)
+ _insert_model_job_scheduler_runner_item(session)
+ _insert_model_job_group_scheduler_runner_item(session)
session.commit()
diff --git a/web_console_v2/api/fedlearner_webconsole/initial_db_test.py b/web_console_v2/api/fedlearner_webconsole/initial_db_test.py
new file mode 100644
index 000000000..1bd071fb2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/initial_db_test.py
@@ -0,0 +1,52 @@
+import unittest
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.initial_db import (_insert_or_update_templates, initial_db, migrate_system_variables,
+ INITIAL_SYSTEM_VARIABLES)
+from fedlearner_webconsole.proto.setting_pb2 import SystemVariables
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class InitialDbTest(NoWebServerTestCase):
+
+ def test_initial_db(self):
+ initial_db()
+ with db.session_scope() as session:
+ self.assertEqual(SettingService(session).get_system_variables_dict()['namespace'], 'default')
+
+ def test_merge_system_variables(self):
+ with db.session_scope() as session:
+ migrate_system_variables(session, INITIAL_SYSTEM_VARIABLES)
+ session.commit()
+
+ with db.session_scope() as session:
+ migrate_system_variables(
+ session,
+ ParseDict(
+ {
+ 'variables': [{
+ 'name': 'namespace',
+ 'value': 'not_default'
+ }, {
+ 'name': 'unknown',
+ 'value': 'test'
+ }]
+ }, SystemVariables()))
+ self.assertEqual(SettingService(session).get_system_variables_dict()['namespace'], 'default')
+ self.assertEqual(SettingService(session).get_system_variables_dict()['unknown'], 'test')
+ session.commit()
+
+ def test_insert_syspreset_template(self):
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ session.commit()
+
+ with db.session_scope() as session:
+ self.assertEqual(session.query(WorkflowTemplate).count(), 18)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/job/BUILD.bazel
new file mode 100644
index 000000000..f8561bb52
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/BUILD.bazel
@@ -0,0 +1,317 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "controller_lib",
+ srcs = ["controller.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ ":utils_lib",
+ ":yaml_formatter_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "controller_lib_test",
+ size = "small",
+ srcs = [
+ "controller_test.py",
+ ],
+ imports = ["../.."],
+ main = "controller_test.py",
+ deps = [
+ ":controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:yaml_formatter_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "crd_lib",
+ srcs = ["crd.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_cache_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ ],
+)
+
+py_library(
+ name = "metrics_lib",
+ srcs = ["metrics.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:es_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:job_metrics_lib",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_matplotlib//:pkg",
+ "@common_mpld3//:pkg",
+ ],
+)
+
+py_test(
+ name = "metrics_lib_test",
+ size = "medium",
+ srcs = [
+ "metrics_test.py",
+ ],
+ imports = ["../.."],
+ main = "metrics_test.py",
+ deps = [
+ ":metrics_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing/test_data:test_data_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ ":crd_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:mixins_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "model_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "service_lib",
+ srcs = ["service.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_yaml_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "service_lib_test",
+ size = "small",
+ srcs = [
+ "service_test.py",
+ ],
+ imports = ["../.."],
+ main = "service_test.py",
+ deps = [
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_yaml_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib"],
+)
+
+py_test(
+ name = "utils_test",
+ size = "small",
+ srcs = [
+ "utils_test.py",
+ ],
+ imports = ["../.."],
+ main = "utils_test.py",
+ deps = [
+ ":utils_lib",
+ ],
+)
+
+py_library(
+ name = "yaml_formatter_lib",
+ srcs = ["yaml_formatter.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/k8s:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_yaml_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "yaml_formatter_lib_test",
+ size = "small",
+ srcs = [
+ "yaml_formatter_test.py",
+ ],
+ imports = ["../.."],
+ main = "yaml_formatter_test.py",
+ deps = [
+ ":yaml_formatter_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_yaml_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "scheduler_lib",
+ srcs = [
+ "scheduler.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":controller_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "scheduler_lib_test",
+ size = "small",
+ srcs = [
+ "scheduler_test.py",
+ ],
+ imports = ["../.."],
+ main = "scheduler_test.py",
+ deps = [
+ ":scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":metrics_lib",
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:es_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:kibana_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_sqlalchemy//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "event_listener_lib",
+ srcs = ["event_listener.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:event_listener_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_cache_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_job_controller_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/job/apis.py b/web_console_v2/api/fedlearner_webconsole/job/apis.py
index d9a073dbe..83208d251 100644
--- a/web_console_v2/api/fedlearner_webconsole/job/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/job/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,331 +15,545 @@
import json
import logging
import time
+from typing import Optional
-from flask_restful import Resource, reqparse, abort
+from flask_restful import Resource, reqparse
from google.protobuf.json_format import MessageToDict
+from webargs.flaskparser import use_kwargs
+from marshmallow import fields
+from sqlalchemy.orm.session import Session
from envs import Envs
-from fedlearner_webconsole.exceptions import (
- NotFoundException, InternalException
-)
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import (NotFoundException, InternalException, InvalidArgumentException)
from fedlearner_webconsole.job.metrics import JobMetricsBuilder
from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.job.service import JobService
+from fedlearner_webconsole.participant.models import Participant
from fedlearner_webconsole.proto import common_pb2
from fedlearner_webconsole.rpc.client import RpcClient
-from fedlearner_webconsole.utils.decorators import jwt_required
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
from fedlearner_webconsole.utils.es import es
from fedlearner_webconsole.utils.kibana import Kibana
from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
-def _get_job(job_id):
- result = Job.query.filter_by(id=job_id).first()
+def _get_job(job_id, session: Session):
+ result = session.query(Job).filter_by(id=job_id).first()
if result is None:
raise NotFoundException(f'Failed to find job_id: {job_id}')
return result
class JobApi(Resource):
- @jwt_required()
- def get(self, job_id):
- job = _get_job(job_id)
- return {'data': job.to_dict()}
- # TODO: manual start jobs
+ @credentials_required
+ def get(self, job_id):
+ """Get job details.
+ ---
+ tags:
+ - job
+ description: Get job details.
+ parameters:
+ - in: path
+ name: job_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: Detail of job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.JobPb'
+ """
+ with db.session_scope() as session:
+ job = _get_job(job_id, session)
+ result = job.to_proto()
+ result.pods.extend(JobService.get_pods(job))
+ result.snapshot = JobService.get_job_yaml(job)
+ return make_flask_response(result)
class PodLogApi(Resource):
- @jwt_required()
- def get(self, job_id, pod_name):
- parser = reqparse.RequestParser()
- parser.add_argument('start_time', type=int, location='args',
- required=False,
- help='start_time must be timestamp')
- parser.add_argument('max_lines', type=int, location='args',
- required=True,
- help='max_lines is required')
- data = parser.parse_args()
- start_time = data['start_time']
- max_lines = data['max_lines']
- job = _get_job(job_id)
- if start_time is None:
- start_time = job.workflow.start_at
- return {'data': es.query_log(Envs.ES_INDEX, '', pod_name,
- start_time * 1000,
- int(time.time() * 1000))[:max_lines][::-1]}
+
+ @credentials_required
+ @use_kwargs({
+ 'start_time': fields.Int(required=False, load_default=None),
+ 'max_lines': fields.Int(required=True)
+ },
+ location='query')
+ def get(self, start_time: Optional[int], max_lines: int, job_id: int, pod_name: str):
+ """Get pod logs.
+ ---
+ tags:
+ - job
+ description: Get pod logs.
+ parameters:
+ - in: path
+ name: job_id
+ schema:
+ type: integer
+ - in: path
+ name: pod_name
+ schema:
+ type: string
+ - in: query
+ description: timestamp in seconds
+ name: start_time
+ schema:
+ type: integer
+ - in: query
+ name: max_lines
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: List of pod logs
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+
+ """
+ with db.session_scope() as session:
+ job = _get_job(job_id, session)
+ if start_time is None and job.workflow:
+ start_time = job.workflow.start_at
+ return make_flask_response(
+ es.query_log(Envs.ES_INDEX, '', pod_name, (start_time or 0) * 1000)[:max_lines][::-1])
class JobLogApi(Resource):
- @jwt_required()
- def get(self, job_id):
- parser = reqparse.RequestParser()
- parser.add_argument('start_time', type=int, location='args',
- required=False,
- help='project_id must be timestamp')
- parser.add_argument('max_lines', type=int, location='args',
- required=True,
- help='max_lines is required')
- data = parser.parse_args()
- start_time = data['start_time']
- max_lines = data['max_lines']
- job = _get_job(job_id)
- if start_time is None:
- start_time = job.workflow.start_at
- return {
- 'data': es.query_log(
- Envs.ES_INDEX, job.name,
- 'fedlearner-operator',
- start_time * 1000,
- int(time.time() * 1000),
- Envs.OPERATOR_LOG_MATCH_PHRASE)[:max_lines][::-1]
- }
+
+ @credentials_required
+ @use_kwargs({
+ 'start_time': fields.Int(required=False, load_default=None),
+ 'max_lines': fields.Int(required=True)
+ },
+ location='query')
+ def get(self, start_time: Optional[int], max_lines: int, job_id: int):
+ """Get job logs.
+ ---
+ tags:
+ - job
+ description: Get job logs.
+ parameters:
+ - in: path
+ name: job_id
+ schema:
+ type: integer
+ - in: query
+ description: timestamp in seconds
+ name: start_time
+ schema:
+ type: integer
+ - in: query
+ name: max_lines
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: List of job logs
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
+ with db.session_scope() as session:
+ job = _get_job(job_id, session)
+ if start_time is None and job.workflow:
+ start_time = job.workflow.start_at
+ return make_flask_response(
+ es.query_log(Envs.ES_INDEX,
+ job.name,
+ 'fedlearner-operator', (start_time or 0) * 1000,
+ match_phrase=Envs.OPERATOR_LOG_MATCH_PHRASE)[:max_lines][::-1])
class JobMetricsApi(Resource):
- @jwt_required()
- def get(self, job_id):
- job = _get_job(job_id)
- try:
- metrics = JobMetricsBuilder(job).plot_metrics()
- # Metrics is a list of dict. Each dict can be rendered by frontend
- # with mpld3.draw_figure('figure1', json)
- return {'data': metrics}
- except Exception as e: # pylint: disable=broad-except
- logging.warning('Error building metrics: %s', repr(e))
- abort(400, message=repr(e))
+
+ @credentials_required
+ @use_kwargs({
+ 'raw': fields.Bool(required=False, load_default=False),
+ }, location='query')
+ def get(self, job_id: int, raw: bool):
+ """Get job Metrics.
+ ---
+ tags:
+ - job
+ description: Get job metrics.
+ parameters:
+ - in: path
+ name: job_id
+ schema:
+ type: integer
+ - in: query
+ name: raw
+ schema:
+ type: boolean
+ responses:
+ 200:
+ description: List of job metrics
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: object
+ """
+ with db.session_scope() as session:
+ job = _get_job(job_id, session)
+ try:
+ builder = JobMetricsBuilder(job)
+ if raw:
+ return make_flask_response(data=builder.query_metrics())
+ # Metrics is a list of dict. Each dict can be rendered by frontend
+ # with mpld3.draw_figure('figure1', json)
+ return make_flask_response(data=builder.plot_metrics())
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning('Error building metrics: %s', repr(e))
+ raise InvalidArgumentException(details=repr(e)) from e
class PeerJobMetricsApi(Resource):
- @jwt_required()
- def get(self, workflow_uuid, participant_id, job_name):
- workflow = Workflow.query.filter_by(uuid=workflow_uuid).first()
- if workflow is None:
- raise NotFoundException(
- f'Failed to find workflow: {workflow_uuid}')
- project_config = workflow.project.get_config()
- party = project_config.participants[participant_id]
- client = RpcClient(project_config, party)
- resp = client.get_job_metrics(job_name)
- if resp.status.code != common_pb2.STATUS_SUCCESS:
- raise InternalException(resp.status.msg)
- metrics = json.loads(resp.metrics)
+ @credentials_required
+ def get(self, workflow_uuid: str, participant_id: int, job_name: str):
+ """Get peer job metrics.
+ ---
+ tags:
+ - job
+ description: Get peer Job metrics.
+ parameters:
+ - in: path
+ name: workflow_uuid
+ schema:
+ type: string
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: path
+ name: job_name
+ schema:
+ type: string
+ responses:
+ 200:
+ description: List of job metrics
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: object
+ """
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).filter_by(uuid=workflow_uuid).first()
+ if workflow is None:
+ raise NotFoundException(f'Failed to find workflow: {workflow_uuid}')
+ participant = session.query(Participant).filter_by(id=participant_id).first()
+ client = RpcClient.from_project_and_participant(workflow.project.name, workflow.project.token,
+ participant.domain_name)
+ resp = client.get_job_metrics(job_name)
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ raise InternalException(resp.status.msg)
+
+ metrics = json.loads(resp.metrics)
- # Metrics is a list of dict. Each dict can be rendered by frontend with
- # mpld3.draw_figure('figure1', json)
- return {'data': metrics}
+ # Metrics is a list of dict. Each dict can be rendered by frontend with
+ # mpld3.draw_figure('figure1', json)
+ return make_flask_response(metrics)
class JobEventApi(Resource):
# TODO(xiangyuxuan): need test
- @jwt_required()
- def get(self, job_id):
- parser = reqparse.RequestParser()
- parser.add_argument('start_time', type=int, location='args',
- required=False,
- help='start_time must be timestamp')
- parser.add_argument('max_lines', type=int, location='args',
- required=True,
- help='max_lines is required')
- data = parser.parse_args()
- start_time = data['start_time']
- max_lines = data['max_lines']
- job = _get_job(job_id)
- if start_time is None:
- start_time = job.workflow.start_at
- return {'data': es.query_events(Envs.ES_INDEX, job.name,
- 'fedlearner-operator',
- start_time,
- int(time.time() * 1000
- ),
- Envs.OPERATOR_LOG_MATCH_PHRASE
- )[:max_lines][::-1]}
+ @credentials_required
+ @use_kwargs({
+ 'start_time': fields.Int(required=False, load_default=None),
+ 'max_lines': fields.Int(required=True)
+ },
+ location='query')
+ def get(self, start_time: Optional[int], max_lines: int, job_id: int):
+ """Get job events.
+ ---
+ tags:
+ - job
+ description: Get job events.
+ parameters:
+ - in: path
+ name: job_id
+ schema:
+ type: integer
+ - in: query
+ description: timestamp in seconds
+ name: start_time
+ schema:
+ type: integer
+ - in: query
+ name: max_lines
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: List of job events
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
+ with db.session_scope() as session:
+ job = _get_job(job_id, session)
+ if start_time is None and job.workflow:
+ start_time = job.workflow.start_at
+ return make_flask_response(
+ es.query_events(Envs.ES_INDEX, job.name, 'fedlearner-operator', start_time, int(time.time() * 1000),
+ Envs.OPERATOR_LOG_MATCH_PHRASE)[:max_lines][::-1])
class PeerJobEventsApi(Resource):
- @jwt_required()
- def get(self, workflow_uuid, participant_id, job_name):
- parser = reqparse.RequestParser()
- parser.add_argument('start_time', type=int, location='args',
- required=False,
- help='project_id must be timestamp')
- parser.add_argument('max_lines', type=int, location='args',
- required=True,
- help='max_lines is required')
- data = parser.parse_args()
- start_time = data['start_time']
- max_lines = data['max_lines']
- workflow = Workflow.query.filter_by(uuid=workflow_uuid).first()
- if workflow is None:
- raise NotFoundException(
- f'Failed to find workflow: {workflow_uuid}')
- if start_time is None:
- start_time = workflow.start_at
- project_config = workflow.project.get_config()
- party = project_config.participants[participant_id]
- client = RpcClient(project_config, party)
- resp = client.get_job_events(job_name=job_name,
- start_time=start_time,
- max_lines=max_lines)
- if resp.status.code != common_pb2.STATUS_SUCCESS:
- raise InternalException(resp.status.msg)
- peer_events = MessageToDict(
- resp,
- preserving_proto_field_name=True,
- including_default_value_fields=True)['logs']
- return {'data': peer_events}
+
+ @credentials_required
+ @use_kwargs({
+ 'start_time': fields.Int(required=False, load_default=None),
+ 'max_lines': fields.Int(required=True)
+ },
+ location='query')
+ def get(self, start_time: Optional[int], max_lines: int, workflow_uuid: str, participant_id: int, job_name: str):
+ """Get peer job events.
+ ---
+ tags:
+ - job
+ description: Get peer job events.
+ parameters:
+ - in: path
+ name: workflow_uuid
+ schema:
+ type: string
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ - in: path
+ name: job_name
+ schema:
+ type: string
+ responses:
+ 200:
+ description: List of peer job events
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).filter_by(uuid=workflow_uuid).first()
+ if workflow is None:
+ raise NotFoundException(f'Failed to find workflow: {workflow_uuid}')
+ if start_time is None:
+ start_time = workflow.start_at
+ participant = session.query(Participant).filter_by(id=participant_id).first()
+ client = RpcClient.from_project_and_participant(workflow.project.name, workflow.project.token,
+ participant.domain_name)
+ resp = client.get_job_events(job_name=job_name, start_time=start_time, max_lines=max_lines)
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ raise InternalException(resp.status.msg)
+ peer_events = MessageToDict(resp, preserving_proto_field_name=True,
+ including_default_value_fields=True)['logs']
+ return make_flask_response(peer_events)
class KibanaMetricsApi(Resource):
- @jwt_required()
+
+ @credentials_required
def get(self, job_id):
- job = _get_job(job_id)
parser = reqparse.RequestParser()
- parser.add_argument('type', type=str, location='args',
+ parser.add_argument('type',
+ type=str,
+ location='args',
required=True,
- choices=('Rate', 'Ratio', 'Numeric',
- 'Time', 'Timer'),
+ choices=('Rate', 'Ratio', 'Numeric', 'Time', 'Timer'),
help='Visualization type is required. Choices: '
- 'Rate, Ratio, Numeric, Time, Timer')
- parser.add_argument('interval', type=str, location='args',
+ 'Rate, Ratio, Numeric, Time, Timer')
+ parser.add_argument('interval',
+ type=str,
+ location='args',
default='',
help='Time bucket interval length, '
- 'defaults to be automated by Kibana.')
- parser.add_argument('x_axis_field', type=str, location='args',
+ 'defaults to be automated by Kibana.')
+ parser.add_argument('x_axis_field',
+ type=str,
+ location='args',
default='tags.event_time',
help='Time field (X axis) is required.')
- parser.add_argument('query', type=str, location='args',
- help='Additional query string to the graph.')
- parser.add_argument('start_time', type=int, location='args',
+ parser.add_argument('query', type=str, location='args', help='Additional query string to the graph.')
+ parser.add_argument('start_time',
+ type=int,
+ location='args',
default=-1,
help='Earliest time of data.'
- 'Unix timestamp in secs.')
- parser.add_argument('end_time', type=int, location='args',
+ 'Unix timestamp in secs.')
+ parser.add_argument('end_time',
+ type=int,
+ location='args',
default=-1,
help='Latest time of data.'
- 'Unix timestamp in secs.')
+ 'Unix timestamp in secs.')
# (Joined) Rate visualization is fixed and only interval, query and
# x_axis_field can be modified
# Ratio visualization
- parser.add_argument('numerator', type=str, location='args',
+ parser.add_argument('numerator',
+ type=str,
+ location='args',
help='Numerator is required in Ratio '
- 'visualization. '
- 'A query string similar to args::query.')
- parser.add_argument('denominator', type=str, location='args',
+ 'visualization. '
+ 'A query string similar to args::query.')
+ parser.add_argument('denominator',
+ type=str,
+ location='args',
help='Denominator is required in Ratio '
- 'visualization. '
- 'A query string similar to args::query.')
+ 'visualization. '
+ 'A query string similar to args::query.')
# Numeric visualization
- parser.add_argument('aggregator', type=str, location='args',
+ parser.add_argument('aggregator',
+ type=str,
+ location='args',
default='Average',
- choices=('Average', 'Sum', 'Max', 'Min', 'Variance',
- 'Std. Deviation', 'Sum of Squares'),
+ choices=('Average', 'Sum', 'Max', 'Min', 'Variance', 'Std. Deviation', 'Sum of Squares'),
help='Aggregator type is required in Numeric and '
- 'Timer visualization.')
- parser.add_argument('value_field', type=str, location='args',
+ 'Timer visualization.')
+ parser.add_argument('value_field',
+ type=str,
+ location='args',
help='The field to be aggregated on is required '
- 'in Numeric visualization.')
+ 'in Numeric visualization.')
# No additional arguments in Time visualization
#
# Timer visualization
- parser.add_argument('timer_names', type=str, location='args',
+ parser.add_argument('timer_names',
+ type=str,
+ location='args',
help='Names of timers is required in '
- 'Timer visualization.')
- parser.add_argument('split', type=int, location='args',
- default=0,
- help='Whether to plot timers individually.')
+ 'Timer visualization.')
+ parser.add_argument('split', type=int, location='args', default=0, help='Whether to plot timers individually.')
args = parser.parse_args()
- try:
- if args['type'] in Kibana.TSVB:
- return {'data': Kibana.create_tsvb(job, args)}
- if args['type'] in Kibana.TIMELION:
- return {'data': Kibana.create_timelion(job, args)}
- return {'data': []}
- except Exception as e: # pylint: disable=broad-except
- abort(400, message=repr(e))
+ with db.session_scope() as session:
+ job = _get_job(job_id, session)
+ try:
+ if args['type'] in Kibana.TSVB:
+ return {'data': Kibana.create_tsvb(job, args)}
+ if args['type'] in Kibana.TIMELION:
+ return {'data': Kibana.create_timelion(job, args)}
+ return {'data': []}
+ except Exception as e: # pylint: disable=broad-except
+ raise InvalidArgumentException(details=repr(e)) from e
class PeerKibanaMetricsApi(Resource):
- @jwt_required()
+
+ @credentials_required
def get(self, workflow_uuid, participant_id, job_name):
parser = reqparse.RequestParser()
- parser.add_argument('type', type=str, location='args',
+ parser.add_argument('type',
+ type=str,
+ location='args',
required=True,
choices=('Ratio', 'Numeric'),
help='Visualization type is required. Choices: '
- 'Rate, Ratio, Numeric, Time, Timer')
- parser.add_argument('interval', type=str, location='args',
+ 'Rate, Ratio, Numeric, Time, Timer')
+ parser.add_argument('interval',
+ type=str,
+ location='args',
default='',
help='Time bucket interval length, '
- 'defaults to be automated by Kibana.')
- parser.add_argument('x_axis_field', type=str, location='args',
+ 'defaults to be automated by Kibana.')
+ parser.add_argument('x_axis_field',
+ type=str,
+ location='args',
default='tags.event_time',
help='Time field (X axis) is required.')
- parser.add_argument('query', type=str, location='args',
- help='Additional query string to the graph.')
- parser.add_argument('start_time', type=int, location='args',
+ parser.add_argument('query', type=str, location='args', help='Additional query string to the graph.')
+ parser.add_argument('start_time',
+ type=int,
+ location='args',
default=-1,
help='Earliest time of data.'
- 'Unix timestamp in secs.')
- parser.add_argument('end_time', type=int, location='args',
+ 'Unix timestamp in secs.')
+ parser.add_argument('end_time',
+ type=int,
+ location='args',
default=-1,
help='Latest time of data.'
- 'Unix timestamp in secs.')
+ 'Unix timestamp in secs.')
# Ratio visualization
- parser.add_argument('numerator', type=str, location='args',
+ parser.add_argument('numerator',
+ type=str,
+ location='args',
help='Numerator is required in Ratio '
- 'visualization. '
- 'A query string similar to args::query.')
- parser.add_argument('denominator', type=str, location='args',
+ 'visualization. '
+ 'A query string similar to args::query.')
+ parser.add_argument('denominator',
+ type=str,
+ location='args',
help='Denominator is required in Ratio '
- 'visualization. '
- 'A query string similar to args::query.')
+ 'visualization. '
+ 'A query string similar to args::query.')
# Numeric visualization
- parser.add_argument('aggregator', type=str, location='args',
+ parser.add_argument('aggregator',
+ type=str,
+ location='args',
default='Average',
- choices=('Average', 'Sum', 'Max', 'Min', 'Variance',
- 'Std. Deviation', 'Sum of Squares'),
+ choices=('Average', 'Sum', 'Max', 'Min', 'Variance', 'Std. Deviation', 'Sum of Squares'),
help='Aggregator type is required in Numeric and '
- 'Timer visualization.')
- parser.add_argument('value_field', type=str, location='args',
+ 'Timer visualization.')
+ parser.add_argument('value_field',
+ type=str,
+ location='args',
help='The field to be aggregated on is required '
- 'in Numeric visualization.')
+ 'in Numeric visualization.')
args = parser.parse_args()
- workflow = Workflow.query.filter_by(uuid=workflow_uuid).first()
- if workflow is None:
- raise NotFoundException(
- f'Failed to find workflow: {workflow_uuid}')
- project_config = workflow.project.get_config()
- party = project_config.participants[participant_id]
- client = RpcClient(project_config, party)
- resp = client.get_job_kibana(job_name, json.dumps(args))
- if resp.status.code != common_pb2.STATUS_SUCCESS:
- raise InternalException(resp.status.msg)
- metrics = json.loads(resp.metrics)
- # metrics is a list of 2-element lists,
- # each 2-element list is a [x, y] pair.
- return {'data': metrics}
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).filter_by(uuid=workflow_uuid).first()
+ if workflow is None:
+ raise NotFoundException(f'Failed to find workflow: {workflow_uuid}')
+ participant = session.query(Participant).filter_by(id=participant_id).first()
+ client = RpcClient.from_project_and_participant(workflow.project.name, workflow.project.token,
+ participant.domain_name)
+ resp = client.get_job_kibana(job_name, json.dumps(args))
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ raise InternalException(resp.status.msg)
+ metrics = json.loads(resp.metrics)
+ # metrics is a list of 2-element lists,
+ # each 2-element list is a [x, y] pair.
+ return {'data': metrics}
def initialize_job_apis(api):
api.add_resource(JobApi, '/jobs/')
- api.add_resource(PodLogApi,
- '/jobs//pods//log')
- api.add_resource(JobLogApi,
- '/jobs//log')
- api.add_resource(JobMetricsApi,
- '/jobs//metrics')
- api.add_resource(KibanaMetricsApi,
- '/jobs//kibana_metrics')
- api.add_resource(PeerJobMetricsApi,
- '/workflows//peer_workflows'
- '//jobs//metrics')
- api.add_resource(PeerKibanaMetricsApi,
- '/workflows//peer_workflows'
- '//jobs/'
- '/kibana_metrics')
+ api.add_resource(PodLogApi, '/jobs//pods//log')
+ api.add_resource(JobLogApi, '/jobs//log')
+ api.add_resource(JobMetricsApi, '/jobs//metrics')
+ api.add_resource(KibanaMetricsApi, '/jobs//kibana_metrics')
+ api.add_resource(
+ PeerJobMetricsApi, '/workflows//peer_workflows'
+ '//jobs//metrics')
+ api.add_resource(
+ PeerKibanaMetricsApi, '/workflows//peer_workflows'
+ '//jobs/'
+ '/kibana_metrics')
api.add_resource(JobEventApi, '/jobs//events')
- api.add_resource(PeerJobEventsApi,
- '/workflows//peer_workflows'
- '//jobs//events')
+ api.add_resource(
+ PeerJobEventsApi, '/workflows//peer_workflows'
+ '//jobs//events')
diff --git a/web_console_v2/api/fedlearner_webconsole/job/apis_test.py b/web_console_v2/api/fedlearner_webconsole/job/apis_test.py
new file mode 100644
index 000000000..103290ad7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/apis_test.py
@@ -0,0 +1,83 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.proto.job_pb2 import PodPb
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from testing.common import BaseTestCase
+
+
+class JobApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ @patch('fedlearner_webconsole.job.apis.JobService.get_pods')
+ def test_get_job(self, mock_get_pods):
+ mock_get_pods.return_value = [PodPb(name='test', pod_type='a')]
+ created_at = datetime(2021, 10, 1, 8, 8, 8, tzinfo=timezone.utc)
+ with db.session_scope() as session:
+ job = Job(id=1,
+ name='test',
+ job_type=JobType.DATA_JOIN,
+ state=JobState.COMPLETED,
+ workflow_id=1,
+ project_id=1,
+ created_at=created_at,
+ updated_at=created_at)
+ session.add(job)
+ session.commit()
+ resp = self.get_helper('/api/v2/jobs/1')
+ data = self.get_response_data(resp)
+ self.assertEqual(
+ data, {
+ 'complete_at': 0,
+ 'start_at': 0,
+ 'crd_kind': '',
+ 'crd_meta': {
+ 'api_version': ''
+ },
+ 'created_at': to_timestamp(created_at),
+ 'id': 1,
+ 'is_disabled': False,
+ 'job_type': 'DATA_JOIN',
+ 'name': 'test',
+ 'pods': [{
+ 'creation_timestamp': 0,
+ 'message': '',
+ 'name': 'test',
+ 'pod_ip': '',
+ 'pod_type': 'a',
+ 'state': ''
+ }],
+ 'project_id': 1,
+ 'snapshot': '',
+ 'state': 'COMPLETED',
+ 'updated_at': to_timestamp(created_at),
+ 'workflow_id': 1,
+ 'error_message': {
+ 'app': '',
+ 'pods': {}
+ }
+ })
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/controller.py b/web_console_v2/api/fedlearner_webconsole/job/controller.py
new file mode 100644
index 000000000..da8c646d9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/controller.py
@@ -0,0 +1,148 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple, Optional
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.job.service import JobService
+from fedlearner_webconsole.job.yaml_formatter import YamlFormatterService
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+from fedlearner_webconsole.job.utils import DurationState, emit_job_duration_store
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.workflow import build_job_name
+
+
+def _are_peers_ready(session: Session, project: Project, job_name: str) -> bool:
+ service = ParticipantService(session)
+ participants = service.get_platform_participants_by_project(project.id)
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(project.name, project.token, participant.domain_name)
+ resp = client.check_job_ready(job_name)
+ # Fallback solution: we think peer is ready if rpc fails
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ emit_store('job.controller.check_peer_ready_failed', 1)
+ continue
+ if not resp.is_ready:
+ return False
+ return True
+
+
+def schedule_job(unused_session: Session, job: Job):
+ del unused_session
+ if job.is_disabled:
+ # No action
+ return
+ # COMPLETED/FAILED Job State can be scheduled since stop action will
+ # not change the state of completed or failed job
+ assert job.state in [JobState.NEW, JobState.STOPPED, JobState.COMPLETED, JobState.FAILED]
+ job.snapshot = None
+ # Marks the job to be scheduled
+ job.state = JobState.WAITING
+ job.error_message = None
+
+
+def start_job_if_ready(session: Session, job: Job) -> Tuple[bool, Optional[str]]:
+ """Schedules a job for execution.
+
+ Returns:
+ Job readiness and the related message.
+ """
+ if job.state != JobState.WAITING:
+ return False, f'Invalid job state: {job.id} {job.state}'
+
+ # Checks readiness locally
+ if not JobService(session).is_ready(job):
+ return False, None
+ config = job.get_config()
+ if config.is_federated:
+ # Checks peers' readiness for federated job
+ if not _are_peers_ready(session, job.project, job.name):
+ return False, None
+
+ _start_job(session, job)
+ return True, job.error_message
+
+
+def _start_job(session: Session, job: Job):
+ """Starts a job locally."""
+ try:
+ assert job.state == JobState.WAITING, 'Job state should be WAITING'
+ # Builds yaml by template and submits it to k8s
+ yaml = YamlFormatterService(session).generate_job_run_yaml(job)
+ job.build_crd_service().create_app(yaml)
+ # Updates job status if submitting successfully
+ job.state = JobState.STARTED
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'Start job {job.id} has error msg: {e.args}')
+ job.error_message = str(e)
+
+
+def stop_job(unused_session: Session, job: Job):
+ del unused_session # Unused for now, this argument is to let invoker commit after this function
+ if job.state not in [JobState.WAITING, JobState.STARTED, JobState.COMPLETED, JobState.FAILED]:
+ logging.warning('illegal job state, name: %s, state: %s', job.name, job.state)
+ return
+ # state change:
+ # WAITING -> NEW
+ # STARTED -> STOPPED
+ # COMPLETED/FAILED unchanged
+ if job.state == JobState.STARTED:
+ JobService.set_status_to_snapshot(job)
+ job.build_crd_service().delete_app()
+ job.state = JobState.STOPPED
+ emit_job_duration_store(to_timestamp(now()) - to_timestamp(job.created_at),
+ job_name=job.name,
+ state=DurationState.STOPPED)
+ if job.state == JobState.WAITING:
+ # This change to make sure no effect on waiting jobs
+ job.state = JobState.NEW
+
+
+def create_job_without_workflow(session: Session,
+ job_def: JobDefinition,
+ project_id: int,
+ name: Optional[str] = None,
+ uuid: Optional[str] = None) -> Optional[Job]:
+ """Create a job without workflow.
+ Args:
+ session: db session, must be committed after this function return.
+ job_def: JobDefinition. job_def.yaml_template should not use any variables of workflow.
+ project_id: int indicate a project.
+ name: the unique name of the job overriding the default name {uuid}-{job_def.name}
+ uuid: {uuid}-{job_def.name} will be the unique name of the job. When job_def.is_federated is True,
+ participants in the project must have a job with the same name.
+ Returns:
+ Optional[Job]
+ """
+ if name is None:
+ if uuid is None:
+ uuid = resource_uuid()
+ name = build_job_name(uuid, job_def.name)
+ job = session.query(Job).filter_by(name=name).first()
+ if job is not None:
+ return None
+ job = Job(name=name, job_type=JobType(job_def.job_type), workflow_id=0, project_id=project_id, state=JobState.NEW)
+ JobService.set_config_and_crd_info(job, job_def)
+ session.add(job)
+ return job
diff --git a/web_console_v2/api/fedlearner_webconsole/job/controller_test.py b/web_console_v2/api/fedlearner_webconsole/job/controller_test.py
new file mode 100644
index 000000000..5c3d94573
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/controller_test.py
@@ -0,0 +1,275 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch, Mock, call
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.controller import _start_job, schedule_job, stop_job, _are_peers_ready, \
+ start_job_if_ready, create_job_without_workflow
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.job.yaml_formatter import YamlFormatterService
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.service_pb2 import CheckJobReadyResponse
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.utils.const import DEFAULT_OWNER_FOR_JOB_WITHOUT_WORKFLOW
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.workflow.models import Workflow # pylint: disable=unused-import
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ScheduleJobTest(NoWebServerTestCase):
+
+ def test_schedule_job_disabled(self):
+ with db.session_scope() as session:
+ job = Job(id=1, is_disabled=True, state=JobState.NEW)
+ schedule_job(session, job)
+ # No change
+ self.assertEqual(job.state, JobState.NEW)
+
+ def test_schedule_job_invalid_state(self):
+ with db.session_scope() as session:
+ job = Job(id=1, state=JobState.STARTED)
+ self.assertRaises(AssertionError, lambda: schedule_job(session, job))
+
+ def test_schedule_job_successfully(self):
+ with db.session_scope() as session:
+ job = Job(id=1, state=JobState.NEW, snapshot='test snapshot')
+ job.set_config(JobDefinition())
+ schedule_job(session, job)
+ self.assertIsNone(job.snapshot)
+ self.assertEqual(job.state, JobState.WAITING)
+
+ @patch('fedlearner_webconsole.job.controller.RpcClient.from_project_and_participant')
+ def test_are_peers_ready(self, mock_rpc_client_factory: Mock):
+ project_id = 1
+ with db.session_scope() as session:
+ participant_1 = Participant(id=1, name='participant 1', domain_name='p1.fedlearner.net')
+ participant_2 = Participant(id=2, name='participant 2', domain_name='p2.fedlearner.net')
+ project = Project(id=project_id, name='project 1')
+ session.add_all([
+ participant_1, participant_2, project,
+ ProjectParticipant(project_id=1, participant_id=1),
+ ProjectParticipant(project_id=1, participant_id=2)
+ ])
+ session.commit()
+
+ mock_check_job_ready = MagicMock()
+ mock_rpc_client_factory.return_value = MagicMock(check_job_ready=mock_check_job_ready)
+
+ job_name = 'fake_job_name'
+ with db.session_scope() as session:
+ project = session.query(Project).get(project_id)
+ # gRPC error
+ mock_check_job_ready.side_effect = [
+ CheckJobReadyResponse(status=common_pb2.Status(code=common_pb2.STATUS_UNKNOWN_ERROR)),
+ CheckJobReadyResponse(is_ready=True)
+ ]
+ self.assertTrue(_are_peers_ready(session, project, job_name))
+ mock_check_job_ready.assert_has_calls([call(job_name), call(job_name)])
+ # Not ready
+ mock_check_job_ready.side_effect = [
+ CheckJobReadyResponse(is_ready=False),
+ CheckJobReadyResponse(is_ready=True)
+ ]
+ self.assertFalse(_are_peers_ready(session, project, job_name))
+ # Ready
+ mock_check_job_ready.side_effect = [
+ CheckJobReadyResponse(is_ready=True),
+ CheckJobReadyResponse(is_ready=True)
+ ]
+ self.assertTrue(_are_peers_ready(session, project, job_name))
+
+
+class StartJobTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.job.controller.YamlFormatterService', spec=YamlFormatterService)
+ @patch('fedlearner_webconsole.job.controller.Job.build_crd_service')
+ def test_start_job_successfully(self, mock_crd_service, mock_formatter_class):
+ mock_formatter = mock_formatter_class.return_value
+ mock_formatter.generate_job_run_yaml.return_value = 'fake job yaml'
+ mock_crd_service.return_value = MagicMock(create_app=MagicMock(return_value=None))
+ with db.session_scope() as session:
+ job = Job(id=123,
+ name='test job',
+ job_type=JobType.RAW_DATA,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ session.add(job)
+ session.commit()
+ _start_job(session, job)
+ session.commit()
+ # Checks result
+ mock_crd_service.return_value.create_app.assert_called_with('fake job yaml')
+ with db.session_scope() as session:
+ job = session.query(Job).get(123)
+ self.assertEqual(job.state, JobState.STARTED)
+ self.assertIsNone(job.error_message)
+
+ @patch('fedlearner_webconsole.job.controller.YamlFormatterService', spec=YamlFormatterService)
+ @patch('fedlearner_webconsole.job.controller.Job.build_crd_service')
+ def test_start_job_exception(self, mock_crd_service, mock_formatter_class):
+ mock_formatter = mock_formatter_class.return_value
+ mock_formatter.generate_job_run_yaml.return_value = 'fake job yaml'
+ mock_crd_service.return_value = MagicMock(create_app=MagicMock(return_value=None))
+ mock_crd_service.return_value.create_app.side_effect = RuntimeError('some errors in k8s')
+ with db.session_scope() as session:
+ job = Job(id=123,
+ name='test job',
+ job_type=JobType.RAW_DATA,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ session.add(job)
+ session.commit()
+ _start_job(session, job)
+ session.commit()
+ # Checks result
+ mock_crd_service.return_value.create_app.assert_called_with('fake job yaml')
+ with db.session_scope() as session:
+ job = session.query(Job).get(123)
+ self.assertEqual(job.state, JobState.WAITING)
+ self.assertEqual(job.error_message, 'some errors in k8s')
+
+
+class StopJobTest(NoWebServerTestCase):
+
+ def test_stop_job_invalid_state(self):
+ with db.session_scope() as session:
+ job = Job(id=1, state=JobState.NEW)
+ stop_job(session, job)
+ # No change
+ self.assertEqual(job.state, JobState.NEW)
+
+ @patch('fedlearner_webconsole.job.controller.Job.build_crd_service')
+ @patch('fedlearner_webconsole.job.controller.JobService.set_status_to_snapshot')
+ def test_stop_job_started(self, mock_set_status_to_snapshot: Mock, mock_build_crd_service: Mock):
+ mock_delete_app = MagicMock()
+ mock_build_crd_service.return_value = MagicMock(delete_app=mock_delete_app)
+
+ with db.session_scope() as session:
+ job = Job(id=1, name='test-job', state=JobState.STARTED, created_at=now())
+ stop_job(session, job)
+ mock_set_status_to_snapshot.assert_called_once_with(job)
+ mock_delete_app.assert_called_once()
+ self.assertEqual(job.state, JobState.STOPPED)
+
+ def test_stop_job_waiting(self):
+ with db.session_scope() as session:
+ job = Job(id=1, name='test-job', state=JobState.WAITING, created_at=now())
+ stop_job(session, job)
+ self.assertEqual(job.state, JobState.NEW)
+
+ def test_stop_job_completed(self):
+ with db.session_scope() as session:
+ job = Job(id=1, name='test-job', state=JobState.COMPLETED, created_at=now())
+ stop_job(session, job)
+ # No change
+ self.assertEqual(job.state, JobState.COMPLETED)
+
+ @patch('fedlearner_webconsole.job.controller._start_job')
+ @patch('fedlearner_webconsole.job.controller._are_peers_ready')
+ @patch('fedlearner_webconsole.job.controller.JobService.is_ready')
+ def test_start_job_if_ready(self, mock_is_ready: Mock, mock_are_peers_ready: Mock, mock_start_job: Mock):
+ with db.session_scope() as session:
+ not_ready_job = Job(id=2,
+ name='not_ready_job',
+ job_type=JobType.RAW_DATA,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ mock_is_ready.return_value = False
+ res = start_job_if_ready(session, not_ready_job)
+ self.assertEqual(res, (False, None))
+ peers_not_ready_job = Job(id=3,
+ name='peers_not_ready_job',
+ job_type=JobType.PSI_DATA_JOIN,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ mock_is_ready.return_value = True
+ mock_are_peers_ready.return_value = False
+ peers_not_ready_job.set_config(JobDefinition(is_federated=True))
+ res = start_job_if_ready(session, peers_not_ready_job)
+ self.assertEqual(res, (False, None))
+ peers_ready_job = Job(id=4,
+ name='peers_ready_job',
+ job_type=JobType.PSI_DATA_JOIN,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ mock_are_peers_ready.return_value = True
+ peers_ready_job.set_config(JobDefinition(is_federated=True))
+ res = start_job_if_ready(session, peers_ready_job)
+ self.assertEqual(res, (True, None))
+ running_job = Job(id=3002,
+ name='running_job',
+ job_type=JobType.RAW_DATA,
+ state=JobState.STARTED,
+ workflow_id=1,
+ project_id=1)
+ running_job.set_config(JobDefinition(is_federated=True))
+ res = start_job_if_ready(session, running_job)
+ self.assertEqual(res, (False, 'Invalid job state: 3002 JobState.STARTED'))
+ start_job_calls = [call[0][1].id for call in mock_start_job.call_args_list]
+ self.assertCountEqual(start_job_calls, [peers_ready_job.id])
+
+ def test_create_job_without_workflow(self):
+ with db.session_scope() as session:
+ project = Project(id=1, name='project 1')
+ session.add(project)
+ job_def = JobDefinition(name='lonely_job', job_type=JobDefinition.ANALYZER)
+ job_def.yaml_template = """
+ {
+ "apiVersion": "sparkoperator.k8s.io/v1beta2",
+ "kind": "SparkApplication",
+ "metadata": {
+ "name": self.name,
+ },
+ }
+ """
+ job = create_job_without_workflow(
+ session,
+ job_def=job_def,
+ project_id=1,
+ )
+ session.commit()
+ self.assertEqual(job.crd_kind, 'SparkApplication')
+ yaml = YamlFormatterService(session).generate_job_run_yaml(job)
+ self.assertEqual(yaml['metadata']['labels']['owner'], DEFAULT_OWNER_FOR_JOB_WITHOUT_WORKFLOW)
+ with db.session_scope() as session:
+ job_def.yaml_template = """
+ {
+ "apiVersion": "sparkoperator.k8s.io/v1beta2",
+ "kind": "SparkApplication",
+ "metadata": {
+ "name": self.name,
+ "namespace": workflow.name
+ },
+
+ }
+ """
+ job = create_job_without_workflow(session, job_def=job_def, project_id=1)
+ session.commit()
+ self.assertEqual(job.crd_kind, 'SparkApplication')
+ with self.assertRaisesRegex(ValueError, 'Invalid python dict placeholder error msg: workflow.name'):
+ YamlFormatterService(session).generate_job_run_yaml(job)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/crd.py b/web_console_v2/api/fedlearner_webconsole/job/crd.py
new file mode 100644
index 000000000..e0e4837e7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/crd.py
@@ -0,0 +1,64 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Optional
+
+from fedlearner_webconsole.k8s.k8s_cache import k8s_cache
+from fedlearner_webconsole.k8s.models import CrdKind, SparkApp, FlApp, FedApp, UnknownCrd
+from fedlearner_webconsole.k8s.k8s_client import k8s_client
+from fedlearner_webconsole.utils.metrics import emit_store
+
+CRD_CLASS_MAP = {
+ CrdKind.FLAPP: FlApp,
+ CrdKind.SPARKAPPLICATION: SparkApp,
+ CrdKind.FEDAPP: FedApp,
+ CrdKind.UNKNOWN: UnknownCrd
+}
+
+
+class CrdService(object):
+
+ def __init__(self, kind: str, api_version: str, app_name: str):
+ self.kind = kind
+ # only un-UNKNOWN kind crd support complete/failed/pods detail
+ # UNKNOWN only support create and delete
+ self.supported_kind = CrdKind.from_value(kind)
+ self.api_version = api_version
+ self.plural = f'{kind.lower()}s'
+ self.group, _, self.version = api_version.partition('/')
+ self.app_name = app_name
+
+ def get_k8s_app(self, snapshot: Optional[dict]):
+ if snapshot is None:
+ snapshot = self.get_k8s_app_cache()
+ return CRD_CLASS_MAP[self.supported_kind].from_json(snapshot)
+
+ def delete_app(self):
+ emit_store('job.crd_service.deletion', value=1, tags={'name': self.app_name, 'plural': self.plural})
+ k8s_client.delete_app(self.app_name, self.group, self.version, self.plural)
+
+ def create_app(self, yaml: dict):
+ emit_store('job.crd_service.submission', value=1, tags={'name': self.app_name, 'plural': self.plural})
+ k8s_client.create_app(yaml, self.group, self.version, self.plural)
+
+ def get_k8s_app_cache(self):
+ if self.supported_kind == CrdKind.UNKNOWN:
+ try:
+ return {'app': k8s_client.get_custom_object(self.app_name, self.group, self.version, self.plural)}
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'Get app detail failed: {str(e)}')
+ return {'app': {}}
+ return k8s_cache.get_cache(self.app_name)
diff --git a/web_console_v2/api/fedlearner_webconsole/job/event_listener.py b/web_console_v2/api/fedlearner_webconsole/job/event_listener.py
new file mode 100644
index 000000000..87e424404
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/event_listener.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.models import Job, JobState
+from fedlearner_webconsole.job.service import JobService
+from fedlearner_webconsole.k8s.event_listener import EventListener
+from fedlearner_webconsole.k8s.k8s_cache import Event, ObjectType
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.workflow.service import WorkflowService
+from fedlearner_webconsole.workflow.workflow_job_controller import stop_workflow
+
+
+class JobEventListener(EventListener):
+
+ def update(self, event: Event):
+ # TODO(xiangyuxuan.prs): recompose the JobEventListener
+ valid_obj_type = [ObjectType.FLAPP, ObjectType.SPARKAPP, ObjectType.FEDAPP]
+ if event.obj_type not in valid_obj_type:
+ return
+ logging.debug('[k8s_watcher][job_event_listener]receive event %s', event.app_name)
+
+ with db.session_scope() as session:
+ job = session.query(Job).filter_by(name=event.app_name).first()
+ if job is None:
+ return
+ old_state = job.state
+ result_state = JobService(session).update_running_state(event.app_name)
+ wid = job.workflow_id
+ session.commit()
+
+ # trigger workflow state change
+ if old_state != result_state and result_state in [JobState.COMPLETED, JobState.FAILED]:
+ with db.session_scope() as session:
+ w = session.query(Workflow).get(wid)
+ logging.info(f'[JobEventListener] {w.uuid} should be stopped.')
+ if WorkflowService(session).should_auto_stop(w):
+ stop_workflow(wid)
diff --git a/web_console_v2/api/fedlearner_webconsole/job/metrics.py b/web_console_v2/api/fedlearner_webconsole/job/metrics.py
index cda672b54..c9a8d1565 100644
--- a/web_console_v2/api/fedlearner_webconsole/job/metrics.py
+++ b/web_console_v2/api/fedlearner_webconsole/job/metrics.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,43 +13,60 @@
# limitations under the License.
# coding: utf-8
-from datetime import datetime
-
import mpld3
+from datetime import datetime
+from typing import List
from matplotlib.figure import Figure
-
-from fedlearner_webconsole.job.models import JobType
from fedlearner_webconsole.utils.es import es
+from fedlearner_webconsole.job.models import Job, JobType
+from fedlearner_webconsole.utils.job_metrics import get_feature_importance
+from fedlearner_webconsole.proto.metrics_pb2 import ModelJobMetrics, Metric
+
+_CONF_METRIC_LIST = ['tp', 'tn', 'fp', 'fn']
+_TREE_METRIC_LIST = ['acc', 'auc', 'precision', 'recall', 'f1', 'ks', 'mse', 'msre', 'abs'] + _CONF_METRIC_LIST
+_NN_METRIC_LIST = ['acc', 'auc', 'loss', 'mse', 'abs']
class JobMetricsBuilder(object):
- def __init__(self, job):
+
+ def __init__(self, job: Job):
self._job = job
def _to_datetime(self, timestamp):
if timestamp is None:
return None
- return datetime.fromtimestamp(timestamp/1000.0)
+ return datetime.fromtimestamp(timestamp / 1000.0)
+
+ def _is_nn_job(self):
+ return self._job.job_type in [JobType.NN_MODEL_TRANINING, JobType.NN_MODEL_EVALUATION]
+
+ def _is_tree_job(self):
+ return self._job.job_type in [JobType.TREE_MODEL_TRAINING, JobType.TREE_MODEL_EVALUATION]
+
+ def query_metrics(self):
+ if self._is_tree_job():
+ return self.query_tree_metrics(need_feature_importance=True)
+ if self._is_nn_job():
+ return self.query_nn_metrics()
+ return []
def plot_metrics(self, num_buckets=30):
+ figs = []
if self._job.job_type == JobType.DATA_JOIN:
- metrics = self.plot_data_join_metrics(num_buckets)
- elif self._job.job_type in [
- JobType.NN_MODEL_TRANINING, JobType.NN_MODEL_EVALUATION]:
- metrics = self.plot_nn_metrics(num_buckets)
- elif self._job.job_type in [JobType.TREE_MODEL_TRAINING,
- JobType.TREE_MODEL_EVALUATION]:
- metrics = self.plot_tree_metrics()
+ figs = self.plot_data_join_metrics(num_buckets)
+ elif self._is_nn_job():
+ metrics = self.query_nn_metrics(num_buckets)
+ figs = self.plot_nn_metrics(metrics)
+ elif self._is_tree_job():
+ metrics = self.query_tree_metrics(False)
+ figs = self.plot_tree_metrics(metrics)
elif self._job.job_type == JobType.RAW_DATA:
- metrics = self.plot_raw_data_metrics(num_buckets)
- else:
- metrics = []
- return metrics
+ figs = self.plot_raw_data_metrics(num_buckets)
+ return figs
def plot_data_join_metrics(self, num_buckets=30):
res = es.query_data_join_metrics(self._job.name, num_buckets)
- time_res = es.query_time_metrics(self._job.name, num_buckets,
- index='data_join*')
+ time_res = es.query_time_metrics(self._job.name, num_buckets, index='data_join*')
metrics = []
if not res['aggregations']['OVERALL']['buckets']:
return metrics
@@ -57,9 +74,7 @@ def plot_data_join_metrics(self, num_buckets=30):
# plot pie chart for overall join rate
overall = res['aggregations']['OVERALL']['buckets'][0]
labels = ['joined', 'fake', 'unjoined']
- sizes = [
- overall['JOINED']['doc_count'], overall['FAKE']['doc_count'],
- overall['UNJOINED']['doc_count']]
+ sizes = [overall['JOINED']['doc_count'], overall['FAKE']['doc_count'], overall['UNJOINED']['doc_count']]
fig = Figure()
ax = fig.add_subplot(111)
ax.pie(sizes, labels=labels, autopct='%1.1f%%')
@@ -73,16 +88,14 @@ def plot_data_join_metrics(self, num_buckets=30):
et_unjoined = [buck['UNJOINED']['doc_count'] for buck in by_et]
fig = Figure()
ax = fig.add_subplot(111)
- ax.stackplot(
- et_index, et_joined, et_faked, et_unjoined, labels=labels)
+ ax.stackplot(et_index, et_joined, et_faked, et_unjoined, labels=labels)
twin_ax = ax.twinx()
twin_ax.patch.set_alpha(0.0)
et_rate = [buck['JOIN_RATE']['value'] for buck in by_et]
et_rate_fake = [buck['JOIN_RATE_WITH_FAKE']['value'] for buck in by_et]
twin_ax.plot(et_index, et_rate, label='join rate', color='black')
- twin_ax.plot(et_index, et_rate_fake,
- label='join rate w/ fake', color='#8f8f8f') # grey color
+ twin_ax.plot(et_index, et_rate_fake, label='join rate w/ fake', color='#8f8f8f') # grey color
ax.xaxis_date()
ax.legend()
@@ -94,53 +107,123 @@ def plot_data_join_metrics(self, num_buckets=30):
return metrics
- def plot_nn_metrics(self, num_buckets=30):
- res = es.query_nn_metrics(self._job.name, num_buckets)
- metrics = []
- if not res['aggregations']['PROCESS_TIME']['buckets']:
- return metrics
-
- buckets = res['aggregations']['PROCESS_TIME']['buckets']
- time = [self._to_datetime(buck['key']) for buck in buckets]
-
- # plot auc curve
- auc = [buck['AUC']['value'] for buck in buckets]
- fig = Figure()
- ax = fig.add_subplot(111)
- ax.plot(time, auc, label='auc')
- ax.legend()
- metrics.append(mpld3.fig_to_dict(fig))
-
+ def query_nn_metrics(self, num_buckets: int = 30) -> ModelJobMetrics:
+ res = es.query_nn_metrics(job_name=self._job.name, metric_list=_NN_METRIC_LIST, num_buckets=num_buckets)
+ metrics = ModelJobMetrics()
+ aggregations = res['aggregations']
+ for metric in _NN_METRIC_LIST:
+ buckets = aggregations[metric]['PROCESS_TIME']['buckets']
+ if len(buckets) == 0:
+ continue
+ times = [buck['key'] for buck in buckets]
+ values = [buck['VALUE']['value'] for buck in buckets]
+ # filter none value in times and values
+ time_values = [(t, v) for t, v in zip(times, values) if t is not None and v is not None]
+ times, values = zip(*time_values)
+ if len(values) == 0:
+ continue
+ metrics.train[metric].steps.extend(times)
+ metrics.train[metric].values.extend(values)
+ metrics.eval[metric].steps.extend(times)
+ metrics.eval[metric].values.extend(values)
return metrics
- def plot_tree_metrics(self):
- metric_list = ['acc', 'auc', 'precision', 'recall',
- 'f1', 'ks', 'mse', 'msre', 'abs']
- metrics = []
- aggregations = es.query_tree_metrics(self._job.name, metric_list)
- for name in metric_list:
+ def plot_nn_metrics(self, metrics: ModelJobMetrics):
+ figs = []
+ for name in metrics.train:
+ fig = Figure()
+ ax = fig.add_subplot(111)
+ timestamp = [self._to_datetime(t) for t in metrics.train[name].steps]
+ values = metrics.train[name].values
+ ax.plot(timestamp, values, label=name)
+ ax.legend()
+ figs.append(mpld3.fig_to_dict(fig))
+ return figs
+
+ @staticmethod
+ def _average_value_by_iteration(metrics: [List[int], List[int]]) -> [List[int], List[int]]:
+ iter_to_value = {}
+ for iteration, value in zip(*metrics):
+ if iteration not in iter_to_value:
+ iter_to_value[iteration] = []
+ iter_to_value[iteration].append(value)
+ iterations = []
+ values = []
+ for key, value_list in iter_to_value.items():
+ iterations.append(key)
+ values.append(sum(value_list) / len(value_list))
+ return [iterations, values]
+
+ def _get_iter_val(self, records: dict) -> Metric:
+ iterations = [item['_source']['tags']['iteration'] for item in records]
+ values = [item['_source']['value'] for item in records]
+ iterations, values = self._average_value_by_iteration([iterations, values])
+ return Metric(steps=iterations, values=values)
+
+ @staticmethod
+ def _set_confusion_metric(metrics: ModelJobMetrics):
+
+ def _is_training() -> bool:
+ iter_vals = metrics.train.get('tp')
+ if iter_vals is not None and len(iter_vals.values) > 0:
+ return True
+ return False
+
+ def _get_last_values(name: str, is_training: bool) -> int:
+ if is_training:
+ iter_vals = metrics.train.get(name)
+ else:
+ iter_vals = metrics.eval.get(name)
+ if iter_vals is not None and len(iter_vals.values) > 0:
+ return int(iter_vals.values[-1])
+ return 0
+
+ _is_training = _is_training()
+ metrics.confusion_matrix.tp = _get_last_values('tp', _is_training)
+ metrics.confusion_matrix.tn = _get_last_values('tn', _is_training)
+ metrics.confusion_matrix.fp = _get_last_values('fp', _is_training)
+ metrics.confusion_matrix.fn = _get_last_values('fn', _is_training)
+ # remove confusion relevant metrics from train metrics
+ for key in _CONF_METRIC_LIST:
+ metrics.train.pop(key)
+ metrics.eval.pop(key)
+
+ def query_tree_metrics(self, need_feature_importance=False) -> ModelJobMetrics:
+ job_name = self._job.name
+ aggregations = es.query_tree_metrics(job_name, _TREE_METRIC_LIST)['aggregations']
+ metrics = ModelJobMetrics()
+ for name in _TREE_METRIC_LIST:
train_ = aggregations[name.upper()]['TRAIN']['TOP']['hits']['hits']
eval_ = aggregations[name.upper()]['EVAL']['TOP']['hits']['hits']
- if len(train_) == 0 and len(eval_) == 0:
+ if len(train_) > 0:
+ metrics.train[name].MergeFrom(self._get_iter_val(train_))
+ if len(eval_) > 0:
+ metrics.eval[name].MergeFrom(self._get_iter_val(eval_))
+ self._set_confusion_metric(metrics)
+ if need_feature_importance:
+ metrics.feature_importance.update(get_feature_importance(self._job))
+ return metrics
+
+ def plot_tree_metrics(self, metrics: ModelJobMetrics):
+ metric_list = set.union(set(metrics.train.keys()), set(metrics.eval.keys()))
+ figs = []
+ for name in metric_list:
+ train_metric = metrics.train.get(name)
+ eval_metric = metrics.eval.get(name)
+ if train_metric is None and eval_metric is None:
continue
fig = Figure()
ax = fig.add_subplot(111)
- if len(train_) > 0:
- train_metric = [(item['_source']['tags']['iteration'],
- item['_source']['value'])
- for item in train_]
- ax.plot(*zip(*train_metric), label='train', color='blue')
- if len(eval_) > 0:
- eval_metric = [(item['_source']['tags']['iteration'],
- item['_source']['value'])
- for item in eval_]
- ax.plot(*zip(*eval_metric), label='eval', color='red')
+ if train_metric is not None:
+ ax.plot(train_metric.steps, train_metric.values, label='train', color='blue')
+ if eval_metric is not None:
+ ax.plot(eval_metric.steps, eval_metric.values, label='eval', color='red')
ax.legend()
ax.set_title(name)
ax.set_xlabel('iteration')
ax.set_ylabel('value')
- metrics.append(mpld3.fig_to_dict(fig))
- return metrics
+ figs.append(mpld3.fig_to_dict(fig))
+ return figs
def plot_raw_data_metrics(self, num_buckets=30):
res = es.query_time_metrics(self._job.name, num_buckets)
@@ -158,19 +241,9 @@ def _plot_pt_vs_et(self, res):
for buck in by_pt]
fig = Figure()
ax = fig.add_subplot(111)
- pt_index = [
- idx for idx, time in zip(pt_index, pt_min) if time is not None
- ]
- ax.plot(
- pt_index,
- list(filter(lambda x: x is not None, pt_min)),
- label='min event time'
- )
- ax.plot(
- pt_index,
- list(filter(lambda x: x is not None, pt_max)),
- label='max event time'
- )
+ pt_index = [idx for idx, time in zip(pt_index, pt_min) if time is not None]
+ ax.plot(pt_index, list(filter(lambda x: x is not None, pt_min)), label='min event time')
+ ax.plot(pt_index, list(filter(lambda x: x is not None, pt_max)), label='max event time')
ax.xaxis_date()
ax.yaxis_date()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/metrics_test.py b/web_console_v2/api/fedlearner_webconsole/job/metrics_test.py
new file mode 100644
index 000000000..b446b008e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/metrics_test.py
@@ -0,0 +1,279 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import time
+import unittest
+from unittest.mock import patch
+from http import HTTPStatus
+
+from testing.common import BaseTestCase, TestAppProcess
+from testing.test_data import es_query_result
+from fedlearner_webconsole.proto import workflow_definition_pb2
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.job.models import Job, JobType
+from fedlearner_webconsole.job.metrics import JobMetricsBuilder
+from fedlearner_webconsole.utils.proto import to_dict
+
+
+@unittest.skip('require es client')
+class SkippedJobMetricsBuilderTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ ES_HOST = ''
+ ES_PORT = 80
+
+ class FollowerConfig(Config):
+ GRPC_LISTEN_PORT = 4990
+
+ def test_data_join_metrics(self):
+ job = Job(name='multi-indices-test27', job_type=JobType.DATA_JOIN)
+ import json # pylint: disable=import-outside-toplevel
+ print(json.dumps(JobMetricsBuilder(job).plot_metrics()))
+
+ def test_nn_metrics(self):
+ job = Job(name='automl-2782410011', job_type=JobType.NN_MODEL_TRANINING)
+ print(JobMetricsBuilder(job).plot_metrics())
+
+ def test_peer_metrics(self):
+ proc = TestAppProcess(JobMetricsBuilderTest, 'follower_test_peer_metrics', JobMetricsBuilderTest.FollowerConfig)
+ proc.start()
+ self.leader_test_peer_metrics()
+ proc.terminate()
+
+ def leader_test_peer_metrics(self):
+ self.setup_project('leader', JobMetricsBuilderTest.FollowerConfig.GRPC_LISTEN_PORT)
+ workflow = Workflow(name='test-workflow', project_id=1)
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.commit()
+
+ while True:
+ resp = self.get_helper('/api/v2/workflows/1/peer_workflows/0/jobs/test-job/metrics')
+ if resp.status_code == HTTPStatus.OK:
+ break
+ time.sleep(1)
+
+ def follower_test_peer_metrics(self):
+ self.setup_project('follower', JobMetricsBuilderTest.Config.GRPC_LISTEN_PORT)
+ with db.session_scope() as session:
+ workflow = Workflow(name='test-workflow', project_id=1, metric_is_public=True)
+ workflow.set_job_ids([1])
+ session.add(workflow)
+ job = Job(name='automl-2782410011',
+ job_type=JobType.NN_MODEL_TRANINING,
+ workflow_id=1,
+ project_id=1,
+ config=workflow_definition_pb2.JobDefinition(name='test-job').SerializeToString())
+ session.add(job)
+ session.commit()
+
+ while True:
+ time.sleep(1)
+
+
+_EXPECTED_TREE_METRICS_RESULT = {
+ 'train': {
+ 'ks': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.47770564314760644, 0.5349813321918623, 0.5469192171410906, 0.5596894247461416, 0.5992009702504102,
+ 0.6175715202967825, 0.6366317091151221, 0.6989964566835509, 0.7088535349932226, 0.7418848541057288
+ ]
+ },
+ 'recall': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.40186915887850466, 0.4252336448598131, 0.45794392523364486, 0.46261682242990654, 0.5233644859813084,
+ 0.514018691588785, 0.5093457943925234, 0.5373831775700935, 0.5467289719626168, 0.5654205607476636
+ ]
+ },
+ 'acc': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [0.857, 0.862, 0.868, 0.872, 0.886, 0.883, 0.884, 0.895, 0.896, 0.902]
+ },
+ 'auc': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.8011640626857863, 0.8377684240565029, 0.8533328577203871, 0.860663242253454, 0.8797977455946351,
+ 0.8921428741290338, 0.9041610187629308, 0.9179270409740553, 0.928827495184419, 0.9439282062257736
+ ]
+ },
+ 'precision': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.8514851485148515, 0.8584905660377359, 0.8596491228070176, 0.8839285714285714, 0.9032258064516129,
+ 0.8943089430894309, 0.9083333333333333, 0.9504132231404959, 0.9435483870967742, 0.9603174603174603
+ ]
+ },
+ 'f1': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.546031746031746, 0.56875, 0.5975609756097561, 0.607361963190184, 0.6627218934911242,
+ 0.6528189910979227, 0.6526946107784432, 0.6865671641791044, 0.6923076923076923, 0.711764705882353
+ ]
+ }
+ },
+ 'confusion_matrix': {
+ 'tp': 121,
+ 'tn': 781,
+ 'fp': 5,
+ 'fn': 93
+ },
+ 'feature_importance': {
+ 'x': 0.3
+ },
+ 'eval': {}
+}
+
+
+class JobMetricsBuilderTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.job.metrics.es.query_nn_metrics')
+ def test_query_and_plot_nn_metrics(self, mock_es_query):
+ mock_es_query.return_value = es_query_result.fake_es_query_nn_metrics_result
+ job = Job(name='test-job', job_type=JobType.NN_MODEL_TRANINING)
+ metrics = JobMetricsBuilder(job).query_nn_metrics()
+ self.assertEqual(
+ to_dict(metrics), {
+ 'train': {
+ 'loss': {
+ 'steps': [
+ 1645093650000.0, 1645093655000.0, 1645093660000.0, 1645093665000.0, 1645093670000.0,
+ 1645093675000.0, 1645093680000.0, 1645093685000.0, 1645093690000.0, 1645093695000.0
+ ],
+ 'values': [
+ 1.8112774487783219, 0.8499700573859391, 0.5077963560819626, 0.4255857397157412,
+ 0.3902850116000456, 0.3689204063266516, 0.34096595416776837, 0.3247630867641419,
+ 0.3146447554727395, 0.3103061146461047
+ ]
+ },
+ 'acc': {
+ 'steps': [
+ 1645093650000.0, 1645093655000.0, 1645093660000.0, 1645093665000.0, 1645093670000.0,
+ 1645093675000.0, 1645093680000.0, 1645093685000.0, 1645093690000.0, 1645093695000.0
+ ],
+ 'values': [
+ 0.37631335140332667, 0.6482393520849722, 0.749889914331765, 0.7920331122783514,
+ 0.8848890877571427, 0.8932028951744239, 0.8983024559915066, 0.9003030106425285,
+ 0.9026716228326161, 0.9047519653053074
+ ]
+ }
+ },
+ 'eval': {
+ 'loss': {
+ 'steps': [
+ 1645093650000.0, 1645093655000.0, 1645093660000.0, 1645093665000.0, 1645093670000.0,
+ 1645093675000.0, 1645093680000.0, 1645093685000.0, 1645093690000.0, 1645093695000.0
+ ],
+ 'values': [
+ 1.8112774487783219, 0.8499700573859391, 0.5077963560819626, 0.4255857397157412,
+ 0.3902850116000456, 0.3689204063266516, 0.34096595416776837, 0.3247630867641419,
+ 0.3146447554727395, 0.3103061146461047
+ ]
+ },
+ 'acc': {
+ 'steps': [
+ 1645093650000.0, 1645093655000.0, 1645093660000.0, 1645093665000.0, 1645093670000.0,
+ 1645093675000.0, 1645093680000.0, 1645093685000.0, 1645093690000.0, 1645093695000.0
+ ],
+ 'values': [
+ 0.37631335140332667, 0.6482393520849722, 0.749889914331765, 0.7920331122783514,
+ 0.8848890877571427, 0.8932028951744239, 0.8983024559915066, 0.9003030106425285,
+ 0.9026716228326161, 0.9047519653053074
+ ]
+ }
+ },
+ 'feature_importance': {}
+ })
+ figs = JobMetricsBuilder(job).plot_nn_metrics(metrics)
+ self.assertEqual(len(figs), 2)
+
+ @patch('fedlearner_webconsole.job.metrics.get_feature_importance')
+ @patch('fedlearner_webconsole.job.metrics.es.query_tree_metrics')
+ def test_query_and_plot_tree_metrics(self, mock_es_query, mock_get_importance):
+ mock_es_query.return_value = es_query_result.fake_es_query_tree_metrics_result
+ mock_get_importance.return_value = {'x': 0.3}
+ job = Job(name='test-job', job_type=JobType.TREE_MODEL_TRAINING)
+ metrics = JobMetricsBuilder(job).query_tree_metrics(need_feature_importance=True)
+ self.assertEqual(to_dict(metrics), _EXPECTED_TREE_METRICS_RESULT)
+ figs = JobMetricsBuilder(job).plot_tree_metrics(metrics=metrics)
+ self.assertEqual(len(figs), 6)
+
+ @patch('fedlearner_webconsole.job.metrics.JobMetricsBuilder.query_nn_metrics')
+ @patch('fedlearner_webconsole.job.metrics.JobMetricsBuilder.query_tree_metrics')
+ def test_query_metrics(self, mock_tree_metrics, mock_nn_metrics):
+ mock_tree_metrics.return_value = {'data': 'tree_metrics'}
+ mock_nn_metrics.return_value = {'data': 'nn_metrics'}
+ treejob = Job(name='test-tree-job', job_type=JobType.TREE_MODEL_TRAINING)
+ metrics = JobMetricsBuilder(treejob).query_metrics()
+ self.assertEqual(metrics, {'data': 'tree_metrics'})
+
+ nnjob = Job(name='test-nn-job', job_type=JobType.NN_MODEL_TRANINING)
+ metrics = JobMetricsBuilder(nnjob).query_metrics()
+ self.assertEqual(metrics, {'data': 'nn_metrics'})
+
+ @patch('fedlearner_webconsole.job.metrics.get_feature_importance')
+ @patch('fedlearner_webconsole.job.metrics.es.query_tree_metrics')
+ def test_query_and_plot_eval_tree_metrics(self, mock_es_query, mock_get_importance):
+ mock_es_query.return_value = es_query_result.fake_es_query_eval_tree_metrics_result
+ mock_get_importance.return_value = {'x': 0.3}
+ job = Job(name='test-job', job_type=JobType.TREE_MODEL_TRAINING)
+ metrics = JobMetricsBuilder(job).query_tree_metrics(need_feature_importance=True)
+ self.assertEqual(
+ to_dict(metrics), {
+ 'eval': {
+ 'auc': {
+ 'steps': [10.0],
+ 'values': [0.7513349869345765]
+ },
+ 'recall': {
+ 'steps': [10.0],
+ 'values': [0.2176754973809691]
+ },
+ 'f1': {
+ 'steps': [10.0],
+ 'values': [0.327016797789616]
+ },
+ 'ks': {
+ 'steps': [10.0],
+ 'values': [0.375900675399236]
+ },
+ 'acc': {
+ 'steps': [10.0],
+ 'values': [0.8019642162921606]
+ },
+ 'precision': {
+ 'steps': [10.0],
+ 'values': [0.6587757792451808]
+ }
+ },
+ 'confusion_matrix': {
+ 'tp': 179,
+ 'tn': 2827,
+ 'fp': 93,
+ 'fn': 649
+ },
+ 'feature_importance': {
+ 'x': 0.3
+ },
+ 'train': {}
+ })
+ figs = JobMetricsBuilder(job).plot_tree_metrics(metrics=metrics)
+ self.assertEqual(len(figs), 6)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/model_test.py b/web_console_v2/api/fedlearner_webconsole/job/model_test.py
new file mode 100644
index 000000000..5cc6b2338
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/model_test.py
@@ -0,0 +1,89 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch, MagicMock
+
+from fedlearner_webconsole.k8s.models import Pod, PodState, ContainerState
+from fedlearner_webconsole.proto.job_pb2 import CrdMetaData, JobPb, JobErrorMessage
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.workflow.models import Workflow # pylint: disable=unused-import
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ModelTest(NoWebServerTestCase):
+
+ def test_is_training_job(self):
+ job = Job()
+ job.job_type = JobType.NN_MODEL_TRANINING
+ self.assertTrue(job.is_training_job())
+ job.job_type = JobType.TREE_MODEL_TRAINING
+ self.assertTrue(job.is_training_job())
+ job.job_type = JobType.TREE_MODEL_EVALUATION
+ self.assertFalse(job.is_training_job())
+
+ def test_get_job_crdmeta(self):
+ job = Job()
+ job.set_crd_meta(CrdMetaData(api_version='a/b'))
+ self.assertEqual(job.get_crd_meta(), CrdMetaData(api_version='a/b'))
+
+ def test_to_proto(self):
+ created_at = datetime(2021, 10, 1, 8, 8, 8, tzinfo=timezone.utc)
+ job = Job(id=1,
+ name='test',
+ job_type=JobType.DATA_JOIN,
+ state=JobState.COMPLETED,
+ workflow_id=1,
+ project_id=1,
+ created_at=created_at,
+ updated_at=created_at)
+ expected_job_proto = JobPb(id=1,
+ name='test',
+ job_type=JobDefinition.DATA_JOIN,
+ state='COMPLETED',
+ workflow_id=1,
+ project_id=1,
+ crd_meta=CrdMetaData(),
+ created_at=to_timestamp(created_at),
+ updated_at=to_timestamp(created_at),
+ error_message=JobErrorMessage())
+ self.assertEqual(job.to_proto(), expected_job_proto)
+
+ @patch('fedlearner_webconsole.job.models.Job.get_k8s_app')
+ def test_get_error_message_with_pods(self, mock_get_k8s_app):
+ fake_pods = [
+ Pod(name='pod0',
+ container_states=[ContainerState(state='terminated', message='00031003')],
+ state=PodState.FAILED),
+ Pod(name='pod1', container_states=[ContainerState(state='terminated')], state=PodState.FAILED),
+ Pod(name='pod2',
+ container_states=[ContainerState(state='terminated', message='Completed')],
+ state=PodState.SUCCEEDED)
+ ]
+ mock_get_k8s_app.return_value = MagicMock(pods=fake_pods)
+ job = Job(error_message='test', state=JobState.FAILED)
+ self.assertEqual(job.get_error_message_with_pods(),
+ JobErrorMessage(app='test', pods={'pod0': 'terminated:00031003'}))
+ job.error_message = None
+ self.assertEqual(job.get_error_message_with_pods(), JobErrorMessage(pods={'pod0': 'terminated:00031003'}))
+ mock_get_k8s_app.return_value = MagicMock(pods=[])
+ self.assertEqual(job.get_error_message_with_pods(), JobErrorMessage())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/models.py b/web_console_v2/api/fedlearner_webconsole/job/models.py
index c9b00aff6..d00479640 100644
--- a/web_console_v2/api/fedlearner_webconsole/job/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/job/models.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +13,22 @@
# limitations under the License.
# coding: utf-8
-import datetime
-import logging
import enum
import json
+from typing import Optional
+
+from google.protobuf import text_format
from sqlalchemy.sql import func
from sqlalchemy.sql.schema import Index
+from fedlearner_webconsole.job.crd import CrdService
+from fedlearner_webconsole.k8s.models import K8sApp, PodState
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
from fedlearner_webconsole.utils.mixins import to_dict_mixin
from fedlearner_webconsole.db import db
-from fedlearner_webconsole.k8s.models import FlApp, Pod, FlAppState
-from fedlearner_webconsole.utils.k8s_client import k8s_client
from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.proto.job_pb2 import CrdMetaData, JobPb, JobErrorMessage
class JobState(enum.Enum):
@@ -35,13 +39,13 @@ class JobState(enum.Enum):
# 4. WAITING -> NEW: triggered by user, stop workflow
# 4. STARTED -> STOPPED: triggered by user, stop workflow
# 5. STARTED -> COMPLETED/FAILED: triggered by k8s_watcher
- INVALID = 0 # INVALID STATE
- STOPPED = 1 # STOPPED BY USER
- WAITING = 2 # SCHEDULED, WAITING FOR RUNNING
- STARTED = 3 # RUNNING
- NEW = 4 # BEFORE SCHEDULE
- COMPLETED = 5 # SUCCEEDED JOB
- FAILED = 6 # FAILED JOB
+ INVALID = 0 # INVALID STATE
+ STOPPED = 1 # STOPPED BY USER
+ WAITING = 2 # SCHEDULED, WAITING FOR RUNNING
+ STARTED = 3 # RUNNING
+ NEW = 4 # BEFORE SCHEDULE
+ COMPLETED = 5 # SUCCEEDED JOB
+ FAILED = 6 # FAILED JOB
# must be consistent with JobType in proto
@@ -54,22 +58,12 @@ class JobType(enum.Enum):
TREE_MODEL_TRAINING = 5
NN_MODEL_EVALUATION = 6
TREE_MODEL_EVALUATION = 7
+ TRANSFORMER = 8
+ ANALYZER = 9
+ CUSTOMIZED = 10
-def merge(x, y):
- """Given two dictionaries, merge them into a new dict as a shallow copy."""
- z = x.copy()
- z.update(y)
- return z
-
-
-@to_dict_mixin(
- extras={
- 'state': (lambda job: job.get_state_for_frontend()),
- 'pods': (lambda job: job.get_pods_for_frontend()),
- 'config': (lambda job: job.get_config()),
- 'complete_at': (lambda job: job.get_complete_at())
- })
+@to_dict_mixin(ignores=['config'], extras={'complete_at': (lambda job: job.get_complete_at())})
class Job(db.Model):
__tablename__ = 'job_v2'
__table_args__ = (Index('idx_workflow_id', 'workflow_id'), {
@@ -77,15 +71,12 @@ class Job(db.Model):
'mysql_engine': 'innodb',
'mysql_charset': 'utf8mb4',
})
- id = db.Column(db.Integer,
- primary_key=True,
- autoincrement=True,
- comment='id')
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
name = db.Column(db.String(255), unique=True, comment='name')
- job_type = db.Column(db.Enum(JobType, native_enum=False),
+ job_type = db.Column(db.Enum(JobType, native_enum=False, create_constraint=False),
nullable=False,
comment='job type')
- state = db.Column(db.Enum(JobState, native_enum=False),
+ state = db.Column(db.Enum(JobState, native_enum=False, create_constraint=False),
nullable=False,
default=JobState.INVALID,
comment='state')
@@ -95,157 +86,113 @@ class Job(db.Model):
workflow_id = db.Column(db.Integer, nullable=False, comment='workflow id')
project_id = db.Column(db.Integer, nullable=False, comment='project id')
- flapp_snapshot = db.Column(db.Text(16777215), comment='flapp snapshot')
- pods_snapshot = db.Column(db.Text(16777215), comment='pods snapshot')
+ flapp_snapshot = db.Column(db.Text(16777215), comment='flapp snapshot') # deprecated
+ sparkapp_snapshot = db.Column(db.Text(16777215), comment='sparkapp snapshot') # deprecated
+ # Format like {'app': app_status_dict, 'pods': {'items': pod_list}}.
+ snapshot = db.Column(db.Text(16777215), comment='snapshot')
error_message = db.Column(db.Text(), comment='error message')
+ crd_meta = db.Column(db.Text(), comment='metadata')
+ # Use string but not enum, in order to support all kinds of crd to create and delete,
+ # but only FLApp SparkApplication and FedApp support getting pods and auto finish.
+ crd_kind = db.Column(db.String(255), comment='kind')
- created_at = db.Column(db.DateTime(timezone=True),
- server_default=func.now(),
- comment='created at')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created at')
updated_at = db.Column(db.DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
comment='updated at')
deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')
- project = db.relationship('Project',
- primaryjoin='Project.id == '
- 'foreign(Job.project_id)')
- workflow = db.relationship('Workflow',
- primaryjoin='Workflow.id == '
- 'foreign(Job.workflow_id)')
+ project = db.relationship(Project.__name__, primaryjoin='Project.id == ' 'foreign(Job.project_id)')
+ workflow = db.relationship('Workflow', primaryjoin='Workflow.id == ' 'foreign(Job.workflow_id)')
- def get_config(self):
+ def get_config(self) -> Optional[JobDefinition]:
if self.config is not None:
proto = JobDefinition()
proto.ParseFromString(self.config)
return proto
return None
- def set_config(self, proto):
+ def set_config(self, proto: JobDefinition):
if proto is not None:
self.config = proto.SerializeToString()
else:
self.config = None
- def _set_snapshot_flapp(self):
- def default(o):
- if isinstance(o, (datetime.date, datetime.datetime)):
- return o.isoformat()
- return str(o)
-
- flapp = k8s_client.get_flapp(self.name)
- if flapp:
- self.flapp_snapshot = json.dumps(flapp, default=default)
- else:
- self.flapp_snapshot = None
-
- def get_flapp_details(self):
- if self.state == JobState.STARTED:
- flapp = k8s_client.get_flapp(self.name)
- elif self.flapp_snapshot is not None:
- flapp = json.loads(self.flapp_snapshot)
- # aims to support old job
- if 'flapp' not in flapp:
- flapp['flapp'] = None
- if 'pods' not in flapp and self.pods_snapshot:
- flapp['pods'] = json.loads(self.pods_snapshot)['pods']
- else:
- flapp = {'flapp': None, 'pods': {'items': []}}
- return flapp
-
- def get_pods_for_frontend(self, include_private_info=True):
- flapp_details = self.get_flapp_details()
- flapp = FlApp.from_json(flapp_details.get('flapp', None))
- pods_json = None
- if 'pods' in flapp_details:
- pods_json = flapp_details['pods'].get('items', None)
- pods = []
- if pods_json is not None:
- pods = [Pod.from_json(p) for p in pods_json]
-
- # deduplication pods both in pods and flapp
- result = {}
- for pod in flapp.pods:
- result[pod.name] = pod
- for pod in pods:
- result[pod.name] = pod
- return [pod.to_dict(include_private_info) for pod in result.values()]
-
- def get_state_for_frontend(self):
- return self.state.name
-
- def is_flapp_failed(self):
- # TODO: make the getter more efficient
- flapp = FlApp.from_json(self.get_flapp_details()['flapp'])
- return flapp.state in [FlAppState.FAILED, FlAppState.SHUTDOWN]
-
- def is_flapp_complete(self):
- # TODO: make the getter more efficient
- flapp = FlApp.from_json(self.get_flapp_details()['flapp'])
- return flapp.state == FlAppState.COMPLETED
-
- def get_complete_at(self):
- # TODO: make the getter more efficient
- flapp = FlApp.from_json(self.get_flapp_details()['flapp'])
- return flapp.completed_at
-
- def stop(self):
- if self.state not in [JobState.WAITING, JobState.STARTED,
- JobState.COMPLETED, JobState.FAILED]:
- logging.warning('illegal job state, name: %s, state: %s',
- self.name, self.state)
- return
- if self.state == JobState.STARTED:
- self._set_snapshot_flapp()
- k8s_client.delete_flapp(self.name)
- # state change:
- # WAITING -> NEW
- # STARTED -> STOPPED
- # COMPLETED/FAILED unchanged
- if self.state == JobState.STARTED:
- self.state = JobState.STOPPED
- if self.state == JobState.WAITING:
- self.state = JobState.NEW
-
- def schedule(self):
- # COMPLETED/FAILED Job State can be scheduled since stop action
- # will not change the state of completed or failed job
- assert self.state in [JobState.NEW, JobState.STOPPED,
- JobState.COMPLETED, JobState.FAILED]
- self.pods_snapshot = None
- self.flapp_snapshot = None
- self.state = JobState.WAITING
-
- def start(self):
- assert self.state == JobState.WAITING
- self.state = JobState.STARTED
-
- def complete(self):
- assert self.state == JobState.STARTED, 'Job State is not STARTED'
- self._set_snapshot_flapp()
- k8s_client.delete_flapp(self.name)
- self.state = JobState.COMPLETED
-
- def fail(self):
- assert self.state == JobState.STARTED, 'Job State is not STARTED'
- self._set_snapshot_flapp()
- k8s_client.delete_flapp(self.name)
- self.state = JobState.FAILED
+ # TODO(xiangyuxuan.prs): Remove this func and get_completed_at from model to service.
+ def get_k8s_app(self) -> K8sApp:
+ snapshot = None
+ if self.state != JobState.STARTED:
+ snapshot = self.snapshot or '{}'
+ snapshot = json.loads(snapshot)
+ return self.build_crd_service().get_k8s_app(snapshot)
+
+ def build_crd_service(self) -> CrdService:
+ if self.crd_kind is not None:
+ return CrdService(self.crd_kind, self.get_crd_meta().api_version, self.name)
+ # TODO(xiangyuxuan.prs): Adapt to old data, remove in the future.
+ if self.job_type in [JobType.TRANSFORMER]:
+ return CrdService('SparkApplication', 'sparkoperator.k8s.io/v1beta2', self.name)
+ return CrdService('FLApp', 'fedlearner.k8s.io/v1alpha1', self.name)
+
+ def is_training_job(self):
+ return self.job_type in [JobType.NN_MODEL_TRANINING, JobType.TREE_MODEL_TRAINING]
+
+ def get_complete_at(self) -> Optional[int]:
+ crd_obj = self.get_k8s_app()
+ return crd_obj.completed_at
+
+ def get_start_at(self) -> int:
+ crd_obj = self.get_k8s_app()
+ return crd_obj.creation_timestamp
+
+ def get_crd_meta(self) -> CrdMetaData:
+ crd_meta_obj = CrdMetaData()
+ if self.crd_meta is not None:
+ return text_format.Parse(self.crd_meta, crd_meta_obj)
+ return crd_meta_obj
+
+ def set_crd_meta(self, crd_meta: Optional[CrdMetaData] = None):
+ if crd_meta is None:
+ crd_meta = CrdMetaData()
+ self.crd_meta = text_format.MessageToString(crd_meta)
+
+ def get_error_message_with_pods(self) -> JobErrorMessage:
+ failed_pods_msg = {}
+ for pod in self.get_k8s_app().pods:
+ if pod.state != PodState.FAILED:
+ continue
+ pod_error_msg = pod.get_message(include_private_info=True).summary
+ if pod_error_msg:
+ failed_pods_msg[pod.name] = pod_error_msg
+ return JobErrorMessage(app=self.error_message, pods=failed_pods_msg)
+
+ def to_proto(self) -> JobPb:
+ return JobPb(id=self.id,
+ name=self.name,
+ job_type=self.job_type.value,
+ state=self.state.name,
+ is_disabled=self.is_disabled,
+ workflow_id=self.workflow_id,
+ project_id=self.project_id,
+ snapshot=self.snapshot,
+ error_message=self.get_error_message_with_pods(),
+ crd_meta=self.get_crd_meta(),
+ crd_kind=self.crd_kind,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ complete_at=self.get_complete_at(),
+ start_at=self.get_start_at())
class JobDependency(db.Model):
__tablename__ = 'job_dependency_v2'
- __table_args__ = (Index('idx_src_job_id', 'src_job_id'),
- Index('idx_dst_job_id', 'dst_job_id'), {
- 'comment': 'record job dependencies',
- 'mysql_engine': 'innodb',
- 'mysql_charset': 'utf8mb4',
- })
- id = db.Column(db.Integer,
- primary_key=True,
- autoincrement=True,
- comment='id')
+ __table_args__ = (Index('idx_src_job_id', 'src_job_id'), Index('idx_dst_job_id', 'dst_job_id'), {
+ 'comment': 'record job dependencies',
+ 'mysql_engine': 'innodb',
+ 'mysql_charset': 'utf8mb4',
+ })
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
src_job_id = db.Column(db.Integer, comment='src job id')
dst_job_id = db.Column(db.Integer, comment='dst job id')
dep_index = db.Column(db.Integer, comment='dep index')
diff --git a/web_console_v2/api/fedlearner_webconsole/job/scheduler.py b/web_console_v2/api/fedlearner_webconsole/job/scheduler.py
new file mode 100644
index 000000000..d18476e06
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/scheduler.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.controller import start_job_if_ready
+from fedlearner_webconsole.job.models import Job, JobState
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, JobSchedulerOutput
+
+
+class JobScheduler(IRunnerV2):
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ with db.session_scope() as session:
+ waiting_jobs = [
+ jid
+ for jid, *_ in session.query(Job.id).filter(Job.state == JobState.WAITING, Job.is_disabled.is_(False))
+ ]
+ if waiting_jobs:
+ logging.info(f'[JobScheduler] Scheduling jobs {waiting_jobs}')
+ output = JobSchedulerOutput()
+ for job_id in waiting_jobs:
+ with db.session_scope() as session:
+ # Row lock to prevent other changes
+ job = session.query(Job).with_for_update().get(job_id)
+ ready, message = start_job_if_ready(session, job)
+ if ready:
+ if message:
+ output.failed_to_start_jobs.append(job_id)
+ else:
+ output.started_jobs.append(job_id)
+ if message:
+ output.messages[job_id] = message
+ session.commit()
+ return RunnerStatus.DONE, RunnerOutput(job_scheduler_output=output)
diff --git a/web_console_v2/api/fedlearner_webconsole/job/scheduler_test.py b/web_console_v2/api/fedlearner_webconsole/job/scheduler_test.py
new file mode 100644
index 000000000..9e1936cc0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/scheduler_test.py
@@ -0,0 +1,83 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, Mock
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.job.scheduler import JobScheduler
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, RunnerOutput, JobSchedulerOutput
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.workflow.models import Workflow # pylint: disable=unused-import
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class SchedulerTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.job.scheduler.start_job_if_ready')
+ def test_run(self, mock_start_job_if_ready: Mock):
+ with db.session_scope() as session:
+ ready_job = Job(id=1,
+ name='ready_job',
+ job_type=JobType.RAW_DATA,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ ready_job.set_config(JobDefinition(is_federated=False))
+ not_ready_job = Job(id=2,
+ name='not_ready_job',
+ job_type=JobType.RAW_DATA,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ ready_job_start_failed = Job(id=3,
+ name='ready_failed_job',
+ job_type=JobType.RAW_DATA,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=1)
+ session.add_all([ready_job, not_ready_job, ready_job_start_failed])
+ session.commit()
+
+ def fake_start_job_if_ready(session, job):
+ if job.name == ready_job_start_failed.name:
+ job.error_message = 'Failed to start'
+ return True, job.error_message
+ if job.name == ready_job.name:
+ return True, None
+ if job.name == not_ready_job.name:
+ return False, None
+ raise RuntimeError(f'Unknown job {job.name}')
+
+ mock_start_job_if_ready.side_effect = fake_start_job_if_ready
+
+ runner = JobScheduler()
+ context = RunnerContext(0, RunnerInput())
+ status, output = runner.run(context)
+ self.assertEqual(status, RunnerStatus.DONE)
+ self.assertEqual(
+ output,
+ RunnerOutput(job_scheduler_output=JobSchedulerOutput(started_jobs=[ready_job.id],
+ failed_to_start_jobs=[ready_job_start_failed.id],
+ messages={
+ ready_job_start_failed.id: 'Failed to start',
+ })))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/service.py b/web_console_v2/api/fedlearner_webconsole/job/service.py
index fc015dfb6..cf351c7dd 100644
--- a/web_console_v2/api/fedlearner_webconsole/job/service.py
+++ b/web_console_v2/api/fedlearner_webconsole/job/service.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,13 +13,27 @@
# limitations under the License.
# coding: utf-8
-
+import datetime
+import json
import logging
+from typing import List
+
from sqlalchemy.orm.session import Session
-from fedlearner_webconsole.rpc.client import RpcClient
-from fedlearner_webconsole.job.models import Job, JobDependency, JobState
-from fedlearner_webconsole.proto import common_pb2
-from fedlearner_webconsole.utils.metrics import emit_counter
+
+from fedlearner_webconsole.proto.job_pb2 import CrdMetaData, PodPb
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.job.models import Job, JobDependency, \
+ JobState
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.pp_yaml import compile_yaml_template
+from fedlearner_webconsole.job.utils import DurationState, emit_job_duration_store
+
+
+def serialize_to_json(o):
+ if isinstance(o, (datetime.date, datetime.datetime)):
+ return o.isoformat()
+ return str(o)
class JobService:
@@ -28,42 +42,91 @@ def __init__(self, session: Session):
self._session = session
def is_ready(self, job: Job) -> bool:
- deps = self._session.query(JobDependency).filter_by(
- dst_job_id=job.id).all()
+ deps = self._session.query(JobDependency).filter_by(dst_job_id=job.id).all()
for dep in deps:
src_job = self._session.query(Job).get(dep.src_job_id)
- assert src_job is not None, 'Job {} not found'.format(
- dep.src_job_id)
+ assert src_job is not None, f'Job {dep.src_job_id} not found'
if not src_job.state == JobState.COMPLETED:
return False
return True
- @staticmethod
- def is_peer_ready(job: Job) -> bool:
- project_config = job.project.get_config()
- for party in project_config.participants:
- client = RpcClient(project_config, party)
- resp = client.check_job_ready(job.name)
- if resp.status.code != common_pb2.STATUS_SUCCESS:
- emit_counter('check_peer_ready_failed', 1)
- return True
- if not resp.is_ready:
- return False
- return True
-
- def update_running_state(self, job_name):
+ def update_running_state(self, job_name: str) -> JobState:
job = self._session.query(Job).filter_by(name=job_name).first()
if job is None:
- emit_counter('[JobService]job_not_found', 1)
- return
+ emit_store('job.service.update_running_state_error',
+ 1,
+ tags={
+ 'job_name': job_name,
+ 'reason': 'job_not_found'
+ })
+ return None
if not job.state == JobState.STARTED:
- emit_counter('[JobService]wrong_job_state', 1)
- return
- if job.is_flapp_complete():
- job.complete()
- logging.debug('[JobService]change job %s state to %s',
- job.name, JobState(job.state))
- elif job.is_flapp_failed():
- job.fail()
- logging.debug('[JobService]change job %s state to %s',
- job.name, JobState(job.state))
+ emit_store('job.service.update_running_state_error',
+ 1,
+ tags={
+ 'job_name': job_name,
+ 'reason': 'wrong_job_state'
+ })
+ return job.state
+ if job.get_k8s_app().is_completed:
+ self.complete(job)
+ logging.debug('[JobService]change job %s state to %s', job.name, JobState(job.state))
+ elif job.get_k8s_app().is_failed:
+ self.fail(job)
+ logging.debug('[JobService]change job %s state to %s', job.name, JobState(job.state))
+ return job.state
+
+ @staticmethod
+ def get_pods(job: Job, include_private_info=True) -> List[PodPb]:
+ crd_obj = job.get_k8s_app()
+ if crd_obj:
+ return [pod.to_proto(include_private_info) for pod in crd_obj.pods]
+ return []
+
+ @staticmethod
+ def set_config_and_crd_info(job: Job, proto: JobDefinition):
+ job.set_config(proto)
+ yaml = {}
+ try:
+ yaml = compile_yaml_template(job.get_config().yaml_template, post_processors=[], ignore_variables=True)
+ except Exception as e: # pylint: disable=broad-except
+ # Don't raise exception because of old templates, default None will use FLApp.
+ logging.error(
+ f'Failed format yaml for job {job.name} when try to get the kind and api_version. msg: {str(e)}')
+ kind = yaml.get('kind', None)
+ api_version = yaml.get('apiVersion', None)
+ job.crd_kind = kind
+ job.set_crd_meta(CrdMetaData(api_version=api_version))
+
+ @staticmethod
+ def complete(job: Job):
+ assert job.state == JobState.STARTED, 'Job State is not STARTED'
+ JobService.set_status_to_snapshot(job)
+ job.build_crd_service().delete_app()
+ job.state = JobState.COMPLETED
+ emit_job_duration_store(duration=job.get_complete_at() - to_timestamp(job.created_at),
+ job_name=job.name,
+ state=DurationState.COMPLETED)
+
+ @staticmethod
+ def fail(job: Job):
+ assert job.state == JobState.STARTED, 'Job State is not STARTED'
+ JobService.set_status_to_snapshot(job)
+ job.build_crd_service().delete_app()
+ job.state = JobState.FAILED
+ job.error_message = job.get_k8s_app().error_message
+ emit_job_duration_store(duration=job.get_complete_at() - to_timestamp(job.created_at),
+ job_name=job.name,
+ state=DurationState.FAILURE)
+
+ @staticmethod
+ def set_status_to_snapshot(job: Job):
+ app = job.build_crd_service().get_k8s_app_cache()
+ job.snapshot = json.dumps(app, default=serialize_to_json)
+
+ @staticmethod
+ def get_job_yaml(job: Job) -> str:
+ # Can't query from k8s api server when job is not started.
+ if job.state != JobState.STARTED:
+ return job.snapshot or ''
+ return json.dumps(job.build_crd_service().get_k8s_app_cache(), default=serialize_to_json)
diff --git a/web_console_v2/api/fedlearner_webconsole/job/service_test.py b/web_console_v2/api/fedlearner_webconsole/job/service_test.py
new file mode 100644
index 000000000..34a064b48
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/service_test.py
@@ -0,0 +1,307 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import patch, MagicMock
+from datetime import datetime
+
+from fedlearner_webconsole.proto.job_pb2 import PodPb
+
+from fedlearner_webconsole.proto import workflow_definition_pb2
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.job.models import Job, JobDependency, JobType, JobState
+from fedlearner_webconsole.job.service import JobService
+from fedlearner_webconsole.k8s.models import FlApp
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class JobServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ workflow_0 = Workflow(id=0, name='test-workflow-0', project_id=0)
+ workflow_1 = Workflow(id=1, name='test-workflow-1', project_id=0)
+
+ config = workflow_definition_pb2.JobDefinition(name='test-job').SerializeToString()
+ job_0 = Job(id=0,
+ name='raw_data_0',
+ job_type=JobType.RAW_DATA,
+ state=JobState.STARTED,
+ workflow_id=0,
+ project_id=0,
+ config=config)
+ job_1 = Job(id=1,
+ name='raw_data_1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.COMPLETED,
+ workflow_id=0,
+ project_id=0,
+ config=config)
+ job_2 = Job(id=2,
+ name='data_join_0',
+ job_type=JobType.DATA_JOIN,
+ state=JobState.WAITING,
+ workflow_id=0,
+ project_id=0,
+ config=config)
+ job_3 = Job(id=3,
+ name='data_join_1',
+ job_type=JobType.DATA_JOIN,
+ state=JobState.COMPLETED,
+ workflow_id=1,
+ project_id=0,
+ config=config)
+ job_4 = Job(id=4,
+ name='train_job_0',
+ job_type=JobType.NN_MODEL_TRANINING,
+ state=JobState.WAITING,
+ workflow_id=1,
+ project_id=0,
+ config=config)
+
+ job_dep_0 = JobDependency(src_job_id=job_0.id, dst_job_id=job_2.id, dep_index=0)
+ job_dep_1 = JobDependency(src_job_id=job_1.id, dst_job_id=job_2.id, dep_index=1)
+ job_dep_2 = JobDependency(src_job_id=job_3.id, dst_job_id=job_4.id, dep_index=0)
+
+ with db.session_scope() as session:
+ session.add_all([workflow_0, workflow_1])
+ session.add_all([job_0, job_1, job_2, job_3, job_4])
+ session.add_all([job_dep_0, job_dep_1, job_dep_2])
+ session.commit()
+
+ def test_is_ready(self):
+ with db.session_scope() as session:
+ job_0 = session.query(Job).get(0)
+ job_2 = session.query(Job).get(2)
+ job_4 = session.query(Job).get(4)
+ job_service = JobService(session)
+ self.assertTrue(job_service.is_ready(job_0))
+ self.assertFalse(job_service.is_ready(job_2))
+ self.assertTrue(job_service.is_ready(job_4))
+
+ @patch('fedlearner_webconsole.job.models.Job.get_k8s_app')
+ def test_update_running_state(self, mock_crd):
+ with db.session_scope() as session:
+ job_0 = session.query(Job).get(0)
+ job_2 = session.query(Job).get(2)
+ job_service = JobService(session)
+ job_service.update_running_state(job_0.name)
+ self.assertEqual(job_0.state, JobState.COMPLETED)
+ self.assertTrue(job_service.is_ready(job_2))
+ job_0.state = JobState.STARTED
+ mock_crd.return_value = MagicMock(is_completed=False, is_failed=True, error_message=None)
+ job_service.update_running_state(job_0.name)
+ self.assertEqual(job_0.state, JobState.FAILED)
+ session.commit()
+
+ def test_get_pods(self):
+ creation_timestamp = datetime.utcnow()
+ fake_pods = \
+ {
+ 'pods': {
+ 'items': [
+ {
+ 'status': {
+ 'phase': 'Running',
+ 'pod_ip': '172.0.0.1',
+ },
+ 'metadata': {
+ 'labels': {'fl-replica-type': 'master'},
+ 'name': 'name1',
+ 'creation_timestamp': creation_timestamp,
+ },
+ 'spec': {
+ 'containers': [
+ {
+ 'name': 'fake_pod',
+ 'resources': {
+ 'limits': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ },
+ 'requests': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ }
+ }
+ }
+ ]
+ }
+ },
+ {
+ 'status': {
+ 'phase': 'Pending',
+ 'pod_ip': '172.0.0.1',
+ },
+ 'metadata': {
+ 'labels': {'fl-replica-type': 'master'},
+ 'name': 'name3',
+ 'creation_timestamp': creation_timestamp,
+ },
+ 'spec': {
+ 'containers': [
+ {
+ 'name': 'fake_pod',
+ 'resources': {
+ 'limits': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ },
+ 'requests': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ }
+ }
+ }
+ ]
+ }
+ },
+ {
+ 'status': {
+ 'phase': 'Succeeded',
+ 'pod_ip': '172.0.0.2',
+ },
+ 'metadata': {
+ 'labels': {'fl-replica-type': 'worker'},
+ 'name': 'name2',
+ 'creation_timestamp': creation_timestamp,
+ },
+ 'spec': {
+ 'containers': [
+ {
+ 'name': 'fake_pod',
+ 'resources': {
+ 'limits': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ },
+ 'requests': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ }
+ }
+ }
+ ]
+ }
+ }, {
+ 'status': {
+ 'phase': 'Running',
+ 'pod_ip': '172.0.0.2',
+ },
+ 'metadata': {
+ 'labels': {'fl-replica-type': 'worker'},
+ 'name': 'running_one',
+ 'creation_timestamp': creation_timestamp,
+ },
+ 'spec': {
+ 'containers': [
+ {
+ 'name': 'fake_pod',
+ 'resources': {
+ 'limits': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ },
+ 'requests': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ }
+ }
+ }
+ ]
+ }
+ },
+ ]
+ },
+ 'app': {
+ 'status': {
+ 'appState': 'FLStateComplete',
+ 'flReplicaStatus': {
+ 'Master': {
+ 'active': {
+ },
+ 'failed': {},
+ 'succeeded': {
+ 'name1': {}
+ }
+ },
+ 'Worker': {
+ 'active': {
+ 'running_one': {}
+ },
+ 'failed': {},
+ 'succeeded': {
+ 'name2': {}
+ }
+ }
+ }
+ }
+ }
+ }
+
+ expected_pods = [
+ PodPb(
+ name='name1',
+ pod_type='MASTER',
+ state='SUCCEEDED_AND_FREED',
+ pod_ip='172.0.0.1',
+ message='',
+ creation_timestamp=to_timestamp(creation_timestamp),
+ ),
+ PodPb(creation_timestamp=to_timestamp(creation_timestamp),
+ message='',
+ name='name3',
+ pod_ip='172.0.0.1',
+ pod_type='MASTER',
+ state='PENDING'),
+ PodPb(
+ name='name2',
+ pod_type='WORKER',
+ state='SUCCEEDED',
+ pod_ip='172.0.0.2',
+ message='',
+ creation_timestamp=to_timestamp(creation_timestamp),
+ ),
+ PodPb(
+ name='running_one',
+ pod_type='WORKER',
+ state='RUNNING',
+ pod_ip='172.0.0.2',
+ message='',
+ creation_timestamp=to_timestamp(creation_timestamp),
+ )
+ ]
+ fake_job = MagicMock()
+ fake_job.is_sparkapp = MagicMock(return_value=False)
+ fake_job.get_k8s_app = MagicMock(return_value=FlApp.from_json(fake_pods))
+ pods = JobService.get_pods(fake_job)
+ self.assertEqual(pods, expected_pods)
+
+ def test_get_job_yaml(self):
+ fake_job = MagicMock()
+ fake_job.state = JobState.STOPPED
+ fake_job.snapshot = 'test'
+ self.assertEqual(JobService.get_job_yaml(fake_job), 'test')
+ fake_job.state = JobState.STARTED
+ test_time = datetime.now()
+ fake_job.build_crd_service = MagicMock(return_value=MagicMock(get_k8s_app_cache=MagicMock(
+ return_value={'a': test_time})))
+ self.assertEqual(JobService.get_job_yaml(fake_job), f'{{"a": "{test_time.isoformat()}"}}')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/utils.py b/web_console_v2/api/fedlearner_webconsole/job/utils.py
new file mode 100644
index 000000000..d205dd676
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/utils.py
@@ -0,0 +1,27 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+from fedlearner_webconsole.utils.metrics import emit_store
+
+
+class DurationState(enum.Enum):
+ STOPPED = 'STOPPED'
+ COMPLETED = 'COMPLETED'
+ FAILURE = 'FAILURE'
+
+
+def emit_job_duration_store(duration: int, job_name: str, state: DurationState):
+ emit_store('job.duration', duration, tags={'job_name': job_name, 'state': state.name})
diff --git a/web_console_v2/api/fedlearner_webconsole/job/utils_test.py b/web_console_v2/api/fedlearner_webconsole/job/utils_test.py
new file mode 100644
index 000000000..22e04bab2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/utils_test.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from fedlearner_webconsole.job.utils import DurationState, emit_job_duration_store
+
+
+class UtilsTest(unittest.TestCase):
+
+ def test_emit_job_duration_store(self):
+ with self.assertLogs() as cm:
+ emit_job_duration_store(10, 'u466-test-job', DurationState.COMPLETED)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(
+ logs,
+ ["""[Metric][Store] job.duration: 10, tags={'job_name': 'u466-test-job', 'state': 'COMPLETED'}"""])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/job/yaml_formatter.py b/web_console_v2/api/fedlearner_webconsole/job/yaml_formatter.py
index bac0d80cc..80a4dbed2 100644
--- a/web_console_v2/api/fedlearner_webconsole/job/yaml_formatter.py
+++ b/web_console_v2/api/fedlearner_webconsole/job/yaml_formatter.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,100 +13,88 @@
# limitations under the License.
# coding: utf-8
+import base64
import json
import tarfile
from io import BytesIO
-import base64
-from string import Template
-from flatten_dict import flatten
-from fedlearner_webconsole.utils.system_envs import get_system_envs
-from fedlearner_webconsole.proto import common_pb2
-
-
-class _YamlTemplate(Template):
- delimiter = '$'
- # Which placeholders in the template should be interpreted
- idpattern = r'[a-zA-Z_\-\[0-9\]]+(\.[a-zA-Z_\-\[0-9\]]+)*'
+from fedlearner_webconsole.k8s.models import CrdKind
+from fedlearner_webconsole.rpc.client import gen_egress_authority
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.utils.const import DEFAULT_OWNER_FOR_JOB_WITHOUT_WORKFLOW
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.utils.pp_yaml import compile_yaml_template, \
+ add_username_in_label, GenerateDictService
-def format_yaml(yaml, **kwargs):
- """Formats a yaml template.
-
- Example usage:
- format_yaml('{"abc": ${x.y}}', x={'y': 123})
- output should be '{"abc": 123}'
- """
- template = _YamlTemplate(yaml)
- try:
- return template.substitute(flatten(kwargs or {},
- reducer='dot'))
- except KeyError as e:
- raise RuntimeError(
- 'Unknown placeholder: {}'.format(e.args[0])) from e
+CODE_TAR_FOLDER = 'code_tar'
+CODE_TAR_FILE_NAME = 'code_tar.tar.gz'
def make_variables_dict(variables):
- var_dict = {
- var.name: (
- code_dict_encode(json.loads(var.value))
- if var.value_type == common_pb2.Variable.ValueType.CODE \
- else var.value)
- for var in variables
- }
- return var_dict
-
-
-def generate_system_dict():
- return {'basic_envs': get_system_envs()}
-
-
-def generate_project_dict(proj):
- project = proj.to_dict()
- project['variables'] = make_variables_dict(
- proj.get_config().variables)
- participants = project['config']['participants']
- for index, participant in enumerate(participants):
- project[f'participants[{index}]'] = {}
- project[f'participants[{index}]']['egress_domain'] = \
- participant['domain_name']
- project[f'participants[{index}]']['egress_host'] = \
- participant['grpc_spec']['authority']
- return project
-
-
-def generate_workflow_dict(wf):
- workflow = wf.to_dict()
- workflow['variables'] = make_variables_dict(
- wf.get_config().variables)
- workflow['jobs'] = {}
- for j in wf.get_jobs():
- variables = make_variables_dict(j.get_config().variables)
- j_dic = j.to_dict()
- j_dic['variables'] = variables
- workflow['jobs'][j.get_config().name] = j_dic
- return workflow
-
-
-def generate_self_dict(j):
- job = j.to_dict()
- job['variables'] = make_variables_dict(
- j.get_config().variables
- )
- return job
+ var_dict = {}
+ for var in variables:
+ typed_value = to_dict(var.typed_value)
+ if var.value_type == common_pb2.Variable.CODE:
+ # if use or, then {} will be ignored.
+ var_dict[var.name] = code_dict_encode(typed_value if typed_value is not None else json.loads(var.value))
+ else:
+ var_dict[var.name] = typed_value if typed_value is not None else var.value
+ return var_dict
-def generate_job_run_yaml(job):
- yaml = format_yaml(job.get_config().yaml_template,
- workflow=generate_workflow_dict(job.workflow),
- project=generate_project_dict(job.project),
- system=generate_system_dict(),
- self=generate_self_dict(job))
- try:
- loaded = json.loads(yaml)
- except Exception as e: # pylint: disable=broad-except
- raise ValueError(f'Invalid json {repr(e)}: {yaml}')
- return loaded
+class YamlFormatterService:
+
+ def __init__(self, session):
+ self._session = session
+
+ @staticmethod
+ def generate_project_dict(proj):
+ project = to_dict(proj.to_proto())
+ variables = proj.get_variables()
+ project['variables'] = make_variables_dict(variables)
+ project['participants'] = []
+ for index, participant in enumerate(proj.participants):
+ # TODO(xiangyuxuan.prs): remove keys such as participants[0] in future.
+ project[f'participants[{index}]'] = {}
+ project[f'participants[{index}]']['egress_domain'] = \
+ participant.domain_name
+ project[f'participants[{index}]']['egress_host'] = gen_egress_authority(participant.domain_name)
+ project['participants'].append(project[f'participants[{index}]'])
+ return project
+
+ def generate_workflow_dict(self, wf: 'Workflow'):
+ workflow = wf.to_dict()
+ workflow['variables'] = make_variables_dict(wf.get_config().variables)
+ workflow['jobs'] = {}
+ jobs = wf.get_jobs(self._session)
+ for j in jobs:
+ variables = make_variables_dict(j.get_config().variables)
+ j_dic = j.to_dict()
+ j_dic['variables'] = variables
+ workflow['jobs'][j.get_config().name] = j_dic
+ return workflow
+
+ @staticmethod
+ def generate_self_dict(j: 'Job'):
+ job = j.to_dict()
+ job['variables'] = make_variables_dict(j.get_config().variables)
+ return job
+
+ def generate_job_run_yaml(self, job: 'Job') -> dict:
+ result_dict = compile_yaml_template(job.get_config().yaml_template,
+ use_old_formater=job.crd_kind is None or
+ job.crd_kind == CrdKind.FLAPP.value,
+ post_processors=[
+ lambda loaded_json: add_username_in_label(
+ loaded_json, job.workflow.creator
+ if job.workflow else DEFAULT_OWNER_FOR_JOB_WITHOUT_WORKFLOW)
+ ],
+ workflow=job.workflow and self.generate_workflow_dict(job.workflow),
+ project=self.generate_project_dict(job.project),
+ system=GenerateDictService(self._session).generate_system_dict(),
+ self=self.generate_self_dict(job))
+ return result_dict
def code_dict_encode(data_dict):
diff --git a/web_console_v2/api/fedlearner_webconsole/job/yaml_formatter_test.py b/web_console_v2/api/fedlearner_webconsole/job/yaml_formatter_test.py
new file mode 100644
index 000000000..16f41059c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/job/yaml_formatter_test.py
@@ -0,0 +1,188 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import base64
+import tarfile
+import unittest
+from unittest.mock import patch
+from envs import Envs
+from io import BytesIO
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.yaml_formatter import code_dict_encode, YamlFormatterService
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.setting_pb2 import SystemVariables
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.utils.pp_yaml import _format_yaml, GenerateDictService
+from fedlearner_webconsole.workflow.models import Workflow # pylint: disable=unused-import
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+BASE_DIR = Envs.BASE_DIR
+
+
+class YamlFormatterTest(NoWebServerTestCase):
+
+ def test_format_with_phs(self):
+ project = {'variables[0]': {'storage_root_dir': 'root_dir'}}
+ workflow = {'jobs': {'raw_data_job': {'name': 'raw_data123'}}}
+ yaml = _format_yaml("""
+ {
+ "name": "OUTPUT_BASE_DIR",
+ "value": "${project.variables[0].storage_root_dir}/raw_data/${workflow.jobs.raw_data_job.name}"
+ }
+ """,
+ project=project,
+ workflow=workflow)
+ self.assertEqual(
+ yaml, """
+ {
+ "name": "OUTPUT_BASE_DIR",
+ "value": "root_dir/raw_data/raw_data123"
+ }
+ """)
+
+ self.assertEqual(_format_yaml('$project.variables[0].storage_root_dir', project=project),
+ project['variables[0]']['storage_root_dir'])
+
+ def test_format_with_no_ph(self):
+ self.assertEqual(_format_yaml('{a: 123, b: 234}'), '{a: 123, b: 234}')
+
+ def test_format_yaml_unknown_ph(self):
+ x = {'y': 123}
+ with self.assertRaises(RuntimeError) as cm:
+ _format_yaml('$x.y is $i.j.k', x=x)
+ self.assertEqual(str(cm.exception), 'Unknown placeholder: i.j.k')
+ with self.assertRaises(RuntimeError) as cm:
+ _format_yaml('$x.y is ${i.j}', x=x)
+ self.assertEqual(str(cm.exception), 'Unknown placeholder: i.j')
+
+ def test_encode_code(self):
+ test_data = {'test/a.py': 'awefawefawefawefwaef', 'test1/b.py': 'asdfasd', 'c.py': '', 'test/d.py': 'asdf'}
+ code_base64 = code_dict_encode(test_data)
+ code_dict = {}
+ if code_base64.startswith('base64://'):
+ tar_binary = BytesIO(base64.b64decode(code_base64[9:]))
+ with tarfile.open(fileobj=tar_binary) as tar:
+ for file in tar.getmembers():
+ code_dict[file.name] = str(tar.extractfile(file).read(), encoding='utf-8')
+ self.assertEqual(code_dict, test_data)
+
+ def test_generate_self_dict(self):
+ config = {
+ 'variables': [{
+ 'name': 'namespace',
+ 'value': 'leader'
+ }, {
+ 'name': 'basic_envs',
+ 'value': '{}'
+ }, {
+ 'name': 'storage_root_dir',
+ 'value': '/'
+ }]
+ }
+ job = Job(name='aa', project_id=1, workflow_id=1, state=JobState.NEW)
+ job.set_config(ParseDict(config, JobDefinition()))
+ self.assertEqual(
+ YamlFormatterService.generate_self_dict(job), {
+ 'id': None,
+ 'crd_kind': None,
+ 'crd_meta': None,
+ 'name': 'aa',
+ 'job_type': None,
+ 'state': 'NEW',
+ 'is_disabled': None,
+ 'workflow_id': 1,
+ 'project_id': 1,
+ 'flapp_snapshot': None,
+ 'sparkapp_snapshot': None,
+ 'error_message': None,
+ 'created_at': None,
+ 'updated_at': None,
+ 'deleted_at': None,
+ 'complete_at': 0,
+ 'snapshot': None,
+ 'variables': {
+ 'namespace': 'leader',
+ 'basic_envs': '{}',
+ 'storage_root_dir': '/',
+ }
+ })
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_application_version')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_variables')
+ def test_generate_system_dict(self, mock_system_variables, mock_app_version):
+ data = ParseDict({'variables': [{'name': 'a', 'value': 'b'}]}, SystemVariables())
+ mock_system_variables.return_value = data
+ mock_app_version.return_value.version.version = '2.2.2.2'
+ with db.session_scope() as session:
+ system_dict = GenerateDictService(session).generate_system_dict()
+ self.assertTrue(isinstance(system_dict['basic_envs'], str))
+ self.assertTrue(system_dict['version'], '2.2.2.2')
+ self.assertEqual({'a': 'b'}, system_dict['variables'])
+ self.assertEqual({'a': 'b'}, system_dict['variables'])
+
+ def test_generate_project_dict(self):
+ project = Project(name='project', comment='comment')
+ participant = Participant(name='test-participant', domain_name='fl-test.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ with db.session_scope() as session:
+ session.add(project)
+ session.add(participant)
+ session.add(relationship)
+ session.commit()
+ project_dict = YamlFormatterService.generate_project_dict(project)
+ result_dict = {'egress_domain': 'fl-test.com', 'egress_host': 'fl-test-client-auth.com'}
+ self.assertEqual(project_dict['participants[0]'], result_dict)
+ self.assertEqual(project_dict['participants'][0], result_dict)
+
+ def test_generate_job_run_yaml(self):
+ with db.session_scope() as session:
+ project = Project(id=1, name='project 1')
+ session.add(project)
+ session.flush()
+ job_def = JobDefinition(name='lonely_job', job_type=JobDefinition.ANALYZER)
+ job_def.yaml_template = """
+ {
+ "apiVersion": "sparkoperator.k8s.io/v1beta2",
+ "kind": "SparkApplication",
+ "metadata": {
+ "name": self.name,
+ },
+
+ }
+ """
+ job = Job(name='test', project_id=1, job_type=JobType(job_def.job_type), workflow_id=0)
+ job.set_config(job_def)
+ session.add(job)
+ session.commit()
+ result = YamlFormatterService(session).generate_job_run_yaml(job)
+ self.assertEqual(
+ result, {
+ 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
+ 'kind': 'SparkApplication',
+ 'metadata': {
+ 'name': 'test',
+ 'labels': {
+ 'owner': 'no___workflow'
+ }
+ }
+ })
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/k8s/BUILD.bazel
new file mode 100644
index 000000000..865f2ee58
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/BUILD.bazel
@@ -0,0 +1,126 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "event_listener_lib",
+ srcs = ["event_listener.py"],
+ imports = ["../.."],
+ deps = [":k8s_cache_lib"],
+)
+
+py_library(
+ name = "fake_k8s_client_lib",
+ srcs = ["fake_k8s_client.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/testing:helpers_lib",
+ "@common_kubernetes//:pkg",
+ ],
+)
+
+py_library(
+ name = "k8s_cache_lib",
+ srcs = ["k8s_cache.py"],
+ imports = ["../.."],
+ deps = [":models_lib"],
+)
+
+py_test(
+ name = "k8s_cache_lib_test",
+ size = "small",
+ srcs = [
+ "k8s_cache_test.py",
+ ],
+ imports = ["../.."],
+ main = "k8s_cache_test.py",
+ deps = [
+ ":k8s_cache_lib",
+ ],
+)
+
+py_library(
+ name = "k8s_client_lib",
+ srcs = ["k8s_client.py"],
+ imports = ["../.."],
+ deps = [
+ ":fake_k8s_client_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:es_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:hooks_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "@common_kubernetes//:pkg",
+ ],
+)
+
+py_test(
+ name = "k8s_client_lib_test",
+ size = "small",
+ srcs = [
+ "k8s_client_test.py",
+ ],
+ imports = ["../.."],
+ main = "k8s_client_test.py",
+ deps = [
+ ":k8s_client_lib",
+ ],
+)
+
+py_library(
+ name = "k8s_watcher_lib",
+ srcs = ["k8s_watcher.py"],
+ imports = ["../.."],
+ deps = [
+ ":k8s_cache_lib",
+ ":k8s_client_lib",
+ ":models_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:event_listener_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:event_listener_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "@common_kubernetes//:pkg",
+ ],
+)
+
+py_test(
+ name = "k8s_watcher_lib_test",
+ size = "medium",
+ srcs = [
+ "k8s_watcher_test.py",
+ ],
+ imports = ["../.."],
+ main = "k8s_watcher_test.py",
+ deps = [
+ ":k8s_watcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_cache_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_kubernetes//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/event_listener.py b/web_console_v2/api/fedlearner_webconsole/k8s/event_listener.py
new file mode 100644
index 000000000..ffd8a5eeb
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/event_listener.py
@@ -0,0 +1,24 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from abc import ABCMeta, abstractmethod
+from fedlearner_webconsole.k8s.k8s_cache import Event
+
+
+class EventListener(metaclass=ABCMeta):
+
+ @abstractmethod
+ def update(self, event: Event):
+ pass
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/fake_k8s_client.py b/web_console_v2/api/fedlearner_webconsole/k8s/fake_k8s_client.py
new file mode 100644
index 000000000..2f33aca1f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/fake_k8s_client.py
@@ -0,0 +1,287 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+# pylint: disable=logging-format-interpolation
+import logging
+import datetime
+import time
+
+from kubernetes import client
+
+from testing.helpers import to_simple_namespace
+
+_RAISE_EXCEPTION_KEY = 'raise_exception'
+
+
+class FakeResponse(object):
+
+ def read_chunked(self, *args, **kwargs):
+ return []
+
+ def close(self):
+ pass
+
+ def release_conn(self):
+ pass
+
+
+class FakeCoreApi(object):
+
+ def __init__(self, timeouts=10):
+ # timeout in second
+ self.timeouts = timeouts
+
+ def list_namespaced_pod(self, namespace, **kwargs):
+ time.sleep(self.timeouts)
+ return FakeResponse()
+
+
+class FakeCrdsApi(object):
+
+ def __init__(self, timeouts=10):
+ # timeout in second
+ self.timeouts = timeouts
+
+ def list_namespaced_custom_object(self, namespace, **kwargs):
+ time.sleep(self.timeouts)
+ return FakeResponse()
+
+ def get_namespaced_custom_object(self, *args, **kwargs):
+ return FakeResponse()
+
+
+class FakeK8sClient(object):
+ """A fake k8s client for development.
+
+ With this client we can decouple the dependency of k8s cluster.
+ """
+
+ def __init__(self):
+ self.core = FakeCoreApi(60)
+ self.crds = FakeCrdsApi(60)
+
+ def close(self):
+ pass
+
+ def create_or_update_secret(self, data, metadata, secret_type, name, namespace='default'):
+ # User may pass two type of data:
+ # 1. dictionary
+ # 2. K8s Object
+ # They are both accepted by real K8s client,
+ # but K8s Object is not iterable.
+ if isinstance(data, dict) and _RAISE_EXCEPTION_KEY in data:
+ raise RuntimeError('[500] Fake exception for save_secret')
+ # Otherwise succeeds
+ logging.info('======================')
+ logging.info(f'Saved a secret with: data: {data}, ' 'metadata: {metadata}, type: {secret_type}')
+
+ def delete_secret(self, name, namespace='default'):
+ logging.info('======================')
+ logging.info(f'Deleted a secret with: name: {name}')
+
+ def get_secret(self, name, namespace='default'):
+ return client.V1Secret(api_version='v1',
+ data={'test': 'test'},
+ kind='Secret',
+ metadata={
+ 'name': name,
+ 'namespace': namespace
+ },
+ type='Opaque')
+
+ def create_or_update_service(self, metadata, spec, name, namespace='default'):
+ logging.info('======================')
+ logging.info(f'Saved a service with: spec: {spec}, metadata: {metadata}')
+
+ def delete_service(self, name, namespace='default'):
+ logging.info('======================')
+ logging.info(f'Deleted a service with: name: {name}')
+
+ def get_service(self, name, namespace='default'):
+ return client.V1Service(api_version='v1',
+ kind='Service',
+ metadata=client.V1ObjectMeta(name=name, namespace=namespace),
+ spec=client.V1ServiceSpec(selector={'app': 'nginx'}))
+
+ def list_service(self, namespace='default'):
+ service_dict = {'items': [{'metadata': {'name': f'fl-{i * 3}'}} for i in 'ac']}
+ return to_simple_namespace(service_dict)
+
+ def create_or_update_ingress(self, metadata, spec, name, namespace='default'):
+ logging.info('======================')
+ logging.info(f'Saved a ingress with: spec: {spec}, metadata: {metadata}')
+
+ def delete_ingress(self, name, namespace='default'):
+ logging.info('======================')
+ logging.info(f'Deleted a ingress with: name: {name}')
+
+ def get_ingress(self, name, namespace='default'):
+ return client.NetworkingV1beta1Ingress(api_version='networking.k8s.io/v1beta1',
+ kind='Ingress',
+ metadata=client.V1ObjectMeta(name=name, namespace=namespace),
+ spec=client.NetworkingV1beta1IngressSpec())
+
+ def list_ingress(self, namespace='default'):
+ ingress_dict = {'items': [{'metadata': {'name': f'fl-{i * 3}-client-auth'}} for i in 'abc']}
+ return to_simple_namespace(ingress_dict)
+
+ def create_or_update_deployment(self, metadata, spec, name, namespace='default'):
+ logging.info('======================')
+ logging.info(f'Saved a deployment with: spec: {spec}, metadata: {metadata}')
+
+ def delete_deployment(self, name, namespace='default'):
+ logging.info('======================')
+ logging.info(f'Deleted a deployment with: name: {name}')
+
+ def get_deployment(self, name, namespace='default'):
+ return client.V1Deployment(
+ api_version='apps/v1',
+ kind='Deployment',
+ metadata=client.V1ObjectMeta(name=name, namespace=namespace),
+ spec=client.V1DeploymentSpec(
+ selector={'matchLabels': {
+ 'app': 'fedlearner-operator'
+ }},
+ template=client.V1PodTemplateSpec(spec=client.V1PodSpec(
+ containers=[client.V1Container(name='fedlearner-operator', args=['test'])]))))
+
+ def delete_app(self, app_name, group, version: str, plural: str, namespace: str = 'default'):
+ pass
+
+ def create_app(self, app_name, group, version: str, plural: str, namespace: str = 'default') -> dict:
+ return {}
+
+ def get_app_cache(self, app_name):
+ pods = {
+ 'pods': {
+ 'metadata': {
+ 'selfLink': '/api/v1/namespaces/default/pods',
+ 'resourceVersion': '780480990'
+ }
+ },
+ 'items': [{
+ 'metadata': {
+ 'name': f'{app_name}-0'
+ }
+ }, {
+ 'metadata': {
+ 'name': f'{app_name}-1'
+ }
+ }]
+ }
+ flapp = {
+ 'kind': 'FLAPP',
+ 'metadata': {
+ 'name': app_name,
+ 'namesapce': 'default'
+ },
+ 'status': {
+ 'appState': 'FLStateRunning',
+ 'flReplicaStatus': {
+ 'Master': {
+ 'active': {
+ 'laomiao-raw-data-1223-v1-follower'
+ '-master-0-717b53c4-'
+ 'fef7-4d65-a309-63cf62494286': {}
+ }
+ },
+ 'Worker': {
+ 'active': {
+ 'laomiao-raw-data-1223-v1-follower'
+ '-worker-0-61e49961-'
+ 'e6dd-4015-a246-b6d25e69a61c': {},
+ 'laomiao-raw-data-1223-v1-follower'
+ '-worker-1-accef16a-'
+ '317f-440f-8f3f-7dd5b3552d25': {}
+ }
+ }
+ }
+ }
+ }
+ return {'flapp': flapp, 'pods': pods}
+
+ def get_sparkapplication(self, name: str, namespace: str = 'default') -> dict:
+ logging.info('======================')
+ logging.info(f'get spark application, name: {name}, namespace: {namespace}')
+ return {
+ 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
+ 'kind': 'SparkApplication',
+ 'metadata': {
+ 'creationTimestamp': '2021-04-15T10:43:15Z',
+ 'generation': 1,
+ 'name': name,
+ 'namespace': namespace,
+ },
+ 'status': {
+ 'applicationState': {
+ 'state': 'COMPLETED'
+ },
+ }
+ }
+
+ def create_sparkapplication(self, json_object: dict, namespace: str = 'default') -> dict:
+ logging.info('======================')
+ logging.info(f'create spark application, namespace: {namespace}, ' f'json: {json_object}')
+ return {
+ 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
+ 'kind': 'SparkApplication',
+ 'metadata': {
+ 'creationTimestamp': '2021-04-15T10:43:15Z',
+ 'generation': 1,
+ 'name': 'fl-transformer-yaml',
+ 'namespace': 'fedlearner',
+ 'resourceVersion': '348817823',
+ },
+ 'spec': {
+ 'arguments': ['hdfs://user/feature/data.csv', 'hdfs://user/feature/data_tfrecords/'],
+ }
+ }
+
+ def delete_sparkapplication(self, name: str, namespace: str = 'default') -> dict:
+ logging.info('======================')
+ logging.info(f'delete spark application, name: {name}, namespace: {namespace}')
+ return {
+ 'kind': 'Status',
+ 'apiVersion': 'v1',
+ 'metadata': {},
+ 'status': 'Success',
+ 'details': {
+ 'name': name,
+ 'group': 'sparkoperator.k8s.io',
+ 'kind': 'sparkapplications',
+ 'uid': '790603b6-9dd6-11eb-9282-b8599fb51ea8'
+ }
+ }
+
+ def get_pod_log(self, name: str, namespace: str, tail_lines: int):
+ return str(datetime.datetime.now())
+
+ def get_pods(self, namespace, label_selector):
+ fake_pod = client.V1Pod(metadata=client.V1ObjectMeta(name='fake_pod',
+ labels={},
+ creation_timestamp=datetime.datetime.utcnow()),
+ status=client.V1PodStatus(phase='Running'),
+ spec=client.V1PodSpec(containers=[
+ client.V1Container(name='fake_container',
+ resources=client.V1ResourceRequirements(limits={
+ 'cpu': '2000m',
+ 'memory': '4Gi'
+ },
+ requests={
+ 'cpu': '2000m',
+ 'memory': '4Gi'
+ }))
+ ]))
+ return client.V1PodList(items=[fake_pod])
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/k8s_cache.py b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_cache.py
new file mode 100644
index 000000000..f74d248bc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_cache.py
@@ -0,0 +1,122 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from enum import Enum
+
+from fedlearner_webconsole.k8s.models import get_app_name_from_metadata
+
+
+class EventType(Enum):
+ ADDED = 'ADDED'
+ MODIFIED = 'MODIFIED'
+ DELETED = 'DELETED'
+
+
+class ObjectType(Enum):
+ POD = 'POD'
+ FLAPP = 'FLAPP'
+ SPARKAPP = 'SPARKAPP'
+ FEDAPP = 'FEDAPP'
+
+
+class Event(object):
+
+ def __init__(self, app_name: str, event_type: EventType, obj_type: ObjectType, obj_dict: dict):
+ self.app_name = app_name
+ self.event_type = event_type
+ self.obj_type = obj_type
+ # {'status': {}, 'metadata': {}}
+ self.obj_dict = obj_dict
+
+ @staticmethod
+ def from_json(event, obj_type):
+ # TODO(xiangyuxuan): move this to k8s/models.py
+ event_type = event['type']
+ obj = event['object']
+ if obj_type == ObjectType.POD:
+ app_name = get_app_name_from_metadata(obj.metadata)
+ obj = obj.to_dict()
+ status = obj.get('status')
+ return Event(app_name,
+ EventType(event_type),
+ obj_type,
+ obj_dict={
+ 'status': status,
+ 'metadata': obj.get('metadata', {})
+ })
+
+ metadata = obj.get('metadata', {})
+ # put event to queue
+ return Event(metadata.get('name', None), EventType(event_type), obj_type, obj_dict=obj)
+
+
+class K8sCache(object):
+
+ def __init__(self):
+ # key: app_name, value: a dict
+ # {'flapp': flapp cache, 'pods': pods cache,
+ # 'deleted': is flapp deleted}
+ self._cache = {}
+ self._pod_cache = {}
+
+ def inspect(self) -> dict:
+ c = {}
+ c['pod_cache'] = self._pod_cache
+ c['app_cache'] = self._cache
+ return c
+
+ def update_cache(self, event: Event):
+ if event.obj_type == ObjectType.POD:
+ self._updata_pod_cache(event)
+ else:
+ self._update_app_cache(event)
+
+ def get_cache(self, app_name: str) -> dict:
+ return self._get_app_cache(app_name)
+
+ def _update_app_cache(self, event: Event):
+ app_name = event.app_name
+
+ self._cache[app_name] = {'app': event.obj_dict}
+ if app_name not in self._pod_cache:
+ self._pod_cache[app_name] = {'items': [], 'deleted': False}
+ self._pod_cache[app_name]['deleted'] = False
+ if event.event_type == EventType.DELETED:
+ self._cache[app_name] = {'app': None}
+ self._pod_cache[app_name] = {'items': [], 'deleted': True}
+
+ def _get_app_cache(self, app_name) -> dict:
+ if app_name not in self._cache:
+ return {'app': None, 'pods': {'items': []}}
+ app = {**self._cache[app_name], 'pods': self._pod_cache[app_name]}
+ return app
+
+ def _updata_pod_cache(self, event: Event):
+ app_name = event.app_name
+ if app_name not in self._pod_cache:
+ self._pod_cache[app_name] = {'items': [], 'deleted': False}
+ if self._pod_cache[app_name]['deleted']:
+ return
+ existed = False
+ for index, pod in enumerate(self._pod_cache[app_name]['items']):
+ if pod['metadata']['name'] == event.obj_dict['metadata']['name']:
+ existed = True
+ self._pod_cache[app_name]['items'][index] = event.obj_dict
+ break
+ if not existed:
+ self._pod_cache[app_name]['items'].append(event.obj_dict)
+
+
+k8s_cache = K8sCache()
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/k8s_cache_test.py b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_cache_test.py
new file mode 100644
index 000000000..f4cb1f630
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_cache_test.py
@@ -0,0 +1,33 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from fedlearner_webconsole.k8s.k8s_cache import Event, ObjectType, EventType
+
+
+class EventTest(unittest.TestCase):
+
+ def test_from_json(self):
+ app_object = {'metadata': {'name': 'test'}, 'status': None, 'spec': {'test': 1}}
+ test_event_dict = {'type': 'ADDED', 'object': app_object}
+ event = Event.from_json(test_event_dict, ObjectType.FLAPP)
+ self.assertEqual(event.app_name, 'test')
+ self.assertEqual(event.obj_type, ObjectType.FLAPP)
+ self.assertEqual(event.event_type, EventType.ADDED)
+ self.assertEqual(event.obj_dict, app_object)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/k8s_client.py b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_client.py
new file mode 100644
index 000000000..60614beb4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_client.py
@@ -0,0 +1,445 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+# pylint: disable=inconsistent-return-statements
+import enum
+import logging
+from http import HTTPStatus
+from typing import Callable, Optional
+
+import kubernetes
+from kubernetes import client
+from kubernetes.client import V1ServiceList, NetworkingV1beta1IngressList
+from kubernetes.client.exceptions import ApiException
+
+from envs import Envs
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+from fedlearner_webconsole.exceptions import (NotFoundException, InternalException)
+from fedlearner_webconsole.k8s.fake_k8s_client import FakeK8sClient
+from fedlearner_webconsole.utils.es import es
+from fedlearner_webconsole.utils.hooks import parse_and_get_fn
+
+# This is the default k8s client hook.
+# Args:
+# app_yaml [dict] the app yaml definition of k8s resource.
+# Returns:
+# [dict] the modified app yaml of k8s resource.
+# Note:
+# If you want to custom k8s client hook,
+# 1. write the hook function according this interface
+# 2. assign module_fn to `K8S_HOOK_MODULE_PATH` variables.
+DEFAULT_K8S_CLIENT_HOOK: Callable[[dict], dict] = lambda o: o
+
+
+# TODO(xiangyuxuan.prs): they are just used in dataset, should be deprecated.
+class CrdKind(enum.Enum):
+ FLAPP = 'flapps'
+ SPARK_APPLICATION = 'sparkapplications'
+
+
+FEDLEARNER_CUSTOM_GROUP = 'fedlearner.k8s.io'
+FEDLEARNER_CUSTOM_VERSION = 'v1alpha1'
+
+SPARKOPERATOR_CUSTOM_GROUP = 'sparkoperator.k8s.io'
+SPARKOPERATOR_CUSTOM_VERSION = 'v1beta2'
+SPARKOPERATOR_NAMESPACE = Envs.K8S_NAMESPACE
+
+REQUEST_TIMEOUT_IN_SECOND = 10
+
+
+# TODO(wangsen.0914): remove create_deployment etc.; add UT for client
+class K8sClient(object):
+
+ def __init__(self):
+ self.core = None
+ self.crds = None
+ self._networking = None
+ self._app = None
+ self._hook_fn = DEFAULT_K8S_CLIENT_HOOK
+
+ def init(self, config_path: Optional[str] = None, hook_module_path: Optional[str] = None):
+ # Sets config
+ if config_path is None:
+ kubernetes.config.load_incluster_config()
+ else:
+ kubernetes.config.load_kube_config(config_path)
+
+ # Initialize hook
+ if hook_module_path:
+ self._hook_fn = parse_and_get_fn(hook_module_path)
+
+ # Inits API clients
+ self.core = client.CoreV1Api()
+ self.crds = client.CustomObjectsApi()
+ self._networking = client.NetworkingV1beta1Api()
+ self._app = client.AppsV1Api()
+
+ def close(self):
+ self.core.api_client.close()
+ self._networking.api_client.close()
+
+ def _raise_runtime_error(self, exception: ApiException):
+ logging.error(f'[k8s_client]: runtime error {exception}')
+ raise RuntimeError(str(exception))
+
+ def create_or_update_secret(self, data, metadata, secret_type, name, namespace='default'):
+ """Create secret. If existed, then replace"""
+ request = client.V1Secret(api_version='v1', data=data, kind='Secret', metadata=metadata, type=secret_type)
+ try:
+ self.core.read_namespaced_secret(name, namespace)
+ # If the secret already exists, then we use patch to replace it.
+ # We don't use replace method because it requires `resourceVersion`.
+ self.core.patch_namespaced_secret(name, namespace, request)
+ return
+ except ApiException as e:
+ # 404 is expected if the secret does not exist
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+ try:
+ self.core.create_namespaced_secret(namespace, request)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def delete_secret(self, name, namespace='default'):
+ try:
+ self.core.delete_namespaced_secret(name, namespace)
+ except ApiException as e:
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+
+ def get_secret(self, name, namespace='default'):
+ try:
+ return self.core.read_namespaced_secret(name, namespace)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def create_or_update_config_map(self, metadata, data, name, namespace='default'):
+ """Create configMap. If existed, then patch"""
+ request = client.V1ConfigMap(api_version='v1', kind='ConfigMap', metadata=metadata, data=data)
+ try:
+ self.core.read_namespaced_config_map(name, namespace)
+ # If the configMap already exists, then we use patch to replace it.
+ # We don't use replace method because it requires `resourceVersion`.
+ self.core.patch_namespaced_config_map(name, namespace, request)
+ return
+ except ApiException as e:
+ # 404 is expected if the configMap does not exist
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+ try:
+ self.core.create_namespaced_config_map(namespace, request)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def delete_config_map(self, name, namespace='default'):
+ try:
+ self.core.delete_namespaced_config_map(name, namespace)
+ except ApiException as e:
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+
+ def get_config_map(self, name, namespace='default'):
+ try:
+ return self.core.read_namespaced_config_map(name, namespace)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def create_or_update_service(self, metadata, spec, name, namespace='default'):
+ """Create secret. If existed, then replace"""
+ request = client.V1Service(api_version='v1', kind='Service', metadata=metadata, spec=spec)
+ try:
+ self.core.read_namespaced_service(name, namespace)
+ # If the service already exists, then we use patch to replace it.
+ # We don't use replace method because it requires `resourceVersion`.
+ self.core.patch_namespaced_service(name, namespace, request)
+ return
+ except ApiException as e:
+ # 404 is expected if the service does not exist
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+ try:
+ self.core.create_namespaced_service(namespace, request)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def delete_service(self, name, namespace='default'):
+ try:
+ self.core.delete_namespaced_service(name, namespace)
+ except ApiException as e:
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+
+ def get_service(self, name, namespace='default'):
+ try:
+ return self.core.read_namespaced_service(name, namespace)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def list_service(self, namespace: str = 'default') -> V1ServiceList:
+ try:
+ return self.core.list_namespaced_service(namespace)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def create_or_update_ingress(self, metadata, spec, name, namespace='default'):
+ request = client.NetworkingV1beta1Ingress(api_version='networking.k8s.io/v1beta1',
+ kind='Ingress',
+ metadata=metadata,
+ spec=spec)
+ try:
+ self._networking.read_namespaced_ingress(name, namespace)
+ # If the ingress already exists, then we use patch to replace it.
+ # We don't use replace method because it requires `resourceVersion`.
+ self._networking.patch_namespaced_ingress(name, namespace, request)
+ return
+ except ApiException as e:
+ # 404 is expected if the ingress does not exist
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+ try:
+ self._networking.create_namespaced_ingress(namespace, request)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def delete_ingress(self, name, namespace='default'):
+ try:
+ self._networking.delete_namespaced_ingress(name, namespace)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def get_ingress(self, name, namespace='default'):
+ try:
+ return self._networking.read_namespaced_ingress(name, namespace)
+ except ApiException as e:
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+
+ def list_ingress(self, namespace: str = 'default') -> NetworkingV1beta1IngressList:
+ try:
+ return self._networking.list_namespaced_ingress(namespace)
+ except ApiException as e:
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+
+ def create_or_update_deployment(self, metadata, spec, name, namespace='default'):
+ request = client.V1Deployment(api_version='apps/v1', kind='Deployment', metadata=metadata, spec=spec)
+ try:
+ self._app.read_namespaced_deployment(name, namespace)
+ # If the deployment already exists, then we use patch to replace it.
+ # We don't use replace method because it requires `resourceVersion`.
+ self._app.patch_namespaced_deployment(name, namespace, request)
+ return
+ except ApiException as e:
+ # 404 is expected if the deployment does not exist
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+ try:
+ self._app.create_namespaced_deployment(namespace, request)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def delete_deployment(self, name, namespace='default'):
+ try:
+ self._app.delete_namespaced_deployment(name, namespace)
+ except ApiException as e:
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+
+ def get_deployment(self, name):
+ try:
+ return self._app.read_namespaced_deployment(name, Envs.K8S_NAMESPACE)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def get_sparkapplication(self, name: str, namespace: str = SPARKOPERATOR_NAMESPACE) -> dict:
+ """get sparkapp
+
+ Args:
+ name (str): sparkapp name
+ namespace (str, optional): namespace to submit.
+
+ Raises:
+ InternalException: if any error occurs during API call
+ NotFoundException: if the spark app is not found
+
+ Returns:
+ dict: resp of k8s
+ """
+ try:
+ return self.crds.get_namespaced_custom_object(group=SPARKOPERATOR_CUSTOM_GROUP,
+ version=SPARKOPERATOR_CUSTOM_VERSION,
+ namespace=namespace,
+ plural=CrdKind.SPARK_APPLICATION.value,
+ name=name)
+ except ApiException as err:
+ if err.status == 404:
+ raise NotFoundException() from err
+ raise InternalException(details=err.body) from err
+
+ def create_sparkapplication(self, json_object: dict, namespace: str = SPARKOPERATOR_NAMESPACE) -> dict:
+ """ create sparkapp
+
+ Args:
+ json_object (dict): json object of config
+ namespace (str, optional): namespace to submit.
+
+ Returns:
+ dict: resp of k8s
+ """
+ logging.debug(f'create sparkapp json is {json_object}')
+ return self.crds.create_namespaced_custom_object(group=SPARKOPERATOR_CUSTOM_GROUP,
+ version=SPARKOPERATOR_CUSTOM_VERSION,
+ namespace=namespace,
+ plural=CrdKind.SPARK_APPLICATION.value,
+ body=json_object)
+
+ def delete_sparkapplication(self, name: str, namespace: str = SPARKOPERATOR_NAMESPACE) -> dict:
+ """ delete sparkapp
+
+ Args:
+ name (str): sparkapp name
+ namespace (str, optional): namespace to delete.
+
+ Raises:
+ NotFoundException: if the spark app is nout found
+ InternalException: if any error occurs during API call
+
+ Returns:
+ dict: resp of k8s
+ """
+ try:
+ return self.crds.delete_namespaced_custom_object(group=SPARKOPERATOR_CUSTOM_GROUP,
+ version=SPARKOPERATOR_CUSTOM_VERSION,
+ namespace=namespace,
+ plural=CrdKind.SPARK_APPLICATION.value,
+ name=name,
+ body=client.V1DeleteOptions())
+ except ApiException as err:
+ if err.status == 404:
+ raise NotFoundException() from err
+ raise InternalException(details=err.body) from err
+
+ def get_pod_log(self, name: str, namespace: str, tail_lines: int):
+ # this is not necessary for now
+ del namespace
+ return es.query_log(Envs.ES_INDEX, '', name)[:tail_lines][::-1]
+
+ def get_pods(self, namespace, label_selector):
+ try:
+ return self.core.list_namespaced_pod(namespace=namespace, label_selector=label_selector)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def create_app(self,
+ app_yaml: dict,
+ group: str,
+ version: str,
+ plural: str,
+ namespace: str = Envs.K8S_NAMESPACE) -> dict:
+ try:
+ app_yaml = self._hook_fn(app_yaml)
+ return self.crds.create_namespaced_custom_object(group=group,
+ version=version,
+ namespace=namespace,
+ plural=plural,
+ body=app_yaml,
+ _request_timeout=REQUEST_TIMEOUT_IN_SECOND)
+ except ApiException as e:
+ # 404 is expected if the custom resource does not exist
+ if e.status != HTTPStatus.CONFLICT:
+ self._raise_runtime_error(e)
+ logging.warning(f'Crd object: {app_yaml} has been created!')
+
+ @retry_fn(retry_times=3)
+ def delete_app(self, app_name, group, version: str, plural: str, namespace: str = Envs.K8S_NAMESPACE):
+ try:
+ self.crds.delete_namespaced_custom_object(group=group,
+ version=version,
+ namespace=namespace,
+ plural=plural,
+ name=app_name,
+ _request_timeout=REQUEST_TIMEOUT_IN_SECOND)
+ except ApiException as e:
+ # If the custom resource has been deleted then the exception gets ignored
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+
+ def get_custom_object(self,
+ name: str,
+ group: str,
+ version: str,
+ plural: str,
+ namespace: str = Envs.K8S_NAMESPACE) -> dict:
+ try:
+ return self.crds.get_namespaced_custom_object(group=group,
+ version=version,
+ namespace=namespace,
+ plural=plural,
+ name=name,
+ _request_timeout=REQUEST_TIMEOUT_IN_SECOND)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+ def update_app(self, app_yaml: dict, group: str, version: str, plural: str, namespace: str = Envs.K8S_NAMESPACE):
+ try:
+ app_yaml = self._hook_fn(app_yaml)
+ name = app_yaml['metadata']['name']
+ self.crds.patch_namespaced_custom_object(group=group,
+ version=version,
+ namespace=namespace,
+ plural=plural,
+ name=name,
+ body=app_yaml,
+ _request_timeout=REQUEST_TIMEOUT_IN_SECOND)
+ except ApiException as e:
+ if e.status == HTTPStatus.NOT_FOUND:
+ logging.error(f'[k8s_client] Resource: {app_yaml} doesn\'t exist!')
+ self._raise_runtime_error(e)
+
+ def create_or_update_app(self,
+ app_yaml: dict,
+ group: str,
+ version: str,
+ plural: str,
+ namespace: str = Envs.K8S_NAMESPACE):
+ name = app_yaml['metadata']['name']
+ try:
+ # Why not use `get_custom_object`?
+ # Because `get_custom_object` wraps the exception, it's difficult to parse 404 info.
+ self.crds.get_namespaced_custom_object(group=group,
+ version=version,
+ namespace=namespace,
+ plural=plural,
+ name=name,
+ _request_timeout=REQUEST_TIMEOUT_IN_SECOND)
+ # If the resource already exists, then we use patch to replace it.
+ # We don't use replace method because it requires `resourceVersion`.
+ self.update_app(app_yaml=app_yaml, group=group, version=version, plural=plural, namespace=namespace)
+ return
+ except ApiException as e:
+ # 404 is expected if the deployment does not exist
+ if e.status != HTTPStatus.NOT_FOUND:
+ self._raise_runtime_error(e)
+ try:
+ self.create_app(app_yaml=app_yaml, group=group, version=version, plural=plural, namespace=namespace)
+ except ApiException as e:
+ self._raise_runtime_error(e)
+
+
+k8s_client = FakeK8sClient()
+if Envs.FLASK_ENV == 'production' or \
+ Envs.K8S_CONFIG_PATH is not None:
+ k8s_client = K8sClient()
+ k8s_client.init(config_path=Envs.K8S_CONFIG_PATH, hook_module_path=Envs.K8S_HOOK_MODULE_PATH)
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/k8s_client_test.py b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_client_test.py
new file mode 100644
index 000000000..3d38d8914
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_client_test.py
@@ -0,0 +1,142 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import MagicMock, patch
+
+from kubernetes.client import ApiException
+
+from fedlearner_webconsole.k8s.k8s_client import K8sClient, REQUEST_TIMEOUT_IN_SECOND
+
+
+class K8sClientTest(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._load_incluster_config_patcher = patch(
+ 'fedlearner_webconsole.k8s.k8s_client.kubernetes.config.load_incluster_config', lambda: None)
+ self._load_incluster_config_patcher.start()
+
+ self._k8s_client = K8sClient()
+ self._k8s_client.init()
+
+ def tearDown(self):
+ self._load_incluster_config_patcher.stop()
+ super().tearDown()
+
+ def test_delete_flapp(self):
+ mock_crds = MagicMock()
+ self._k8s_client.crds = mock_crds
+ # Test delete successfully
+ mock_crds.delete_namespaced_custom_object = MagicMock()
+ self._k8s_client.delete_app('test_flapp', 'fedlearner.k8s.io', 'v1alpha1', 'flapps')
+ mock_crds.delete_namespaced_custom_object.assert_called_once_with(group='fedlearner.k8s.io',
+ name='test_flapp',
+ namespace='default',
+ plural='flapps',
+ version='v1alpha1',
+ _request_timeout=REQUEST_TIMEOUT_IN_SECOND)
+ # Tests that the flapp has been deleted
+ mock_crds.delete_namespaced_custom_object = MagicMock(side_effect=ApiException(status=404))
+ self._k8s_client.delete_app('test_flapp2', 'fedlearner.k8s.io', 'v1alpha1', 'flapps')
+ self.assertEqual(mock_crds.delete_namespaced_custom_object.call_count, 1)
+ # Tests with other exceptions
+ mock_crds.delete_namespaced_custom_object = MagicMock(side_effect=ApiException(status=500))
+ with self.assertRaises(RuntimeError):
+ self._k8s_client.delete_app('test_flapp3', 'fedlearner.k8s.io', 'v1alpha1', 'flapps')
+ self.assertEqual(mock_crds.delete_namespaced_custom_object.call_count, 3)
+
+ def test_create_flapp(self):
+ test_yaml = {'metadata': {'name': 'test app'}, 'kind': 'flapp', 'apiVersion': 'fedlearner.k8s.io/v1alpha1'}
+ mock_crds = MagicMock()
+ self._k8s_client.crds = mock_crds
+ # Test create successfully
+ mock_crds.create_namespaced_custom_object = MagicMock()
+ self._k8s_client.create_app(test_yaml, plural='flapps', version='v1alpha1', group='fedlearner.k8s.io')
+ mock_crds.create_namespaced_custom_object.assert_called_once_with(group='fedlearner.k8s.io',
+ namespace='default',
+ plural='flapps',
+ version='v1alpha1',
+ _request_timeout=REQUEST_TIMEOUT_IN_SECOND,
+ body=test_yaml)
+ self._k8s_client.create_app(test_yaml, plural='flapps', version='v1alpha1', group='fedlearner.k8s.io')
+ self.assertEqual(mock_crds.create_namespaced_custom_object.call_count, 2)
+
+ @patch('fedlearner_webconsole.k8s.k8s_client.parse_and_get_fn')
+ @patch('fedlearner_webconsole.k8s.k8s_client.client.CustomObjectsApi.create_namespaced_custom_object')
+ def test_create_app_with_hook(self, mock_create_namespaced_custom_object: MagicMock,
+ mock_parse_and_get_fn: MagicMock):
+
+ def custom_magic_fn(app_yaml: dict) -> dict:
+ app_yaml['metadata']['name'] = app_yaml['metadata']['name'] + '_hello'
+ return app_yaml
+
+ mock_parse_and_get_fn.return_value = custom_magic_fn
+ self._k8s_client.init(hook_module_path='test.hook:custom_magic_fn')
+ deployment_app_yaml = {
+ 'apiVersion': 'apps/v1',
+ 'kind': 'Deployment',
+ 'metadata': {
+ 'name': 'world',
+ },
+ 'spec': {
+ 'selector': {
+ 'matchLabels': {
+ 'app': 'test-app'
+ }
+ },
+ 'replicas': 1,
+ 'template': {
+ 'metadata': {
+ 'labels': {
+ 'app': 'test-app'
+ }
+ },
+ 'spec': {
+ 'volumes': [{
+ 'name': 'test-app-config',
+ 'configMap': {
+ 'name': 'test-app-config'
+ }
+ }],
+ 'containers': [{
+ 'name': 'test-app',
+ 'image': 'serving:lastest',
+ 'args': [
+ '--port=8500', '--rest_api_port=8501', '--model_config_file=/app/config/config.pb'
+ ],
+ 'ports': [{
+ 'containerPort': 8500
+ }, {
+ 'containerPort': 8501
+ }],
+ 'volumeMounts': [{
+ 'name': 'test-app-config',
+ 'mountPath': '/app/config/'
+ }]
+ }]
+ }
+ }
+ }
+ }
+ self._k8s_client.create_app(app_yaml=deployment_app_yaml, group='apps', version='v1', plural='Deployment')
+
+ mock_parse_and_get_fn.assert_called_once_with('test.hook:custom_magic_fn')
+ mock_create_namespaced_custom_object.assert_called_once()
+ self.assertEqual(mock_create_namespaced_custom_object.call_args[1]['body']['metadata']['name'], 'world_hello')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/k8s_watcher.py b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_watcher.py
new file mode 100644
index 000000000..84a2ad336
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_watcher.py
@@ -0,0 +1,265 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from abc import ABC, abstractmethod
+import logging
+import multiprocessing
+import threading
+import queue
+import traceback
+from http import HTTPStatus
+from typing import Generator, NamedTuple, Optional, Tuple
+from kubernetes import client, watch
+from kubernetes.client import V1ObjectMeta
+from envs import Envs
+from fedlearner_webconsole.job.event_listener import JobEventListener
+from fedlearner_webconsole.k8s.k8s_cache import k8s_cache, Event, ObjectType
+from fedlearner_webconsole.utils.metrics import emit_store, emit_counter
+from fedlearner_webconsole.k8s.k8s_client import k8s_client
+from fedlearner_webconsole.k8s.models import CrdKind, get_app_name_from_metadata
+from fedlearner_webconsole.mmgr.event_listener import ModelEventListener
+
+
+class CrdWatcherConfig(NamedTuple):
+ version: str
+ group: str
+ plural: str
+ object_type: ObjectType
+
+
+WATCHER_CONFIG = {
+ CrdKind.FLAPP:
+ CrdWatcherConfig(
+ version='v1alpha1',
+ group='fedlearner.k8s.io',
+ plural='flapps',
+ object_type=ObjectType.FLAPP,
+ ),
+ CrdKind.SPARKAPPLICATION:
+ CrdWatcherConfig(
+ version='v1beta2',
+ group='sparkoperator.k8s.io',
+ plural='sparkapplications',
+ object_type=ObjectType.SPARKAPP,
+ ),
+ CrdKind.FEDAPP:
+ CrdWatcherConfig(
+ version='v1alpha1',
+ group='fedlearner.k8s.io',
+ plural='fedapps',
+ object_type=ObjectType.FEDAPP,
+ ),
+}
+
+_REQUEST_TIMEOUT_IN_SECOND = 900
+
+
+class AbstractWatcher(ABC):
+
+ @property
+ @abstractmethod
+ def kind(self) -> str:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _watch(self, watcher: watch.Watch, resource_version: str) -> Generator[Tuple[Event, Optional[str]], None, None]:
+ """An abstract method which subclasses should implement it to watch and procedue events."""
+ raise NotImplementedError()
+
+ def _watch_forever(self, event_queue: 'queue.Queue[Event]'):
+ """Watches forever, which handles a lot of exceptions and make the watcher always work (hopefully)."""
+ # resource_version '0' means getting a recent resource without
+ # consistency guarantee, this is to reduce the load of etcd.
+ # Ref: https://kubernetes.io/docs/reference/using-api/api-concepts/ #the-resourceversion-parameter
+ resource_version = '0'
+ watcher = watch.Watch()
+ while True:
+ try:
+ logging.info(f'[K8s watcher] [{self.kind}] start watching, resource version: {resource_version}')
+ # Each round we re-watch to k8s
+ watch_stream = self._watch(watcher, resource_version)
+ emit_counter('k8s.watcher.watch', 1, tags={'kind': self.kind})
+ for (event, new_version) in watch_stream:
+ if new_version:
+ # Updates newest resource version, note that resource version is string,
+ # using max is a little hacky.
+ resource_version = max(resource_version, new_version)
+ logging.debug(f'[K8s watcher] [{self.kind}] new resource version: {new_version}')
+ event_queue.put(event)
+ except client.exceptions.ApiException as e:
+ logging.exception(f'[K8s watcher] [{self.kind}] API error')
+ if e.status == HTTPStatus.GONE:
+ # It has been too old, resources should be relisted
+ resource_version = '0'
+ # TODO(xiangyuxuan.prs): remove in the future.
+ elif e.status == HTTPStatus.NOT_FOUND:
+ logging.exception(f'[K8s watcher] [{self.kind}] unsupported')
+ break
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception(f'[K8s watcher] [{self.kind}] unexpected error')
+ watcher.stop()
+
+ def run(self, event_queue: queue.Queue, retry_timeout_in_second: int):
+ """Starts the watcher.
+
+ Historically the watcher (process) may hang and never send the requests to k8s API server,
+ so this is a workaround to retry after timeout.
+
+ Args:
+ event_queue: A queue to passthrough the events from watcher.
+ retry_timeout_in_second: If no event received within this threshold, the watcher gets restarted.
+ """
+ mp_context = multiprocessing.get_context('spawn')
+ internal_queue = mp_context.Queue()
+ process_name = f'k8s-watcher-{self.kind}'
+ process = mp_context.Process(name=process_name, daemon=True, target=self._watch_forever, args=(internal_queue,))
+ process.start()
+ logging.info(f'[K8s watcher] [{self.kind}] process started')
+ while True:
+ try:
+ # Waits for a new event with timeout, if it gets stuck, then we restart the watcher.
+ event = internal_queue.get(timeout=retry_timeout_in_second)
+ # Puts to outside
+ event_queue.put(event)
+ except queue.Empty:
+ logging.info(f'[K8s watcher] [{self.kind}] no event in queue, restarting...')
+ process.terminate()
+ process.join()
+ # TODO(wangsen.0914): add process.close() here after upgrade to python 3.8
+ internal_queue.close()
+ internal_queue = mp_context.Queue()
+ process = mp_context.Process(name=process_name,
+ daemon=True,
+ target=self._watch_forever,
+ args=(internal_queue,))
+ process.start()
+ logging.info(f'[K8s watcher] [{self.kind}] process restarted')
+
+
+class PodWatcher(AbstractWatcher):
+
+ @property
+ def kind(self) -> str:
+ return ObjectType.POD.name
+
+ def _watch(self, watcher: watch.Watch, resource_version: str) -> Generator[Tuple[Event, Optional[str]], None, None]:
+ stream = watcher.stream(
+ k8s_client.core.list_namespaced_pod,
+ namespace=Envs.K8S_NAMESPACE,
+ resource_version=resource_version,
+ # Sometimes watch gets stuck
+ _request_timeout=_REQUEST_TIMEOUT_IN_SECOND,
+ )
+ for event in stream:
+ metadata: V1ObjectMeta = event['object'].metadata
+ if get_app_name_from_metadata(metadata):
+ yield Event.from_json(event, ObjectType.POD), metadata.resource_version
+
+
+class CrdWatcher(AbstractWatcher):
+
+ def __init__(self, config: CrdWatcherConfig):
+ super().__init__()
+ self.config = config
+
+ @property
+ def kind(self) -> str:
+ return self.config.object_type.name
+
+ def _watch(self, watcher: watch.Watch, resource_version: str) -> Generator[Tuple[Event, Optional[str]], None, None]:
+ stream = watcher.stream(
+ k8s_client.crds.list_namespaced_custom_object,
+ group=self.config.group,
+ version=self.config.version,
+ namespace=Envs.K8S_NAMESPACE,
+ plural=self.config.plural,
+ resource_version=resource_version,
+ # Sometimes watch gets stuck
+ _request_timeout=_REQUEST_TIMEOUT_IN_SECOND,
+ )
+ for event in stream:
+ new_resource_version = event['object'].get('metadata', {}).get('resourceVersion', None)
+ yield Event.from_json(event, self.config.object_type), new_resource_version
+
+
+class K8sWatcher(object):
+
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._running = False
+ self._event_consumer_thread = None
+ self._event_listeners = [JobEventListener(), ModelEventListener()]
+
+ # https://stackoverflow.com/questions/62223424/simplequeue-vs-queue-in-python-what-is-the-advantage-of-using-simplequeue
+ # if use simplequeue, put opt never block.
+ # TODO(xiangyuxuan): change to simplequeue
+ self._queue = queue.Queue()
+ self._cache = {}
+ self._cache_lock = threading.Lock()
+
+ def start(self):
+ with self._lock:
+ if self._running:
+ logging.warning('K8s watcher has already started')
+ return
+ self._running = True
+
+ watchers = [PodWatcher()]
+ for _, crd_config in WATCHER_CONFIG.items():
+ watchers.append(CrdWatcher(config=crd_config))
+ watcher_threads = [
+ threading.Thread(
+ name=f'k8s-watcher-{watcher.kind}',
+ target=watcher.run,
+ args=(
+ self._queue,
+ # Keep consistent with k8s watcher event timeout
+ _REQUEST_TIMEOUT_IN_SECOND,
+ ),
+ daemon=True,
+ ) for watcher in watchers
+ ]
+
+ self._event_consumer_thread = threading.Thread(target=self._event_consumer,
+ name='cache_consumer',
+ daemon=True)
+ for wthread in watcher_threads:
+ wthread.start()
+ self._event_consumer_thread.start()
+ logging.info('K8s watcher started')
+
+ def _event_consumer(self):
+ # TODO(xiangyuxuan): do more business level operations
+ while True:
+ try:
+ event = self._queue.get()
+ k8s_cache.update_cache(event)
+ self._listen_crd_event(event)
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'K8s event_consumer : {str(e)}. ' f'traceback:{traceback.format_exc()}')
+
+ def _listen_crd_event(self, event: Event):
+ if event.obj_type == ObjectType.POD:
+ return
+ for listener in self._event_listeners:
+ try:
+ listener.update(event)
+ # pylint: disable=broad-except
+ except Exception as e:
+ emit_store('event_listener_update_error', 1)
+ logging.warning(f'[K8sWatcher] listener update with error {str(e)}')
+
+
+k8s_watcher = K8sWatcher()
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/k8s_watcher_test.py b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_watcher_test.py
new file mode 100644
index 000000000..6cd6b480d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/k8s_watcher_test.py
@@ -0,0 +1,291 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from http import HTTPStatus
+import multiprocessing
+import threading
+from typing import Callable, Generator, List, NamedTuple, Optional, Tuple
+import unittest
+from unittest.mock import MagicMock, Mock, patch
+from kubernetes import client, watch
+from kubernetes.client import V1Pod, V1ObjectMeta, V1OwnerReference
+
+from fedlearner_webconsole.k8s.k8s_watcher import AbstractWatcher, CrdWatcher, CrdWatcherConfig, PodWatcher
+from fedlearner_webconsole.k8s.k8s_cache import Event, EventType, ObjectType
+
+
+class Response(NamedTuple):
+ # gone/unknown/not_found/normal
+ type: str
+ events: Optional[List[Event]]
+
+
+# Why using fake implementation instead of mock?
+# because mock does not work in multiprocessing
+class FakeWatcher(AbstractWatcher):
+
+ def __init__(self, resp_sequence: Optional[List[Response]]):
+ super().__init__()
+ self.calls = []
+ self._resp_sequence = resp_sequence or []
+ self._round = 0
+
+ @property
+ def kind(self) -> str:
+ return 'fake'
+
+ def _watch(self, watcher: watch.Watch, resource_version: str) -> Generator[Tuple[Event, Optional[str]], None, None]:
+ assert isinstance(watcher, watch.Watch)
+ self.calls.append(resource_version)
+ if self._round < len(self._resp_sequence):
+ resp = self._resp_sequence[self._round]
+ self._round += 1
+ if resp.type == 'gone':
+ raise client.exceptions.ApiException(status=HTTPStatus.GONE)
+ if resp.type == 'not_found':
+ raise client.exceptions.ApiException(status=HTTPStatus.NOT_FOUND)
+ if resp.type == 'unknown':
+ raise RuntimeError('fake unknown')
+ for i, event in enumerate(resp.events or []):
+ yield event, str(i)
+ return
+ # Dead loop
+ while True:
+ pass
+
+
+def _fake_event_from_json(event, object_type):
+ return event, object_type
+
+
+class AbstractWatcherTest(unittest.TestCase):
+
+ def _start_thread(self, func: Callable, args) -> threading.Thread:
+ t = threading.Thread(daemon=True, target=func, args=args)
+ t.start()
+ return t
+
+ def test_watch_normally(self):
+ events = [
+ Event(app_name='t1', event_type=EventType.ADDED, obj_type=ObjectType.POD, obj_dict={}),
+ Event(app_name='t2', event_type=EventType.MODIFIED, obj_type=ObjectType.POD, obj_dict={}),
+ Event(app_name='t3', event_type=EventType.MODIFIED, obj_type=ObjectType.POD, obj_dict={}),
+ ]
+
+ watcher = FakeWatcher(resp_sequence=[
+ Response(type='normal', events=events),
+ ])
+ q = multiprocessing.Queue()
+ self._start_thread(watcher.run, args=(
+ q,
+ 1000,
+ ))
+ actual_events = [q.get(), q.get(), q.get()]
+ q.close()
+
+ app_names = [e.app_name for e in actual_events]
+ self.assertEqual(app_names, ['t1', 't2', 't3'])
+
+ def test_watch_k8s_api_gone(self):
+ events = [
+ Event(app_name='t1', event_type=EventType.ADDED, obj_type=ObjectType.POD, obj_dict={}),
+ Event(app_name='t2', event_type=EventType.MODIFIED, obj_type=ObjectType.POD, obj_dict={}),
+ ]
+
+ watcher = FakeWatcher(resp_sequence=[
+ Response(type='gone', events=None),
+ Response(type='normal', events=events),
+ ])
+ q = multiprocessing.Queue()
+ self._start_thread(watcher.run, args=(
+ q,
+ 1000,
+ ))
+ actual_events = [q.get(), q.get()]
+ q.close()
+
+ app_names = [e.app_name for e in actual_events]
+ self.assertEqual(app_names, ['t1', 't2'])
+
+ def test_watch_unknown_k8s_error(self):
+ events1 = [
+ Event(app_name='t1', event_type=EventType.ADDED, obj_type=ObjectType.POD, obj_dict={}),
+ Event(app_name='t2', event_type=EventType.MODIFIED, obj_type=ObjectType.POD, obj_dict={}),
+ ]
+ events2 = [
+ Event(app_name='t3', event_type=EventType.ADDED, obj_type=ObjectType.POD, obj_dict={}),
+ ]
+
+ watcher = FakeWatcher(resp_sequence=[
+ Response(type='normal', events=events1),
+ Response(type='unknown', events=None),
+ Response(type='normal', events=events2),
+ ])
+ q = multiprocessing.Queue()
+ self._start_thread(watcher.run, args=(
+ q,
+ 1000,
+ ))
+ actual_events = [q.get(), q.get(), q.get()]
+ q.close()
+
+ app_names = [e.app_name for e in actual_events]
+ self.assertEqual(app_names, ['t1', 't2', 't3'])
+
+ def test_watch_client_hangs(self):
+ events = [
+ Event(app_name='t1', event_type=EventType.ADDED, obj_type=ObjectType.POD, obj_dict={}),
+ Event(app_name='t2', event_type=EventType.MODIFIED, obj_type=ObjectType.POD, obj_dict={}),
+ ]
+
+ watcher = FakeWatcher(resp_sequence=[
+ Response(type='normal', events=events),
+ ])
+ q = multiprocessing.Queue()
+ self._start_thread(watcher.run, args=(
+ q,
+ 15,
+ ))
+ # If no event in 10s the client will re-watch, so there should be at least 4 threads
+ actual_events = [q.get(), q.get(), q.get(), q.get()]
+ q.close()
+
+ app_names = [e.app_name for e in actual_events]
+ self.assertEqual(app_names, ['t1', 't2', 't1', 't2'])
+
+
+class PodWatcherTest(unittest.TestCase):
+
+ def test_kind(self):
+ watcher = PodWatcher()
+ self.assertEqual(watcher.kind, 'POD')
+
+ @patch('fedlearner_webconsole.k8s.k8s_watcher.Envs.K8S_NAMESPACE', 'fedlearner')
+ @patch('fedlearner_webconsole.k8s.k8s_watcher.Event.from_json', _fake_event_from_json)
+ @patch('fedlearner_webconsole.k8s.k8s_watcher.k8s_client')
+ def test_watch(self, mock_k8s_client: Mock):
+ mock_k8s_client.return_value = MagicMock(core=MagicMock(list_namespaced_pod=MagicMock()))
+ events = [
+ {
+ 'object':
+ V1Pod(metadata=V1ObjectMeta(
+ resource_version='123',
+ owner_references=[
+ V1OwnerReference(
+ api_version='v1',
+ controller=True,
+ kind='Pod',
+ name='test-driver',
+ uid='812c1a48-5585-400f-9174-471d311fbec3',
+ )
+ ],
+ labels={
+ 'sparkoperator.k8s.io/app-name': 'spark-app-name',
+ },
+ )),
+ },
+ {
+ 'object':
+ V1Pod(metadata=V1ObjectMeta(
+ resource_version='234',
+ owner_references=[
+ V1OwnerReference(
+ api_version='v1',
+ controller=True,
+ kind='Pod',
+ name='test-driver',
+ uid='812c1a48-5585-400f-9174-471d311fbec3',
+ )
+ ],
+ labels={
+ 'sparkoperator.k8s.io/app-name': 'spark-app-name',
+ },
+ )),
+ },
+ ]
+ mock_stream = MagicMock(return_value=events)
+ mock_watcher_client = MagicMock(stream=mock_stream)
+
+ watcher = PodWatcher()
+ self.assertEqual(
+ list(watcher._watch(mock_watcher_client, '0')), # pylint: disable=protected-access
+ [
+ ((events[0], ObjectType.POD), '123'),
+ ((events[1], ObjectType.POD), '234'),
+ ])
+ mock_stream.assert_called_once_with(
+ mock_k8s_client.core.list_namespaced_pod,
+ namespace='fedlearner',
+ resource_version='0',
+ _request_timeout=900,
+ )
+
+
+class CrdWatcherTest(unittest.TestCase):
+ WATCHER_CONFIG = CrdWatcherConfig(
+ version='v1alpha1',
+ group='fedlearner.k8s.io',
+ plural='fedapps',
+ object_type=ObjectType.FEDAPP,
+ )
+
+ def test_kind(self):
+ watcher = CrdWatcher(self.WATCHER_CONFIG)
+ self.assertEqual(watcher.kind, 'FEDAPP')
+
+ @patch('fedlearner_webconsole.k8s.k8s_watcher.Envs.K8S_NAMESPACE', 'fedlearner')
+ @patch('fedlearner_webconsole.k8s.k8s_watcher.Event.from_json', _fake_event_from_json)
+ @patch('fedlearner_webconsole.k8s.k8s_watcher.k8s_client')
+ def test_watch(self, mock_k8s_client: Mock):
+ mock_k8s_client.return_value = MagicMock(crds=MagicMock(list_namespaced_custom_object=MagicMock()))
+ events = [
+ {
+ 'object': {
+ 'metadata': {
+ 'resourceVersion': '1111',
+ },
+ },
+ },
+ {
+ 'object': {
+ 'metadata': {
+ 'resourceVersion': '2222',
+ },
+ },
+ },
+ ]
+ mock_stream = MagicMock(return_value=events)
+ mock_watcher_client = MagicMock(stream=mock_stream)
+
+ watcher = CrdWatcher(self.WATCHER_CONFIG)
+ self.assertEqual(
+ list(watcher._watch(mock_watcher_client, '1000')), # pylint: disable=protected-access
+ [
+ ((events[0], ObjectType.FEDAPP), '1111'),
+ ((events[1], ObjectType.FEDAPP), '2222'),
+ ])
+ mock_stream.assert_called_once_with(
+ mock_k8s_client.crds.list_namespaced_custom_object,
+ group=self.WATCHER_CONFIG.group,
+ version=self.WATCHER_CONFIG.version,
+ namespace='fedlearner',
+ plural=self.WATCHER_CONFIG.plural,
+ resource_version='1000',
+ _request_timeout=900,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/models.py b/web_console_v2/api/fedlearner_webconsole/k8s/models.py
index 458f39f81..23969acff 100644
--- a/web_console_v2/api/fedlearner_webconsole/k8s/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/models.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
@@ -17,28 +17,17 @@
from abc import ABCMeta, abstractmethod
from datetime import datetime, timezone
from enum import Enum, unique
-from typing import Optional, List
+from typing import Optional, List, Dict, NamedTuple
+from google.protobuf.json_format import ParseDict
+from kubernetes.client import V1ObjectMeta
+from fedlearner_webconsole.proto.job_pb2 import PodPb
+from fedlearner_webconsole.proto.k8s_pb2 import Condition
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
-# Please keep the value consistent with operator's definition
-@unique
-class PodType(Enum):
- UNKNOWN = 'UNKNOWN'
- # Parameter server
- PS = 'PS'
- # Master worker
- MASTER = 'MASTER'
- WORKER = 'WORKER'
-
- @staticmethod
- def from_value(value: str) -> 'PodType':
- try:
- if isinstance(value, str):
- value = value.upper()
- return PodType(value)
- except ValueError:
- logging.error(f'Unexpected value of PodType: {value}')
- return PodType.UNKNOWN
+class PodMessage(NamedTuple):
+ summary: Optional[str]
+ details: str
@unique
@@ -64,16 +53,30 @@ def from_value(value: str) -> 'PodState':
return PodState.UNKNOWN
+class CrdKind(Enum):
+ FLAPP = 'FLApp'
+ SPARKAPPLICATION = 'SparkApplication'
+ FEDAPP = 'FedApp'
+ UNKNOWN = 'Unknown'
+
+ @staticmethod
+ def from_value(value: str) -> 'CrdKind':
+ try:
+ return CrdKind(value)
+ except ValueError:
+ return CrdKind.UNKNOWN
+
+
class MessageProvider(metaclass=ABCMeta):
+
@abstractmethod
def get_message(self, private: bool = False) -> Optional[str]:
pass
class ContainerState(MessageProvider):
- def __init__(self, state: str,
- message: Optional[str] = None,
- reason: Optional[str] = None):
+
+ def __init__(self, state: str, message: Optional[str] = None, reason: Optional[str] = None):
self.state = state
self.message = message
self.reason = reason
@@ -95,9 +98,8 @@ def __eq__(self, other):
class PodCondition(MessageProvider):
- def __init__(self, cond_type: str,
- message: Optional[str] = None,
- reason: Optional[str] = None):
+
+ def __init__(self, cond_type: str, message: Optional[str] = None, reason: Optional[str] = None):
self.cond_type = cond_type
self.message = message
self.reason = reason
@@ -119,19 +121,24 @@ def __eq__(self, other):
class Pod(object):
+
def __init__(self,
name: str,
state: PodState,
- pod_type: PodType,
+ pod_type: str = 'UNKNOWN',
pod_ip: str = None,
container_states: List[ContainerState] = None,
- pod_conditions: List[PodCondition] = None):
+ pod_conditions: List[PodCondition] = None,
+ creation_timestamp: int = None,
+ status_message: str = None):
self.name = name
self.state = state or PodState.UNKNOWN
self.pod_type = pod_type
self.pod_ip = pod_ip
self.container_states = container_states or []
self.pod_conditions = pod_conditions or []
+ self.creation_timestamp = creation_timestamp or 0
+ self.status_message = status_message or ''
def __eq__(self, other):
if not isinstance(other, Pod):
@@ -149,27 +156,32 @@ def __eq__(self, other):
return self.name == other.name and \
self.state == other.state and \
self.pod_type == other.pod_type and \
- self.pod_ip == other.pod_ip
+ self.pod_ip == other.pod_ip and \
+ self.creation_timestamp == self.creation_timestamp
- def to_dict(self, include_private_info: bool = False):
- # TODO: to reuse to_dict from db.py
- messages = []
+ def to_proto(self, include_private_info: bool = False) -> PodPb:
+
+ return PodPb(name=self.name,
+ pod_type=self.pod_type,
+ state=self.state.name,
+ pod_ip=self.pod_ip,
+ creation_timestamp=self.creation_timestamp,
+ message=self.get_message(include_private_info).details)
+
+ def get_message(self, include_private_info: bool = False) -> PodMessage:
+ summary = None
+ messages = [self.status_message] if self.status_message else []
for container_state in self.container_states:
message = container_state.get_message(include_private_info)
if message is not None:
messages.append(message)
+ if container_state.state == 'terminated':
+ summary = message
for pod_condition in self.pod_conditions:
message = pod_condition.get_message(include_private_info)
if message is not None:
messages.append(message)
-
- return {
- 'name': self.name,
- 'pod_type': self.pod_type.name,
- 'state': self.state.name,
- 'pod_ip': self.pod_ip,
- 'message': ', '.join(messages)
- }
+ return PodMessage(summary=summary, details=', '.join(messages))
@classmethod
def from_json(cls, p: dict) -> 'Pod':
@@ -179,32 +191,83 @@ def from_json(cls, p: dict) -> 'Pod':
master/v1.6.5-standalone/pod.json"""
container_states: List[ContainerState] = []
pod_conditions: List[PodCondition] = []
- if 'containerStatuses' in p['status'] and \
- isinstance(p['status']['containerStatuses'], list) and \
- len(p['status']['containerStatuses']) > 0:
+ if 'container_statuses' in p['status'] and \
+ isinstance(p['status']['container_statuses'], list) and \
+ len(p['status']['container_statuses']) > 0:
for state, detail in \
- p['status']['containerStatuses'][0]['state'].items():
- container_states.append(ContainerState(
- state=state,
- message=detail.get('message'),
- reason=detail.get('reason')
- ))
+ p['status']['container_statuses'][0]['state'].items():
+ # detail may be None, so add a conditional judgement('and')
+ # short-circuit operation
+ container_states.append(
+ ContainerState(state=state,
+ message=detail and detail.get('message'),
+ reason=detail and detail.get('reason')))
if 'conditions' in p['status'] and \
isinstance(p['status']['conditions'], list):
for cond in p['status']['conditions']:
- pod_conditions.append(PodCondition(
- cond_type=cond['type'],
- message=cond.get('message'),
- reason=cond.get('reason')
- ))
- return cls(
- name=p['metadata']['name'],
- pod_type=PodType.from_value(
- p['metadata']['labels']['fl-replica-type']),
- state=PodState.from_value(p['status']['phase']),
- pod_ip=p['status'].get('pod_ip'),
- container_states=container_states,
- pod_conditions=pod_conditions)
+ pod_conditions.append(
+ PodCondition(cond_type=cond['type'], message=cond.get('message'), reason=cond.get('reason')))
+
+ return cls(name=p['metadata']['name'],
+ pod_type=get_pod_type(p),
+ state=PodState.from_value(p['status']['phase']),
+ pod_ip=p['status'].get('pod_ip'),
+ container_states=container_states,
+ pod_conditions=pod_conditions,
+ creation_timestamp=to_timestamp(p['metadata']['creation_timestamp']),
+ status_message=p['status'].get('message'))
+
+
+def get_pod_type(pod: dict) -> str:
+ labels = pod['metadata']['labels']
+ # SparkApplication -> pod.metadata.labels.spark-role
+ # FlApp -> pod.metadata.labels.fl-replica-type
+ pod_type = labels.get('fl-replica-type', None) or labels.get('spark-role', 'UNKNOWN')
+ return pod_type.upper()
+
+
+def get_creation_timestamp_from_k8s_app(app: dict) -> int:
+ if 'metadata' in app and 'creationTimestamp' in app['metadata']:
+ return to_timestamp(app['metadata']['creationTimestamp'])
+ return 0
+
+
+class K8sApp(metaclass=ABCMeta):
+
+ @classmethod
+ @abstractmethod
+ def from_json(cls, app_detail: dict):
+ pass
+
+ @property
+ @abstractmethod
+ def is_completed(self) -> bool:
+ pass
+
+ @property
+ @abstractmethod
+ def is_failed(self) -> bool:
+ pass
+
+ @property
+ @abstractmethod
+ def completed_at(self) -> int:
+ pass
+
+ @property
+ @abstractmethod
+ def pods(self) -> List[Pod]:
+ pass
+
+ @property
+ @abstractmethod
+ def error_message(self) -> Optional[str]:
+ pass
+
+ @property
+ @abstractmethod
+ def creation_timestamp(self) -> int:
+ pass
# Please keep the value consistent with operator's definition
@@ -229,14 +292,19 @@ def from_value(value: str) -> 'FlAppState':
return FlAppState.UNKNOWN
-class FlApp(object):
+class FlApp(K8sApp):
+
def __init__(self,
state: FlAppState = FlAppState.UNKNOWN,
pods: Optional[List[Pod]] = None,
- completed_at: Optional[int] = None):
+ completed_at: Optional[int] = None,
+ creation_timestamp: Optional[int] = None):
self.state = state
- self.pods = pods or []
- self.completed_at = completed_at
+ self._pods = pods or []
+ self._completed_at = completed_at
+ self._is_failed = self.state == FlAppState.FAILED
+ self._is_completed = self.state == FlAppState.COMPLETED
+ self._creation_timestamp = creation_timestamp
def __eq__(self, other):
if not isinstance(other, FlApp):
@@ -250,7 +318,8 @@ def __eq__(self, other):
self.completed_at == other.completed_at
@classmethod
- def from_json(cls, flapp: dict) -> 'FlApp':
+ def from_json(cls, app_detail: dict) -> 'FlApp':
+ flapp = app_detail.get('app', None)
if flapp is None \
or 'status' not in flapp \
or not isinstance(flapp['status'], dict):
@@ -261,24 +330,277 @@ def from_json(cls, flapp: dict) -> 'FlApp':
# Parses pod related info
replicas = flapp['status'].get('flReplicaStatus', {})
for pod_type in replicas:
- for state in ['failed', 'succeeded']:
+ for state in ['active', 'failed', 'succeeded']:
for pod_name in replicas[pod_type].get(state, {}):
+ if state == 'active':
+ pod_state = PodState.RUNNING
if state == 'failed':
pod_state = PodState.FAILED_AND_FREED
- else:
+ if state == 'succeeded':
pod_state = PodState.SUCCEEDED_AND_FREED
- pods.append(Pod(
- name=pod_name,
- pod_type=PodType.from_value(pod_type),
- state=pod_state))
+ pods.append(Pod(name=pod_name, pod_type=pod_type.upper(), state=pod_state))
state = flapp['status'].get('appState')
if flapp['status'].get('completionTime', None):
# Completion time is a iso formatted datetime in UTC timezone
- completed_at = int(datetime.strptime(
- flapp['status']['completionTime'],
- '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=timezone.utc)
- .timestamp())
-
+ completed_at = int(
+ datetime.strptime(flapp['status']['completionTime'],
+ '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=timezone.utc).timestamp())
+
+ name_to_pod = get_pod_dict_from_detail(app_detail.get('pods', {}))
+ for pod in pods:
+ # Only master pod and ps pod use state in flapp,
+ # because they would not immediately exit when flapp is deleted.
+ if pod.name not in name_to_pod:
+ name_to_pod[pod.name] = pod
+ elif pod.pod_type in ['MASTER', 'PS']:
+ name_to_pod[pod.name].state = pod.state
+
+ pods = list(name_to_pod.values())
return cls(state=FlAppState.from_value(state),
pods=pods,
- completed_at=completed_at)
+ completed_at=completed_at,
+ creation_timestamp=get_creation_timestamp_from_k8s_app(flapp))
+
+ @property
+ def is_completed(self) -> bool:
+ return self._is_completed
+
+ @property
+ def is_failed(self) -> bool:
+ return self._is_failed
+
+ @property
+ def completed_at(self) -> int:
+ return self._completed_at or 0
+
+ @property
+ def pods(self) -> List[Pod]:
+ return self._pods
+
+ @property
+ def error_message(self) -> Optional[str]:
+ return None
+
+ @property
+ def creation_timestamp(self) -> int:
+ return self._creation_timestamp or 0
+
+
+@unique
+class SparkAppState(Enum):
+ # state: https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/ \
+ # blob/075e5383e4678ddd70d7f3fdd71904aa3c9113c2 \
+ # /pkg/apis/sparkoperator.k8s.io/v1beta2/types.go#L332
+
+ # core state transition: SUBMITTED -> RUNNING -> COMPLETED/FAILED
+ NEW = ''
+ SUBMITTED = 'SUBMITTED'
+ RUNNING = 'RUNNING'
+ COMPLETED = 'COMPLETED'
+ FAILED = 'FAILED'
+ SUBMISSION_FAILED = 'SUBMISSION_FAILED'
+ PEDNING_RERUN = 'PENDING_RERUN'
+ INVALIDATING = 'INVALIDATING'
+ SUCCEEDING = 'SUCCEEDING'
+ FAILING = 'FAILING'
+ UNKNOWN = 'UNKNOWN'
+
+ @staticmethod
+ def from_value(value: str) -> 'SparkAppState':
+ try:
+ return SparkAppState(value)
+ except ValueError:
+ logging.error(f'Unexpected value of FlAppState: {value}')
+ return SparkAppState.UNKNOWN
+
+
+class SparkApp(K8sApp):
+
+ def __init__(self,
+ pods: List[Pod],
+ state: SparkAppState = SparkAppState.UNKNOWN,
+ completed_at: Optional[int] = None,
+ err_message: Optional[str] = None,
+ creation_timestamp: Optional[int] = None):
+ self.state = state
+ self._completed_at = completed_at
+ self._is_failed = self.state in [SparkAppState.FAILED]
+ self._is_completed = self.state in [SparkAppState.COMPLETED]
+ self._pods = pods
+ self._error_message = err_message
+ self._creation_timestamp = creation_timestamp
+
+ def __eq__(self, other):
+ if not isinstance(other, SparkApp):
+ return False
+ return self.state == other.state and \
+ self.completed_at == other.completed_at
+
+ @classmethod
+ def from_json(cls, app_detail: dict) -> 'SparkApp':
+ sparkapp = app_detail.get('app', None)
+ if sparkapp is None \
+ or 'status' not in sparkapp \
+ or not isinstance(sparkapp['status'], dict):
+ return cls(pods=[])
+
+ status = sparkapp['status']
+ application_state = status.get('applicationState', {})
+ state = application_state.get('state', SparkAppState.UNKNOWN)
+ completed_at: Optional[int] = None
+ termination_time = status.get('terminationTime', None)
+ if termination_time:
+ # Completion time is a iso formatted datetime in UTC timezone
+ completed_at = int(
+ datetime.strptime(termination_time, '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=timezone.utc).timestamp())
+ pods = list(get_pod_dict_from_detail(app_detail.get('pods', {})).values())
+ err_message = application_state.get('errorMessage', None)
+ return cls(state=SparkAppState.from_value(state),
+ completed_at=completed_at,
+ pods=pods,
+ err_message=err_message,
+ creation_timestamp=get_creation_timestamp_from_k8s_app(sparkapp))
+
+ @property
+ def is_completed(self) -> bool:
+ return self._is_completed
+
+ @property
+ def is_failed(self) -> bool:
+ return self._is_failed
+
+ @property
+ def completed_at(self) -> int:
+ return self._completed_at or 0
+
+ @property
+ def pods(self) -> List[Pod]:
+ return self._pods
+
+ @property
+ def error_message(self) -> Optional[str]:
+ return self._error_message
+
+ @property
+ def creation_timestamp(self) -> int:
+ return self._creation_timestamp or 0
+
+
+class FedApp(K8sApp):
+
+ def __init__(self, pods: List[Pod], success_condition: Condition, creation_timestamp: Optional[int] = None):
+ self.success_condition = success_condition
+ self._pods = pods
+ self._completed_at = self.success_condition.last_transition_time and to_timestamp(
+ self.success_condition.last_transition_time)
+ self._is_failed = self.success_condition.status == Condition.FALSE
+ self._is_completed = self.success_condition.status == Condition.TRUE
+ self._creation_timestamp = creation_timestamp
+
+ @classmethod
+ def from_json(cls, app_detail: dict) -> 'FedApp':
+ app = app_detail.get('app', None)
+ if app is None \
+ or 'status' not in app \
+ or not isinstance(app['status'], dict):
+ return cls([], Condition())
+
+ status = app['status']
+ success_condition = Condition()
+ for c in status.get('conditions', []):
+ c_proto: Condition = ParseDict(c, Condition())
+ if c_proto.type == Condition.SUCCEEDED:
+ success_condition = c_proto
+ pods = list(get_pod_dict_from_detail(app_detail.get('pods', {})).values())
+ return cls(success_condition=success_condition,
+ pods=pods,
+ creation_timestamp=get_creation_timestamp_from_k8s_app(app))
+
+ @property
+ def is_completed(self) -> bool:
+ return self._is_completed
+
+ @property
+ def is_failed(self) -> bool:
+ return self._is_failed
+
+ @property
+ def completed_at(self) -> int:
+ return self._completed_at or 0
+
+ @property
+ def pods(self) -> List[Pod]:
+ return self._pods
+
+ @property
+ def error_message(self) -> str:
+ return f'{self.success_condition.reason}: {self.success_condition.message}'
+
+ @property
+ def creation_timestamp(self) -> int:
+ return self._creation_timestamp or 0
+
+
+class UnknownCrd(K8sApp):
+
+ @classmethod
+ def from_json(cls, app_detail: dict) -> 'UnknownCrd':
+ return UnknownCrd()
+
+ @property
+ def is_completed(self) -> bool:
+ return False
+
+ @property
+ def is_failed(self) -> bool:
+ return False
+
+ @property
+ def completed_at(self) -> int:
+ return 0
+
+ @property
+ def pods(self) -> List[Pod]:
+ return []
+
+ @property
+ def error_message(self) -> Optional[str]:
+ return None
+
+ @property
+ def creation_timestamp(self) -> Optional[str]:
+ return None
+
+
+def get_pod_dict_from_detail(pod_detail: dict) -> Dict[str, Pod]:
+ """
+ Generate name to Pod dict from pod json detail which got from pod cache.
+ """
+ name_to_pod = {}
+ pods_json = pod_detail.get('items', [])
+ for p in pods_json:
+ pod = Pod.from_json(p)
+ name_to_pod[pod.name] = pod
+ return name_to_pod
+
+
+def get_app_name_from_metadata(metadata: V1ObjectMeta) -> Optional[str]:
+ """Extracts the CR app name from the metadata.
+
+ Basically the metadata is from k8s watch event, we only care about the events
+ related with CRs, so we will check owner references."""
+ owner_refs = metadata.owner_references or []
+ if not owner_refs:
+ return None
+
+ # Spark app uses labels to get app name instead of owner references,
+ # because executors' owner reference will be driver, not the spark app.
+ labels = metadata.labels or {}
+ sparkapp_name = labels.get('sparkoperator.k8s.io/app-name', None)
+ if sparkapp_name:
+ return sparkapp_name
+
+ owner = owner_refs[0]
+ if CrdKind.from_value(owner.kind) == CrdKind.UNKNOWN:
+ return None
+ return owner.name
diff --git a/web_console_v2/api/fedlearner_webconsole/k8s/models_test.py b/web_console_v2/api/fedlearner_webconsole/k8s/models_test.py
new file mode 100644
index 000000000..85e98f12e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/k8s/models_test.py
@@ -0,0 +1,374 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from datetime import datetime, timezone
+
+from kubernetes.client import V1ObjectMeta, V1OwnerReference
+
+from fedlearner_webconsole.k8s.models import PodState, ContainerState, \
+ PodCondition, Pod, FlAppState, FlApp, SparkApp, SparkAppState, FedApp, get_app_name_from_metadata, PodMessage
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.proto.job_pb2 import PodPb
+
+
+class PodStateTest(unittest.TestCase):
+
+ def test_from_string(self):
+ self.assertEqual(PodState.from_value('Running'), PodState.RUNNING)
+ self.assertEqual(PodState.from_value('Unknown'), PodState.UNKNOWN)
+
+ def test_from_unknown(self):
+ self.assertEqual(PodState.from_value('hhhhhhhh'), PodState.UNKNOWN)
+
+
+class ContainerStateTest(unittest.TestCase):
+
+ def test_get_message(self):
+ state = ContainerState(state='haha', message='test message', reason='test reason')
+ self.assertEqual(state.get_message(), 'haha:test reason')
+ self.assertEqual(state.get_message(private=True), 'haha:test message')
+ state.message = None
+ self.assertEqual(state.get_message(), 'haha:test reason')
+ self.assertEqual(state.get_message(private=True), 'haha:test reason')
+
+
+class PodConditionTest(unittest.TestCase):
+
+ def test_get_message(self):
+ cond = PodCondition(cond_type='t1', message='test message', reason='test reason')
+ self.assertEqual(cond.get_message(), 't1:test reason')
+ self.assertEqual(cond.get_message(private=True), 't1:test message')
+ cond.message = None
+ self.assertEqual(cond.get_message(), 't1:test reason')
+ self.assertEqual(cond.get_message(private=True), 't1:test reason')
+
+
+class PodTest(unittest.TestCase):
+
+ def test_to_proto(self):
+ pod = Pod(name='this-is-a-pod',
+ state=PodState.RUNNING,
+ pod_type='WORKER',
+ pod_ip='172.10.0.20',
+ container_states=[ContainerState(state='h1', message='test message')],
+ pod_conditions=[PodCondition(cond_type='h2', reason='test reason')],
+ creation_timestamp=100)
+ self.assertEqual(
+ pod.to_proto(include_private_info=True),
+ PodPb(
+ name='this-is-a-pod',
+ pod_type='WORKER',
+ state='RUNNING',
+ pod_ip='172.10.0.20',
+ message='h1:test message, h2:test reason',
+ creation_timestamp=100,
+ ))
+
+ def test_from_json(self):
+ creation_timestamp = datetime.utcnow()
+ json = {
+ 'metadata': {
+ 'name': 'test-pod',
+ 'labels': {
+ 'app-name': 'u244777dac51949c5b2b-data-join-job',
+ 'fl-replica-type': 'master'
+ },
+ 'creation_timestamp': creation_timestamp,
+ },
+ 'status': {
+ 'pod_ip':
+ '172.10.0.20',
+ 'message':
+ 'test',
+ 'phase':
+ 'Running',
+ 'conditions': [{
+ 'type': 'Failed',
+ 'reason': 'Test reason'
+ }],
+ 'container_statuses': [{
+ 'containerID':
+ 'docker://034eaf58d4e24581232832661636da9949b6e2fb05\
+ 6398939fc2c0f2809d4c64',
+ 'image':
+ 'artifact.bytedance.com/fedlearner/fedlearner:438d603',
+ 'state': {
+ 'running': {
+ 'message': 'Test message'
+ }
+ }
+ }]
+ },
+ 'spec': {
+ 'containers': [{
+ 'name': 'test-container',
+ 'resources': {
+ 'limits': {
+ 'cpu': '2000m',
+ 'memory': '4Gi',
+ },
+ 'requests': {
+ 'cpu': '2000m',
+ 'memory': '4Gi',
+ }
+ }
+ }]
+ }
+ }
+ expected_pod = Pod(name='test-pod',
+ state=PodState.RUNNING,
+ pod_type='MASTER',
+ pod_ip='172.10.0.20',
+ container_states=[ContainerState(state='running', message='Test message')],
+ pod_conditions=[PodCondition(cond_type='Failed', reason='Test reason')],
+ creation_timestamp=to_timestamp(creation_timestamp),
+ status_message='test')
+ self.assertEqual(Pod.from_json(json), expected_pod)
+
+ def test_get_message(self):
+ pod = Pod(name='test',
+ state=PodState.FAILED,
+ container_states=[
+ ContainerState(state='terminated', message='0101010'),
+ ContainerState(state='running', message='11')
+ ])
+ self.assertEqual(pod.get_message(True),
+ PodMessage(summary='terminated:0101010', details='terminated:0101010, running:11'))
+ self.assertEqual(pod.get_message(False), PodMessage(summary=None, details=''))
+ pod.container_states = [ContainerState(state='terminated')]
+ self.assertEqual(pod.get_message(True), PodMessage(summary=None, details=''))
+
+
+class FlAppStateTest(unittest.TestCase):
+
+ def test_from_string(self):
+ self.assertEqual(FlAppState.from_value('FLStateComplete'), FlAppState.COMPLETED)
+ self.assertEqual(FlAppState.from_value('Unknown'), FlAppState.UNKNOWN)
+
+ def test_from_unknown(self):
+ self.assertEqual(FlAppState.from_value('hhh123hhh'), FlAppState.UNKNOWN)
+
+
+class FlAppTest(unittest.TestCase):
+
+ def test_from_json(self):
+ json = {
+ 'app': {
+ 'metadata': {
+ 'creationTimestamp': '2022-09-27T09:07:01Z',
+ },
+ 'status': {
+ 'appState': 'FLStateComplete',
+ 'completionTime': '2021-04-26T08:33:45Z',
+ 'flReplicaStatus': {
+ 'Master': {
+ 'active': {
+ 'test-pod1': {}
+ },
+ 'failed': {
+ 'test-pod2': {}
+ },
+ },
+ 'Worker': {
+ 'succeeded': {
+ 'test-pod3': {},
+ 'test-pod4': {}
+ }
+ }
+ }
+ }
+ }
+ }
+ completed_at = int(datetime(2021, 4, 26, 8, 33, 45, tzinfo=timezone.utc).timestamp())
+ expected_flapp = FlApp(state=FlAppState.COMPLETED,
+ completed_at=completed_at,
+ pods=[
+ Pod(name='test-pod1', state=PodState.RUNNING, pod_type='MASTER'),
+ Pod(name='test-pod2', state=PodState.FAILED_AND_FREED, pod_type='MASTER'),
+ Pod(name='test-pod3', state=PodState.SUCCEEDED_AND_FREED, pod_type='WORKER'),
+ Pod(name='test-pod4', state=PodState.SUCCEEDED_AND_FREED, pod_type='WORKER')
+ ])
+ actual_flapp = FlApp.from_json(json)
+ self.assertEqual(actual_flapp, expected_flapp)
+ self.assertEqual(actual_flapp.is_completed, True)
+ self.assertEqual(actual_flapp.completed_at, 1619426025)
+ self.assertEqual(actual_flapp.creation_timestamp, 1664269621)
+
+
+class SparkAppTest(unittest.TestCase):
+
+ def test_from_json(self):
+ json = {
+ 'app': {
+ 'metadata': {
+ 'creationTimestamp': '2022-09-27T09:07:01Z',
+ },
+ 'status': {
+ 'applicationState': {
+ 'state': 'COMPLETED',
+ 'errorMessage': 'OOMKilled'
+ },
+ 'driverInfo': {
+ 'podName': 'fl-transformer-yaml-driver',
+ },
+ 'executionAttempts': 1,
+ 'executorState': {
+ 'fedlearnertransformer-4a859f78d5210f41-exec-1': 'RUNNING'
+ },
+ 'lastSubmissionAttemptTime': '2021-04-15T10:43:28Z',
+ 'sparkApplicationId': 'spark-adade63e9071431881d6a16666ec1c87',
+ 'submissionAttempts': 1,
+ 'submissionID': '37a07c69-516b-48fe-ae70-701eec529eda',
+ 'terminationTime': '2021-04-15T10:43:53Z'
+ }
+ }
+ }
+
+ completed_at = int(datetime(2021, 4, 15, 10, 43, 53, tzinfo=timezone.utc).timestamp())
+ expected_sparkapp = SparkApp(state=SparkAppState.COMPLETED, completed_at=completed_at, pods=[])
+ actual_sparkapp = SparkApp.from_json(json)
+ self.assertEqual(actual_sparkapp, expected_sparkapp)
+ self.assertEqual(actual_sparkapp.is_completed, True)
+ self.assertEqual(actual_sparkapp.is_failed, False)
+ self.assertEqual(actual_sparkapp.completed_at, 1618483433)
+ self.assertEqual(actual_sparkapp.error_message, 'OOMKilled')
+ self.assertEqual(actual_sparkapp.creation_timestamp, 1664269621)
+
+
+class FedAppTest(unittest.TestCase):
+
+ def test_from_json(self):
+ json = {
+ 'app': {
+ 'metadata': {
+ 'creationTimestamp': '2022-09-27T09:07:01Z',
+ },
+ 'status': {
+ 'conditions': [{
+ 'type': 'succeeded',
+ 'status': 'False',
+ 'lastTransitionTime': '2022-01-17T12:06:33Z',
+ 'reason': 'OutOfLimitation',
+ 'message': 'detail'
+ }]
+ }
+ }
+ }
+ fed_app = FedApp.from_json(json)
+ self.assertEqual(fed_app.is_failed, True)
+ self.assertEqual(fed_app.is_completed, False)
+ self.assertEqual(fed_app.completed_at, 1642421193)
+ self.assertEqual(fed_app.error_message, 'OutOfLimitation: detail')
+ self.assertEqual(fed_app.creation_timestamp, 1664269621)
+ json = {'app': {'status': {'conditions': []}}}
+ fed_app = FedApp.from_json(json)
+ self.assertEqual(fed_app.is_failed, False)
+ self.assertEqual(fed_app.is_completed, False)
+ self.assertEqual(fed_app.completed_at, 0)
+ self.assertEqual(fed_app.creation_timestamp, 0)
+
+
+class GetAppNameFromMetadataTest(unittest.TestCase):
+
+ def test_pure_pod(self):
+ metadata = V1ObjectMeta(
+ name='test-pod',
+ namespace='fedlearner',
+ )
+ self.assertIsNone(get_app_name_from_metadata(metadata))
+
+ def test_sparkapp(self):
+ metadata = V1ObjectMeta(
+ name='test-driver',
+ namespace='fedlearner',
+ owner_references=[
+ V1OwnerReference(
+ api_version='sparkoperator.k8s.io/v1beta2',
+ controller=True,
+ kind='SparkApplication',
+ name='spark-app-name',
+ uid='812c1a48-5585-400f-9174-471d311fbec3',
+ )
+ ],
+ labels={
+ 'sparkoperator.k8s.io/app-name': 'spark-app-name',
+ 'sparkoperator.k8s.io/launched-by-spark-operator': 'true',
+ },
+ )
+ self.assertEqual(get_app_name_from_metadata(metadata), 'spark-app-name')
+ metadata = V1ObjectMeta(
+ name='test-executor',
+ namespace='fedlearner',
+ owner_references=[
+ V1OwnerReference(
+ api_version='v1',
+ controller=True,
+ kind='Pod',
+ name='test-driver',
+ uid='812c1a48-5585-400f-9174-471d311fbec3',
+ )
+ ],
+ labels={
+ 'sparkoperator.k8s.io/app-name': 'spark-app-name',
+ },
+ )
+ self.assertEqual(get_app_name_from_metadata(metadata), 'spark-app-name')
+
+ def test_fedapp(self):
+ metadata = V1ObjectMeta(name='test-pod',
+ namespace='default',
+ owner_references=[
+ V1OwnerReference(
+ api_version='fedlearner.k8s.io/v1alpha1',
+ controller=True,
+ kind='FedApp',
+ name='test-fedapp-job',
+ uid='bcf5324c-aa2b-4918-bdee-42ac464e18d5',
+ )
+ ])
+ self.assertEqual(get_app_name_from_metadata(metadata), 'test-fedapp-job')
+
+ def test_flapp(self):
+ metadata = V1ObjectMeta(name='test-pod',
+ namespace='fedlearner',
+ owner_references=[
+ V1OwnerReference(
+ api_version='fedlearner.k8s.io/v1alpha1',
+ controller=True,
+ kind='FLApp',
+ name='u130eaab6eec64552945-nn-model',
+ uid='bcf5324c-aa2b-4918-bdee-42ac464e18d5',
+ )
+ ])
+ self.assertEqual(get_app_name_from_metadata(metadata), 'u130eaab6eec64552945-nn-model')
+
+ def test_unknown_metadata(self):
+ metadata = V1ObjectMeta(name='test-pod',
+ namespace='fedlearner',
+ owner_references=[
+ V1OwnerReference(
+ api_version='fedlearner.k8s.io/v1alpha1',
+ controller=True,
+ kind='NewApp',
+ name='u130eaab6eec64552945-nn-model',
+ uid='bcf5324c-aa2b-4918-bdee-42ac464e18d5',
+ )
+ ])
+ self.assertIsNone(get_app_name_from_metadata(metadata))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/middleware/BUILD.bazel
new file mode 100644
index 000000000..9f51a2ad8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/BUILD.bazel
@@ -0,0 +1,80 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "api_latency_lib",
+ srcs = ["api_latency.py"],
+ imports = ["../.."],
+ deps = [
+ "@common_flask//:pkg",
+ "@common_opentelemetry_instrumentation_flask//:pkg",
+ ],
+)
+
+py_test(
+ name = "api_latency_test",
+ srcs = [
+ "api_latency_test.py",
+ ],
+ imports = ["../.."],
+ main = "api_latency_test.py",
+ deps = [
+ ":api_latency_lib",
+ "@common_flask//:pkg",
+ "@common_flask_testing//:pkg",
+ "@common_opentelemetry_sdk//:pkg",
+ ],
+)
+
+py_library(
+ name = "log_filter_lib",
+ srcs = ["log_filter.py"],
+ imports = ["../.."],
+ deps = [
+ ":middlewares_lib",
+ ":request_id_lib",
+ ],
+)
+
+py_test(
+ name = "log_filter_test",
+ srcs = [
+ "log_filter_test.py",
+ ],
+ imports = ["../.."],
+ main = "log_filter_test.py",
+ deps = [
+ ":log_filter_lib",
+ ],
+)
+
+py_library(
+ name = "middlewares_lib",
+ srcs = ["middlewares.py"],
+ imports = ["../.."],
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "request_id_lib",
+ srcs = ["request_id.py"],
+ imports = ["../.."],
+ deps = [
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@common_flask//:pkg",
+ ],
+)
+
+py_test(
+ name = "request_id_test",
+ srcs = [
+ "request_id_test.py",
+ ],
+ imports = ["../.."],
+ main = "request_id_test.py",
+ deps = [
+ ":request_id_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/api_latency.py b/web_console_v2/api/fedlearner_webconsole/middleware/api_latency.py
new file mode 100644
index 000000000..96ac176b7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/api_latency.py
@@ -0,0 +1,22 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from flask import Flask
+from opentelemetry.instrumentation.flask import FlaskInstrumentor
+
+
+def api_latency_middleware(app: Flask) -> Flask:
+ FlaskInstrumentor().instrument_app(app)
+ return app
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/api_latency_test.py b/web_console_v2/api/fedlearner_webconsole/middleware/api_latency_test.py
new file mode 100644
index 000000000..dfaf806cd
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/api_latency_test.py
@@ -0,0 +1,78 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import unittest
+import flask_testing
+from flask import Flask
+from io import StringIO
+from opentelemetry import trace
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor, ConsoleSpanExporter
+
+from fedlearner_webconsole.middleware.api_latency import api_latency_middleware
+
+
+class ApiLatencyTest(flask_testing.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._string_io = StringIO()
+ trace.set_tracer_provider(
+ TracerProvider(resource=Resource.create({'service.name': 'test_api_lantency'}),
+ active_span_processor=SimpleSpanProcessor(ConsoleSpanExporter(out=self._string_io))))
+
+ def create_app(self):
+ app = Flask('test_api_lantency')
+
+ @app.route('/test', methods=['GET'])
+ def test():
+ return {'data': 'Hello'}
+
+ app = api_latency_middleware(app)
+ return app
+
+ def test_api_latency(self):
+ get_response = self.client.get('/test')
+ self.assertEqual(get_response.json, {'data': 'Hello'})
+ span = json.loads(self._string_io.getvalue())
+ self.assertEqual(span['name'], '/test')
+ self.assertEqual(span['kind'], 'SpanKind.SERVER')
+ self.assertEqual(
+ span['attributes'], {
+ 'http.method': 'GET',
+ 'http.server_name': 'localhost',
+ 'http.scheme': 'http',
+ 'net.host.port': 80,
+ 'http.host': 'localhost',
+ 'http.target': '/test',
+ 'net.peer.ip': '127.0.0.1',
+ 'http.user_agent': 'werkzeug/1.0.1',
+ 'http.flavor': '1.1',
+ 'http.route': '/test',
+ 'http.status_code': 200
+ })
+ self.assertEqual(
+ span['resource'], {
+ 'telemetry.sdk.language': 'python',
+ 'telemetry.sdk.name': 'opentelemetry',
+ 'telemetry.sdk.version': '1.10.0',
+ 'service.name': 'test_api_lantency'
+ })
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/log_filter.py b/web_console_v2/api/fedlearner_webconsole/middleware/log_filter.py
new file mode 100644
index 000000000..82c234728
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/log_filter.py
@@ -0,0 +1,27 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+
+from fedlearner_webconsole.middleware.request_id import get_current_request_id
+
+
+class RequestIdLogFilter(logging.Filter):
+ """Log filter to inject the current request id.
+ """
+
+ def filter(self, record) -> bool:
+ record.request_id = get_current_request_id()
+ return True
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/log_filter_test.py b/web_console_v2/api/fedlearner_webconsole/middleware/log_filter_test.py
new file mode 100644
index 000000000..be3ccf0d8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/log_filter_test.py
@@ -0,0 +1,33 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import MagicMock, patch
+
+from fedlearner_webconsole.middleware.log_filter import RequestIdLogFilter
+
+
+class RequestIdLogFilterTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.middleware.log_filter.get_current_request_id')
+ def test_attach_request_id(self, mock_get_current_request_id):
+ mock_get_current_request_id.return_value = '123'
+ log_record = MagicMock()
+ self.assertEqual(RequestIdLogFilter().filter(log_record), True)
+ self.assertEqual(log_record.request_id, '123')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/middlewares.py b/web_console_v2/api/fedlearner_webconsole/middleware/middlewares.py
new file mode 100644
index 000000000..e44248163
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/middlewares.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+
+
+class _MiddlewareRegistry(object):
+
+ def __init__(self):
+ self.middlewares = []
+
+ def register(self, middleware):
+ self.middlewares.append(middleware)
+
+ def init_app(self, app):
+ logging.info('Initializing app with middlewares')
+ # Wraps app with middlewares
+ for middleware in self.middlewares:
+ app = middleware(app)
+ return app
+
+
+flask_middlewares = _MiddlewareRegistry()
+wsgi_middlewares = _MiddlewareRegistry()
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/request_id.py b/web_console_v2/api/fedlearner_webconsole/middleware/request_id.py
new file mode 100644
index 000000000..fb9548560
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/request_id.py
@@ -0,0 +1,137 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import random
+import threading
+from abc import ABCMeta, abstractmethod
+from datetime import datetime
+from typing import Optional, List, Tuple, Union
+
+import grpc
+from flask import Flask, Response, g, request, has_request_context
+
+
+class RequestIdContext(metaclass=ABCMeta):
+
+ @abstractmethod
+ def is_current_context(self) -> bool:
+ pass
+
+ @abstractmethod
+ def set_request_id(self, request_id: str):
+ pass
+
+ @abstractmethod
+ def get_request_id(self) -> Optional[str]:
+ pass
+
+
+class FlaskRequestIdContext(RequestIdContext):
+
+ def is_current_context(self) -> bool:
+ return has_request_context()
+
+ def set_request_id(self, request_id: str):
+ g.request_id = request_id
+
+ def get_request_id(self) -> Optional[str]:
+ # Defensively getting request id from flask.g
+ if hasattr(g, 'request_id'):
+ return g.request_id
+ return None
+
+
+thread_local = threading.local()
+
+
+class ThreadLocalContext(RequestIdContext):
+
+ def is_current_context(self) -> bool:
+ return hasattr(thread_local, 'request_id')
+
+ def set_request_id(self, request_id: str):
+ thread_local.request_id = request_id
+
+ def get_request_id(self) -> Optional[str]:
+ # Defensively getting request id
+ if hasattr(thread_local, 'request_id'):
+ return thread_local.request_id
+ return None
+
+
+_flask_request_id_context = FlaskRequestIdContext()
+_thread_local_context = ThreadLocalContext()
+
+
+def _gen_request_id() -> str:
+ # Random number in 4 digits
+ r = f'{random.randint(0, 9999):04}'
+ dt = datetime.now().strftime('%Y%m%d%H%M%S-%f')
+ return f'{dt}-{r}'
+
+
+class FlaskRequestId(object):
+
+ def __init__(self, header_name='X-TT-LOGID'):
+ self.header_name = header_name
+
+ def __call__(self, app: Flask) -> Flask:
+ app.before_request(self._set_request_id)
+ app.after_request(self._add_header)
+ return app
+
+ def _set_request_id(self):
+ # Gets existing request id or generate a new one
+ request_id = request.headers.get(self.header_name) or \
+ _gen_request_id()
+ _flask_request_id_context.set_request_id(request_id)
+
+ def _add_header(self, response: Response) -> Response:
+ response.headers[self.header_name] = \
+ _flask_request_id_context.get_request_id()
+ return response
+
+
+class GrpcRequestIdMiddleware(object):
+ REQUEST_HEADER_NAME = 'x-tt-logid'
+
+ @classmethod
+ def add_header(cls, metadata: List[Tuple[str, Union[str, bytes]]]):
+ """Appends request id in metadata."""
+ # From existing request id in context or generates a new one
+ request_id = get_current_request_id() or _gen_request_id()
+ metadata.append((cls.REQUEST_HEADER_NAME, request_id))
+
+ # Sets thread local context if we get a request id
+ _thread_local_context.set_request_id(request_id)
+ return metadata
+
+ @classmethod
+ def set_request_id_in_context(cls, context: grpc.ServicerContext):
+ """Sets request id to thread local context for gRPC service."""
+ for key, value in context.invocation_metadata():
+ if key == cls.REQUEST_HEADER_NAME:
+ # Sets context per gRPC metadata
+ _thread_local_context.set_request_id(value)
+ return
+
+
+def get_current_request_id() -> str:
+ request_id = None
+ if _flask_request_id_context.is_current_context():
+ request_id = _flask_request_id_context.get_request_id()
+ elif _thread_local_context.is_current_context():
+ request_id = _thread_local_context.get_request_id()
+ return request_id or ''
diff --git a/web_console_v2/api/fedlearner_webconsole/middleware/request_id_test.py b/web_console_v2/api/fedlearner_webconsole/middleware/request_id_test.py
new file mode 100644
index 000000000..e95900a75
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/middleware/request_id_test.py
@@ -0,0 +1,70 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import threading
+import time
+import unittest
+
+from fedlearner_webconsole.middleware.request_id import FlaskRequestId, _thread_local_context, get_current_request_id
+from testing.common import BaseTestCase
+
+
+class FlaskRequestIdTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ # Wraps with middleware
+ self.app = FlaskRequestId()(self.app)
+
+ @self.app.route('/test', methods=['GET'])
+ def test_api():
+ return ''
+
+ def test_response_with_request_id(self):
+ response = self.client.get('/test')
+ self.assertEqual(len(response.headers['X-TT-LOGID']), 26, 'request id should be an uuid')
+
+ def test_request_with_request_id(self):
+ response = self.client.get('/test', headers={'X-TT-LOGID': 'test-id'})
+ self.assertEqual(response.headers['X-TT-LOGID'], 'test-id')
+
+
+class ThreadLocalContextTest(unittest.TestCase):
+
+ def test_multi_thread_context(self):
+ ids = {}
+
+ def process(index: str):
+ if not index == 't1':
+ _thread_local_context.set_request_id(index)
+ time.sleep(0.2)
+ ids[index] = get_current_request_id()
+
+ # t1 executes first
+ # t2 and t3 will be in parallel
+ t1 = threading.Thread(target=process, args=['t1'])
+ t1.start()
+ t1.join()
+ t2 = threading.Thread(target=process, args=['t2'])
+ t3 = threading.Thread(target=process, args=['t3'])
+ t2.start()
+ t3.start()
+ t3.join()
+ t2.join()
+ self.assertDictEqual(ids, {'t1': '', 't2': 't2', 't3': 't3'})
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/mmgr/BUILD.bazel
new file mode 100644
index 000000000..65ea2f9f1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/BUILD.bazel
@@ -0,0 +1,384 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "model_job_configer_lib",
+ srcs = ["model_job_configer.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "model_job_configer_lib_test",
+ size = "small",
+ srcs = [
+ "model_job_configer_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_job_configer_test.py",
+ deps = [
+ ":model_job_configer_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:utils_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "service_lib",
+ srcs = ["service.py"],
+ imports = ["../.."],
+ deps = [
+ ":model_job_configer_lib",
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr/metrics:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "service_lib_test",
+ size = "medium",
+ srcs = [
+ "service_test.py",
+ ],
+ imports = ["../.."],
+ main = "service_test.py",
+ deps = [
+ ":service_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_library(
+ name = "controller_lib",
+ srcs = [
+ "controller.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:transaction_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_job_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "scheduler_lib",
+ srcs = ["scheduler.py"],
+ imports = ["../.."],
+ deps = [
+ ":controller_lib",
+ ":model_job_configer_lib",
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "scheduler_lib_test",
+ size = "small",
+ srcs = [
+ "scheduler_test.py",
+ ],
+ imports = ["../.."],
+ main = "scheduler_test.py",
+ deps = [
+ ":scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ ],
+)
+
+py_test(
+ name = "controller_lib_test",
+ size = "medium",
+ srcs = [
+ "controller_test.py",
+ ],
+ imports = ["../.."],
+ main = "controller_test.py",
+ deps = [
+ ":controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ ],
+)
+
+py_test(
+ name = "utils_lib_test",
+ size = "small",
+ srcs = [
+ "utils_test.py",
+ ],
+ imports = ["../.."],
+ main = "utils_test.py",
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "cronjob_lib",
+ srcs = [
+ "cronjob.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":controller_lib",
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "cronjob_lib_test",
+ size = "medium",
+ srcs = [
+ "cronjob_test.py",
+ ],
+ imports = ["../.."],
+ main = "cronjob_test.py",
+ deps = [
+ ":cronjob_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = [
+ "model_apis.py",
+ "model_job_apis.py",
+ "model_job_group_apis.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":controller_lib",
+ ":model_job_configer_lib",
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/scheduler:scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:sorting_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_job_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask//:pkg",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_sqlalchemy//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "model_apis_lib_test",
+ size = "medium",
+ srcs = [
+ "model_apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "model_job_apis_lib_test",
+ size = "medium",
+ srcs = [
+ "model_job_apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_job_apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ ],
+)
+
+py_test(
+ name = "model_job_group_apis_lib_test",
+ size = "medium",
+ srcs = [
+ "model_job_group_apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_job_group_apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "event_listener_lib",
+ srcs = ["event_listener.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:event_listener_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_cache_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ ],
+)
+
+py_test(
+ name = "event_listener_lib_test",
+ size = "small",
+ srcs = [
+ "event_listener_test.py",
+ ],
+ imports = ["../.."],
+ main = "event_listener_test.py",
+ deps = [
+ ":event_listener_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/__init__.py b/web_console_v2/api/fedlearner_webconsole/mmgr/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/apis.py b/web_console_v2/api/fedlearner_webconsole/mmgr/apis.py
deleted file mode 100644
index cbae04a07..000000000
--- a/web_console_v2/api/fedlearner_webconsole/mmgr/apis.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-
-from http import HTTPStatus
-from flask import request
-from flask_restful import Resource
-from fedlearner_webconsole.db import db_handler
-from fedlearner_webconsole.exceptions import NotFoundException
-from fedlearner_webconsole.mmgr.models import Model, ModelType, ModelGroup
-from fedlearner_webconsole.mmgr.service import ModelService
-from fedlearner_webconsole.utils.decorators import jwt_required
-
-
-class ModelApi(Resource):
- @jwt_required()
- def get(self, model_id):
- detail_level = request.args.get('detail_level', '')
- with db_handler.session_scope() as session:
- model_json = ModelService(session).query(model_id, detail_level)
- if not model_json:
- raise NotFoundException(
- f'Failed to find model: {model_id}')
- return {'data': model_json}, HTTPStatus.OK
-
- @jwt_required()
- def put(self, model_id):
- with db_handler.session_scope() as session:
- model = session.query(Model).filter_by(id=model_id).one_or_none()
- if not model:
- raise NotFoundException(
- f'Failed to find model: {model_id}')
- model.extra = request.args.get('extra', model.extra)
- session.commit()
- return {'data': model.to_dict()}, HTTPStatus.OK
-
- @jwt_required()
- def delete(self, model_id):
- with db_handler.session_scope() as session:
- model = ModelService(session).drop(model_id)
- if not model:
- raise NotFoundException(
- f'Failed to find model: {model_id}')
- return {'data': model.to_dict()}, HTTPStatus.OK
-
-
-class ModelListApi(Resource):
- @jwt_required()
- def get(self):
- detail_level = request.args.get('detail_level', '')
- # TODO serialized query may incur performance penalty
- with db_handler.session_scope() as session:
- model_list = [
- ModelService(session).query(m.id, detail_level)
- for m in Model.query.filter(
- Model.type.in_([
- ModelType.NN_MODEL.value, ModelType.TREE_MODEL.value
- ])).all()
- ]
- return {'data': model_list}, HTTPStatus.OK
-
-
-class GroupListApi(Resource):
- @jwt_required()
- def get(self):
- group_list = [o.to_dict() for o in ModelGroup.query.all()]
- return {'data': group_list}, HTTPStatus.OK
-
- @jwt_required()
- def post(self):
- group = ModelGroup()
-
- group.name = request.args.get('name', group.name)
- group.extra = request.args.get('extra', group.extra)
- with db_handler.session_scope() as session:
- session.add(group)
- session.commit()
-
- return {'data': group.to_dict()}, HTTPStatus.OK
-
-
-class GroupApi(Resource):
- @jwt_required()
- def patch(self, group_id):
- group = ModelGroup.query.filter_by(id=group_id).one_or_none()
- if not group:
- raise NotFoundException(
- f'Failed to find group: {group_id}')
-
- group.name = request.args.get('name', group.name)
- group.extra = request.args.get('extra', group.extra)
- with db_handler.session_scope() as session:
- session.add(group)
- session.commit()
-
- return {'data': group.to_dict()}, HTTPStatus.OK
-
-
-def initialize_mmgr_apis(api):
- api.add_resource(ModelListApi, '/models')
- api.add_resource(ModelApi, '/models/')
-
- api.add_resource(GroupListApi, '/model_groups')
- api.add_resource(GroupApi, '/model_groups/')
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/controller.py b/web_console_v2/api/fedlearner_webconsole/mmgr/controller.py
new file mode 100644
index 000000000..40b2d7275
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/controller.py
@@ -0,0 +1,321 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import logging
+from google.protobuf import json_format
+from typing import Tuple, Optional
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, DataBatch
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobGroup, ModelJobType, GroupCreateStatus, \
+ GroupAutoUpdateStatus
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.mmgr.service import check_model_job_group, ModelJobGroupService
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.two_pc.transaction_manager import TransactionManager
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TransactionData, CreateModelJobData, \
+ CreateModelJobGroupData
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGroupPb
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.rpc.v2.job_service_client import JobServiceClient
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+from fedlearner_webconsole.workflow.workflow_job_controller import start_workflow, stop_workflow
+from fedlearner_webconsole.workflow_template.utils import set_value
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.flag.models import Flag
+
+
+def _get_transaction_manager(project_id: int, two_pc_type: TwoPcType) -> TransactionManager:
+ with db.session_scope() as session:
+ project = session.query(Project).get(project_id)
+ participants = ParticipantService(session).get_platform_participants_by_project(project_id)
+ tm = TransactionManager(project_name=project.name,
+ project_token=project.token,
+ two_pc_type=two_pc_type,
+ participants=[participant.domain_name for participant in participants])
+ return tm
+
+
+class CreateModelJob:
+
+ @staticmethod
+ def _create_model_job_by_2pc(project_id: int,
+ name: str,
+ model_job_type: ModelJobType,
+ algorithm_type: AlgorithmType,
+ coordinator_pure_domain_name: str,
+ dataset_id: Optional[int] = None,
+ model_id: Optional[int] = None,
+ group_id: Optional[int] = None) -> Tuple[bool, str]:
+ tm = _get_transaction_manager(project_id=project_id, two_pc_type=TwoPcType.CREATE_MODEL_JOB)
+ with db.session_scope() as session:
+ project_name = session.query(Project).get(project_id).name
+ dataset_uuid = None
+ if dataset_id is not None:
+ dataset_uuid = session.query(Dataset).get(dataset_id).uuid
+ model_uuid = None
+ if model_id is not None:
+ model_uuid = session.query(Model).get(model_id).uuid
+ group_name = None
+ if group_id is not None:
+ group_name = session.query(ModelJobGroup).get(group_id).name
+ model_job_uuid = resource_uuid()
+ workflow_uuid = model_job_uuid
+ succeeded, message = tm.run(
+ TransactionData(
+ create_model_job_data=CreateModelJobData(model_job_name=name,
+ model_job_uuid=model_job_uuid,
+ model_job_type=model_job_type.name,
+ group_name=group_name,
+ algorithm_type=algorithm_type.name,
+ workflow_uuid=workflow_uuid,
+ model_uuid=model_uuid,
+ project_name=project_name,
+ coordinator_pure_domain_name=coordinator_pure_domain_name,
+ dataset_uuid=dataset_uuid)))
+ return succeeded, message
+
+ def run(self,
+ project_id: int,
+ name: str,
+ model_job_type: ModelJobType,
+ algorithm_type: AlgorithmType,
+ coordinator_pure_domain_name: str,
+ dataset_id: Optional[int],
+ model_id: Optional[int] = None,
+ group_id: Optional[int] = None) -> Tuple[bool, str]:
+ # no need create model job at participants when eval or predict horizontal model
+ if algorithm_type in [AlgorithmType.TREE_VERTICAL, AlgorithmType.NN_VERTICAL
+ ] or model_job_type == ModelJobType.TRAINING:
+ succeeded, msg = self._create_model_job_by_2pc(project_id=project_id,
+ name=name,
+ model_job_type=model_job_type,
+ algorithm_type=algorithm_type,
+ coordinator_pure_domain_name=coordinator_pure_domain_name,
+ dataset_id=dataset_id,
+ model_id=model_id,
+ group_id=group_id)
+ return succeeded, msg
+ with db.session_scope() as session:
+ model_job = ModelJob(name=name,
+ group_id=group_id,
+ project_id=project_id,
+ model_job_type=model_job_type,
+ algorithm_type=algorithm_type,
+ model_id=model_id)
+ model_job.uuid = resource_uuid()
+ model_job.workflow_uuid = model_job.uuid
+ session.add(model_job)
+ session.commit()
+ return True, ''
+
+
+class CreateModelJobGroup:
+
+ @staticmethod
+ def run(project_id: int, name: str, algorithm_type: AlgorithmType, dataset_id: Optional[str],
+ coordinator_pure_domain_name: str, model_job_group_uuid: str) -> Tuple[bool, str]:
+ with db.session_scope() as session:
+ project_name = session.query(Project).get(project_id).name
+ dataset_uuid = None
+ if dataset_id is not None:
+ dataset_uuid = session.query(Dataset).get(dataset_id).uuid
+ tm = _get_transaction_manager(project_id=project_id, two_pc_type=TwoPcType.CREATE_MODEL_JOB_GROUP)
+ create_model_job_group_data = CreateModelJobGroupData(model_job_group_name=name,
+ model_job_group_uuid=model_job_group_uuid,
+ project_name=project_name,
+ algorithm_type=algorithm_type.name,
+ coordinator_pure_domain_name=coordinator_pure_domain_name,
+ dataset_uuid=dataset_uuid)
+ succeeded, msg = tm.run(data=TransactionData(create_model_job_group_data=create_model_job_group_data))
+ return succeeded, msg
+
+
+class LaunchModelJob:
+
+ @staticmethod
+ def run(project_id: int, group_id: int, version: int):
+ tm = _get_transaction_manager(project_id=project_id, two_pc_type=TwoPcType.LAUNCH_MODEL_JOB)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(group_id)
+ model_job_name = f'{group.name}-v{group.latest_version}'
+ data = TransactionData(create_model_job_data=CreateModelJobData(
+ model_job_name=model_job_name, model_job_uuid=resource_uuid(), group_uuid=group.uuid, version=version))
+ succeeded, msg = tm.run(data)
+ return succeeded, msg
+
+
+class ModelJobGroupController:
+
+ def __init__(self, session: Session, project_id: int):
+ self._session = session
+ self._clients = []
+ self._participants = ParticipantService(self._session).get_participants_by_project(project_id)
+ self._project = self._session.query(Project).get(project_id)
+ for p in self._participants:
+ self._clients.append(JobServiceClient.from_project_and_participant(p.domain_name, self._project.name))
+
+ def inform_auth_status_to_participants(self, group: ModelJobGroup):
+ participants_info = group.get_participants_info()
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ participants_info.participants_map[pure_domain_name].auth_status = group.auth_status.name
+ group.set_participants_info(participants_info)
+ for client, p in zip(self._clients, self._participants):
+ try:
+ client.inform_model_job_group(group.uuid, group.auth_status)
+ except grpc.RpcError as e:
+ logging.warning(f'[model-job-group] failed to inform participant {p.id}\'s '
+ f'model job group {group.uuid} with grpc code {e.code()} and details {e.details()}')
+
+ def update_participants_model_job_group(self,
+ uuid: str,
+ auto_update_status: Optional[GroupAutoUpdateStatus] = None,
+ start_data_batch_id: Optional[int] = None):
+ start_dataset_job_stage_uuid = None
+ if start_data_batch_id:
+ start_dataset_job_stage_uuid = self._session.query(DataBatch).get(
+ start_data_batch_id).latest_parent_dataset_job_stage.uuid
+ for client, p in zip(self._clients, self._participants):
+ try:
+ client.update_model_job_group(uuid=uuid,
+ auto_update_status=auto_update_status,
+ start_dataset_job_stage_uuid=start_dataset_job_stage_uuid)
+ except grpc.RpcError as e:
+ logging.warning(f'[model-job-group] failed to update participant {p.id}\'s '
+ f'model job group {uuid} with grpc code {e.code()} and details {e.details()}')
+
+ def update_participants_auth_status(self, group: ModelJobGroup):
+ participants_info = group.get_participants_info()
+ for client, p in zip(self._clients, self._participants):
+ try:
+ resp = client.get_model_job_group(group.uuid)
+ if resp.auth_status:
+ auth_status = resp.auth_status
+ else:
+ # Use 'authorized' if the field 'auth_status' is not in the ModelJobGroupPb of the opposite side
+ if resp.authorized:
+ auth_status = AuthStatus.AUTHORIZED.name
+ else:
+ auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map[p.pure_domain_name()].auth_status = auth_status
+ except grpc.RpcError as e:
+ logging.warning(f'[model-job-group] failed to get participant {p.id}\'s '
+ f'model job group {group.uuid} with grpc code {e.code()} and details {e.details()}')
+ group.set_participants_info(participants_info)
+ self._session.commit()
+
+ def get_model_job_group_from_participant(self, participant_id: int,
+ model_job_group_uuid: str) -> Optional[ModelJobGroupPb]:
+ resp = None
+ for client, p in zip(self._clients, self._participants):
+ if p.id == participant_id:
+ try:
+ resp = client.get_model_job_group(uuid=model_job_group_uuid)
+ system_client = SystemServiceClient.from_participant(domain_name=p.domain_name)
+ flag_resp = system_client.list_flags()
+ break
+ except grpc.RpcError as e:
+ logging.warning(
+ f'[model-job-group] failed to get participant {p.id}\'s '
+ f'model job group {model_job_group_uuid} with grpc code {e.code()} and details {e.details()}')
+ if resp and len(resp.config.job_definitions) and flag_resp.get(Flag.MODEL_JOB_GLOBAL_CONFIG_ENABLED.name):
+ variables = resp.config.job_definitions[0].variables
+ for variable in variables:
+ if variable.name == 'algorithm':
+ algo_dict = json_format.MessageToDict(variable.typed_value)
+ algo = AlgorithmFetcher(self._project.id).get_algorithm(algo_dict['algorithmUuid'])
+ algo_dict['algorithmId'] = algo.id
+ algo_dict['participantId'] = algo.participant_id
+ algo_dict['algorithmProjectId'] = algo.algorithm_project_id
+ set_value(variable=variable, typed_value=algo_dict)
+ break
+ return resp
+
+ def create_model_job_group_for_participants(self, model_job_group_id: int):
+ group = self._session.query(ModelJobGroup).get(model_job_group_id)
+ for client, p in zip(self._clients, self._participants):
+ try:
+ client.create_model_job_group(name=group.name,
+ uuid=group.uuid,
+ algorithm_type=group.algorithm_type,
+ dataset_uuid=group.dataset.uuid,
+ algorithm_project_list=group.get_algorithm_project_uuid_list())
+ except grpc.RpcError as e:
+ logging.warning(f'[model-job-group] failed to create model job group for the participant {p.id} '
+ f'with grpc code {e.code()} and details {e.details()}')
+ group.status = GroupCreateStatus.FAILED
+ return
+ group.status = GroupCreateStatus.SUCCEEDED
+
+
+class ModelJobController:
+
+ def __init__(self, session: Session, project_id: int):
+ self._session = session
+ self._client = []
+ self._participants = ParticipantService(self._session).get_participants_by_project(project_id)
+ self._project_id = project_id
+ project = self._session.query(Project).get(project_id)
+ for p in self._participants:
+ self._client.append(JobServiceClient.from_project_and_participant(p.domain_name, project.name))
+
+ def launch_model_job(self, group_id: int) -> ModelJob:
+ check_model_job_group(self._project_id, group_id, self._session)
+ group = ModelJobGroupService(self._session).lock_and_update_version(group_id)
+ self._session.commit()
+ succeeded, msg = LaunchModelJob().run(project_id=self._project_id,
+ group_id=group_id,
+ version=group.latest_version)
+ if not succeeded:
+ raise InternalException(f'launching model job by 2PC with message: {msg}')
+ model_job = self._session.query(ModelJob).filter_by(group_id=group_id, version=group.latest_version).first()
+ return model_job
+
+ def inform_auth_status_to_participants(self, model_job: ModelJob):
+ for client, p in zip(self._client, self._participants):
+ try:
+ client.inform_model_job(model_job.uuid, model_job.auth_status)
+ except grpc.RpcError as e:
+ logging.warning(f'[model-job] failed to inform participants {p.id}\'s model job '
+ f'{model_job.uuid} with grpc code {e.code()} and details {e.details()}')
+
+ def update_participants_auth_status(self, model_job: ModelJob):
+ participants_info = model_job.get_participants_info()
+ for client, p in zip(self._client, self._participants):
+ try:
+ resp = client.get_model_job(model_job.uuid)
+ participants_info.participants_map[p.pure_domain_name()].auth_status = resp.auth_status
+ except grpc.RpcError as e:
+ logging.warning(f'[model-job] failed to get participant {p.id}\'s model job {model_job.uuid} '
+ f'with grpc code {e.code()} and details {e.details()}')
+ model_job.set_participants_info(participants_info)
+
+
+# TODO(gezhengqiang): provide start model job rpc
+def start_model_job(model_job_id: int):
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(model_job_id)
+ start_workflow(workflow_id=model_job.workflow_id)
+
+
+def stop_model_job(model_job_id: int):
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(model_job_id)
+ stop_workflow(workflow_id=model_job.workflow_id)
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/controller_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/controller_test.py
new file mode 100644
index 000000000..bd6e26183
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/controller_test.py
@@ -0,0 +1,397 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch, call
+
+import grpc
+from google.protobuf import json_format
+from google.protobuf.empty_pb2 import Empty
+from datetime import datetime
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.fake_model_job_config import get_workflow_config
+from testing.rpc.client import FakeRpcError
+
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmType, Source
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobState, DatasetJobKind, DatasetType, \
+ DatasetJobStage, DataBatch
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobGroup, ModelJobType, ModelJobRole, GroupCreateStatus, \
+ GroupAutoUpdateStatus, AuthStatus as ModelJobAuthStatus
+from fedlearner_webconsole.mmgr.controller import start_model_job, stop_model_job, ModelJobGroupController, \
+ ModelJobController
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmPb
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGroupPb, AlgorithmProjectList, ModelJobPb
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.workflow_template.utils import set_value
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+
+
+class StartModelJobTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ session.add(ModelJob(id=1, name='name', uuid='uuid', workflow_id=2))
+ session.commit()
+
+ @patch('fedlearner_webconsole.mmgr.controller.start_workflow')
+ def test_start_model_job(self, mock_start_workflow: MagicMock):
+ start_model_job(model_job_id=1)
+ mock_start_workflow.assert_called_with(workflow_id=2)
+
+ @patch('fedlearner_webconsole.mmgr.controller.stop_workflow')
+ def test_stop_model_job(self, mock_stop_workflow: MagicMock):
+ stop_model_job(model_job_id=1)
+ mock_stop_workflow.assert_called_with(workflow_id=2)
+
+
+class ModelJobGroupControllerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part1', domain_name='fl-demo1.com')
+ participant2 = Participant(id=2, name='part2', domain_name='fl-demo2.com')
+ dataset = Dataset(id=1, name='dataset', uuid='dataset_uuid')
+ pro_part1 = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ pro_part2 = ProjectParticipant(id=2, project_id=1, participant_id=2)
+ group = ModelJobGroup(id=1,
+ name='group',
+ uuid='uuid',
+ project_id=1,
+ dataset_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL)
+ dataset_job_stage = DatasetJobStage(id=1,
+ name='data_join',
+ uuid='stage_uuid',
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ dataset_job_id=1,
+ data_batch_id=1)
+ data_batch = DataBatch(id=1,
+ name='20221213',
+ dataset_id=1,
+ path='/data/dataset/haha/batch/20221213',
+ latest_parent_dataset_job_stage_id=1,
+ event_time=datetime(2022, 12, 13, 16, 37, 37))
+ algorithm_project_list = AlgorithmProjectList()
+ algorithm_project_list.algorithm_projects['test'] = 'algorithm-project-uuid1'
+ algorithm_project_list.algorithm_projects['part1'] = 'algorithm-project-uuid2'
+ algorithm_project_list.algorithm_projects['part2'] = 'algorithm-project-uuid3'
+ group.set_algorithm_project_uuid_list(algorithm_project_list)
+ algo = Algorithm(id=1, algorithm_project_id=1, name='test-algo', uuid='uuid')
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['demo0'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['demo1'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['demo2'].auth_status = AuthStatus.PENDING.name
+ group.set_participants_info(participants_info)
+ session.add_all([
+ project, participant1, participant2, pro_part1, pro_part2, group, algo, dataset, dataset_job_stage,
+ data_batch
+ ])
+ session.commit()
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_model_job_group')
+ def test_inform_auth_status_to_participants(self, mock_client: MagicMock, mock_system_info: MagicMock):
+ system_info = SystemInfo()
+ system_info.pure_domain_name = 'demo0'
+ mock_system_info.return_value = system_info
+ mock_client.return_value = Empty()
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ group.auth_status = AuthStatus.AUTHORIZED
+ ModelJobGroupController(session, 1).inform_auth_status_to_participants(group)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map[system_info.pure_domain_name].auth_status,
+ AuthStatus.AUTHORIZED.name)
+ self.assertEqual(mock_client.call_args_list, [(('uuid', AuthStatus.AUTHORIZED),),
+ (('uuid', AuthStatus.AUTHORIZED),)])
+ # fail due to grpc abort
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'model job group uuid is not found')
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ group.auth_status = AuthStatus.PENDING
+ ModelJobGroupController(session, 1).inform_auth_status_to_participants(group)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map[system_info.pure_domain_name].auth_status,
+ AuthStatus.PENDING.name)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.update_model_job_group')
+ def test_update_participants_model_job_group(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ group.auto_update_status = GroupAutoUpdateStatus.ACTIVE
+ group.start_data_batch_id = 1
+ ModelJobGroupController(session, 1).update_participants_model_job_group(
+ uuid=group.uuid,
+ auto_update_status=group.auto_update_status,
+ start_data_batch_id=group.start_data_batch_id)
+ self.assertEqual(mock_client.call_args_list, [
+ call(uuid='uuid',
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE,
+ start_dataset_job_stage_uuid='stage_uuid'),
+ call(uuid='uuid',
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE,
+ start_dataset_job_stage_uuid='stage_uuid')
+ ])
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_model_job_group')
+ def test_update_participants_auth_status(self, mock_client: MagicMock):
+ mock_client.side_effect = [
+ ModelJobGroupPb(auth_status=AuthStatus.AUTHORIZED.name),
+ ModelJobGroupPb(auth_status=AuthStatus.AUTHORIZED.name)
+ ]
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ ModelJobGroupController(session, 1).update_participants_auth_status(group)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map['demo1'].auth_status, AuthStatus.AUTHORIZED.name)
+ self.assertEqual(participants_info.participants_map['demo2'].auth_status, AuthStatus.AUTHORIZED.name)
+ # if the 'auth_status' is not in ModelJobGroupPb
+ mock_client.side_effect = [ModelJobGroupPb(authorized=False), ModelJobGroupPb(authorized=False)]
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ ModelJobGroupController(session, 1).update_participants_auth_status(group)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map['demo1'].auth_status, AuthStatus.PENDING.name)
+ self.assertEqual(participants_info.participants_map['demo2'].auth_status, AuthStatus.PENDING.name)
+ # fail due to grpc abort
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'model job group uuid is not found')
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ ModelJobGroupController(session, 1).update_participants_auth_status(group)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map['demo1'].auth_status, AuthStatus.PENDING.name)
+ self.assertEqual(participants_info.participants_map['demo2'].auth_status, AuthStatus.PENDING.name)
+
+ @patch('fedlearner_webconsole.rpc.v2.system_service_client.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.algorithm.fetcher.AlgorithmFetcher.get_algorithm_from_participant')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_model_job_group')
+ def test_get_model_job_group_from_participant(self, mock_client: MagicMock, mock_algo_fetcher: MagicMock,
+ mock_list_flags: MagicMock):
+ algo_dict1 = {
+ 'algorithmId': 1,
+ 'algorithmUuid': 'uuid',
+ 'algorithmProjectId': 1,
+ 'algorithmProjectUuid': 'project_uuid',
+ 'participantId': 0,
+ 'path': '/path'
+ }
+ variable = Variable(name='algorithm')
+ set_value(variable=variable, typed_value=algo_dict1)
+ config = WorkflowDefinition(job_definitions=[JobDefinition(variables=[variable])])
+ mock_client.return_value = ModelJobGroupPb(name='group', uuid='uuid', config=config)
+ mock_list_flags.return_value = {'model_job_global_config_enabled': True}
+ with db.session_scope() as session:
+ resp = ModelJobGroupController(session, 1).get_model_job_group_from_participant(1, 'uuid')
+ self.assertEqual(resp.name, 'group')
+ self.assertEqual(resp.uuid, 'uuid')
+ variables = resp.config.job_definitions[0].variables
+ for variable in variables:
+ if variable.name == 'algorithm':
+ self.assertEqual(json_format.MessageToDict(variable.typed_value), algo_dict1)
+ algo_dict2 = {
+ 'algorithmId': 2,
+ 'algorithmUuid': 'peer-uuid',
+ 'algorithmProjectId': 1,
+ 'algorithmProjectUuid': 'project_uuid',
+ 'participantId': 2,
+ 'path': '/path'
+ }
+ set_value(variable, typed_value=algo_dict2)
+ config = WorkflowDefinition(job_definitions=[JobDefinition(variables=[variable])])
+ mock_client.return_value = ModelJobGroupPb(name='group', uuid='uuid', config=config)
+ mock_algo_fetcher.return_value = AlgorithmPb(name='test-peer-algo',
+ uuid='peer-uuid',
+ participant_id=1,
+ source=Source.PARTICIPANT.name)
+ with db.session_scope() as session:
+ resp = ModelJobGroupController(session, 1).get_model_job_group_from_participant(1, 'uuid')
+ variables = resp.config.job_definitions[0].variables
+ algo_dict2['algorithmId'] = 0
+ algo_dict2['algorithmProjectId'] = 0
+ algo_dict2['participantId'] = 1
+ for variable in variables:
+ if variable.name == 'algorithm':
+ self.assertEqual(json_format.MessageToDict(variable.typed_value), algo_dict2)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_model_job_group')
+ def test_create_model_job_group_for_participants(self, mock_client: MagicMock):
+ with db.session_scope() as session:
+ ModelJobGroupController(session, 1).create_model_job_group_for_participants(1)
+ session.commit()
+ algorithm_project_list = AlgorithmProjectList()
+ algorithm_project_list.algorithm_projects['test'] = 'algorithm-project-uuid1'
+ algorithm_project_list.algorithm_projects['part1'] = 'algorithm-project-uuid2'
+ algorithm_project_list.algorithm_projects['part2'] = 'algorithm-project-uuid3'
+ mock_client.assert_called_with(name='group',
+ uuid='uuid',
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ dataset_uuid='dataset_uuid',
+ algorithm_project_list=algorithm_project_list)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ self.assertEqual(group.status, GroupCreateStatus.SUCCEEDED)
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT, 'dataset with uuid is not found')
+ with db.session_scope() as session:
+ ModelJobGroupController(session, 1).create_model_job_group_for_participants(1)
+ session.commit()
+ mock_client.assert_called()
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ self.assertEqual(group.status, GroupCreateStatus.FAILED)
+
+
+class ModelJobControllerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ project = Project(id=1, name='test-project')
+ participant1 = Participant(id=1, name='part1', domain_name='fl-demo1.com')
+ participant2 = Participant(id=2, name='part2', domain_name='fl-demo2.com')
+ pro_part1 = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ pro_part2 = ProjectParticipant(id=2, project_id=1, participant_id=2)
+ dataset_job = DatasetJob(id=1,
+ name='datasetjob',
+ uuid='uuid',
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=3,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN)
+ dataset = Dataset(id=3,
+ uuid='uuid',
+ name='datasetjob',
+ dataset_type=DatasetType.PSI,
+ path='/data/dataset/haha')
+ algorithm = Algorithm(id=2, name='algorithm')
+ group = ModelJobGroup(id=1,
+ name='group',
+ uuid='uuid',
+ project_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_id=2,
+ role=ModelJobRole.COORDINATOR,
+ dataset_id=3)
+ model_job = ModelJob(id=1,
+ name='model_job',
+ uuid='uuid',
+ project_id=1,
+ auth_status=ModelJobAuthStatus.AUTHORIZED)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ model_job.set_participants_info(participants_info)
+ group.set_config(get_workflow_config(ModelJobType.TRAINING))
+ session.add_all([
+ dataset_job, dataset, project, group, algorithm, participant1, participant2, pro_part1, pro_part2,
+ model_job
+ ])
+ session.commit()
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_launch_model_job(self, mock_remote_do_two_pc):
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).filter_by(uuid='uuid').first()
+ mock_remote_do_two_pc.return_value = True, ''
+ with db.session_scope() as session:
+ ModelJobController(session=session, project_id=1).launch_model_job(group_id=1)
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(name='group').first()
+ model_job = group.model_jobs[0]
+ self.assertEqual(model_job.group_id, group.id)
+ self.assertTrue(model_job.project_id, group.project_id)
+ self.assertEqual(model_job.version, 1)
+ self.assertEqual(group.latest_version, 1)
+ self.assertTrue(model_job.algorithm_type, group.algorithm_type)
+ self.assertTrue(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertTrue(model_job.dataset_id, group.dataset_id)
+ self.assertTrue(model_job.workflow.get_config(), group.get_config())
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_model_job')
+ def test_inform_auth_status_to_participants(self, mock_inform_model_job: MagicMock):
+ mock_inform_model_job.return_value = Empty()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ ModelJobController(session, 1).inform_auth_status_to_participants(model_job)
+ self.assertEqual(mock_inform_model_job.call_args_list, [(('uuid', ModelJobAuthStatus.AUTHORIZED),),
+ (('uuid', ModelJobAuthStatus.AUTHORIZED),)])
+ # fail due to grpc abort
+ mock_inform_model_job.reset_mock()
+ mock_inform_model_job.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'model job uuid is not found')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ ModelJobController(session, 1).inform_auth_status_to_participants(model_job)
+ self.assertEqual(mock_inform_model_job.call_args_list, [(('uuid', ModelJobAuthStatus.AUTHORIZED),),
+ (('uuid', ModelJobAuthStatus.AUTHORIZED),)])
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_model_job')
+ def test_get_participants_auth_status(self, mock_get_model_job: MagicMock):
+ mock_get_model_job.side_effect = [
+ ModelJobPb(auth_status=AuthStatus.AUTHORIZED.name),
+ ModelJobPb(auth_status=AuthStatus.AUTHORIZED.name)
+ ]
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ ModelJobController(session, 1).update_participants_auth_status(model_job)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ self.assertEqual(model_job.get_participants_info(), participants_info)
+ # fail due to grpc abort
+ mock_get_model_job.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'model job uuid is not found')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ ModelJobController(session, 1).update_participants_auth_status(model_job)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ self.assertEqual(model_job.get_participants_info(), participants_info)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/cronjob.py b/web_console_v2/api/fedlearner_webconsole/mmgr/cronjob.py
new file mode 100644
index 000000000..6c08c047f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/cronjob.py
@@ -0,0 +1,102 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple, Optional
+from sqlalchemy.orm import Session
+from google.protobuf.struct_pb2 import Value
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, ModelTrainingCronJobOutput
+from fedlearner_webconsole.mmgr.service import ModelJobGroupService
+from fedlearner_webconsole.mmgr.controller import LaunchModelJob
+from fedlearner_webconsole.mmgr.models import ModelJobGroup
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.proto.common_pb2 import Variable
+
+LOAD_MODEL_NAME = 'load_model_name'
+
+
+class ModelTrainingCronJob(IRunnerV2):
+ """Launch model job periodically."""
+
+ @staticmethod
+ def _set_load_model_name(config: WorkflowDefinition, job_name: str):
+ """Set variable of load_model_name inplace"""
+ assert len(config.job_definitions) == 1
+ for variable in config.job_definitions[0].variables:
+ if variable.name == LOAD_MODEL_NAME:
+ assert variable.value_type == Variable.ValueType.STRING
+ variable.value = job_name
+ variable.typed_value.MergeFrom(Value(string_value=job_name))
+
+ def _update_local_and_peer_config(self, session: Session, group_id: int):
+ group: ModelJobGroup = session.query(ModelJobGroup).get(group_id)
+ model_job = group.latest_completed_job()
+ if model_job is None:
+ return
+ job_name = model_job.job_name
+ config = group.get_config()
+ self._set_load_model_name(config=config, job_name=job_name)
+ group.set_config(config)
+ for party in group.project.participants:
+ client = RpcClient.from_project_and_participant(project_name=group.project.name,
+ project_token=group.project.token,
+ domain_name=party.domain_name)
+ config = client.get_model_job_group(model_job_group_uuid=group.uuid).config
+ self._set_load_model_name(config=config, job_name=job_name)
+ client.update_model_job_group(model_job_group_uuid=group.uuid, config=config)
+
+ def _check_peer_auth_status(self, session: Session, group_id: int) -> Tuple[bool, Optional[str]]:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(group_id)
+ for party in group.project.participants:
+ client = RpcClient.from_project_and_participant(project_name=group.project.name,
+ project_token=group.project.token,
+ domain_name=party.domain_name)
+ resp = client.get_model_job_group(model_job_group_uuid=group.uuid)
+ if not resp.authorized:
+ return False, party.domain_name
+ return True, None
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ output = ModelTrainingCronJobOutput()
+ group_id = context.input.model_training_cron_job_input.group_id
+ with db.session_scope() as session:
+ authorized, domain_name = self._check_peer_auth_status(session=session, group_id=group_id)
+ if not authorized:
+ message = f'party {domain_name} is not authorized for group {group_id}'
+ logging.warning(f'[ModelTrainingCronJob] {message}')
+ return RunnerStatus.FAILED, RunnerOutput(error_message=message)
+ group = ModelJobGroupService(session).lock_and_update_version(group_id)
+ session.commit()
+ with db.session_scope() as session:
+ self._update_local_and_peer_config(session, group.id)
+ session.commit()
+ succeeded, msg = LaunchModelJob().run(project_id=group.project_id,
+ group_id=group_id,
+ version=group.latest_version)
+ if not succeeded:
+ message = f'launching model job for group {group_id} by 2PC with message: {msg}'
+ output.message = message
+ logging.warning(f'[ModelTrainingCronJob] {message}')
+ return RunnerStatus.FAILED, RunnerOutput(model_training_cron_job_output=output)
+ message = f'succeeded in launch model job for group {group_id}'
+ logging.info(f'[ModelTrainingCronJob] {message}')
+ output.message = message
+ return RunnerStatus.DONE, RunnerOutput(model_training_cron_job_output=output)
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/cronjob_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/cronjob_test.py
new file mode 100644
index 000000000..53e0c6b18
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/cronjob_test.py
@@ -0,0 +1,91 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch
+from google.protobuf.struct_pb2 import Value
+
+from testing.common import NoWebServerTestCase
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.mmgr.cronjob import ModelTrainingCronJob
+from fedlearner_webconsole.mmgr.models import ModelJobGroup, ModelJob
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.composer_pb2 import ModelTrainingCronJobInput, RunnerInput
+from fedlearner_webconsole.proto.service_pb2 import GetModelJobGroupResponse
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+
+
+def _get_workflow_definition() -> WorkflowDefinition:
+ return WorkflowDefinition(
+ job_definitions=[JobDefinition(name='nn-model', variables=[Variable(name='load_model_name')])])
+
+
+class ModelTrainingCronJobTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=2, name='project')
+ party = Participant(id=3, name='test', domain_name='fl-test.com')
+ relationship = ProjectParticipant(project_id=2, participant_id=3)
+ job = Job(id=1,
+ name='uuid-nn-model',
+ job_type=JobType.NN_MODEL_TRANINING,
+ state=JobState.COMPLETED,
+ workflow_id=1,
+ project_id=2)
+ workflow = Workflow(id=1, name='workflow', uuid='uuid', state=WorkflowState.COMPLETED, project_id=2)
+ model_job = ModelJob(name='group-v1', workflow_uuid=workflow.uuid, job_name=job.name, group_id=1, project_id=2)
+ group = ModelJobGroup(id=1, name='group', uuid='uuid', latest_version=2, project_id=2)
+ group.set_config(_get_workflow_definition())
+ with db.session_scope() as session:
+ session.add_all([project, party, relationship, job, workflow, model_job, group])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.get_model_job_group')
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.update_model_job_group')
+ @patch('fedlearner_webconsole.mmgr.controller.LaunchModelJob.run')
+ def test_run(self, mock_run, mock_update_group, mock_get_group):
+ context = RunnerContext(index=0,
+ input=RunnerInput(model_training_cron_job_input=ModelTrainingCronJobInput(group_id=1)))
+ mock_run.return_value = True, ''
+ mock_get_group.return_value = GetModelJobGroupResponse(config=_get_workflow_definition(), authorized=False)
+ runner_status, runner_output = ModelTrainingCronJob().run(context)
+ # fail due to peer is not authorized
+ self.assertEqual(runner_status, RunnerStatus.FAILED)
+ self.assertEqual(runner_output.error_message, 'party fl-test.com is not authorized for group 1')
+ # succeeded
+ mock_get_group.return_value = GetModelJobGroupResponse(config=_get_workflow_definition(), authorized=True)
+ runner_status, runner_output = ModelTrainingCronJob().run(context)
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+ mock_run.assert_called_with(project_id=2, group_id=1, version=3)
+ config = _get_workflow_definition()
+ config.job_definitions[0].variables[0].typed_value.MergeFrom(Value(string_value='uuid-nn-model'))
+ config.job_definitions[0].variables[0].value = 'uuid-nn-model'
+ mock_update_group.assert_called_with(config=config, model_job_group_uuid='uuid')
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ self.assertEqual(group.latest_version, 3)
+ self.assertEqual(group.get_config(), config)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/event_listener.py b/web_console_v2/api/fedlearner_webconsole/mmgr/event_listener.py
new file mode 100644
index 000000000..5770a9907
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/event_listener.py
@@ -0,0 +1,74 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.k8s.k8s_cache import Event, EventType, ObjectType
+from fedlearner_webconsole.job.models import Job, JobState
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobType
+from fedlearner_webconsole.mmgr.service import ModelService
+from fedlearner_webconsole.k8s.event_listener import EventListener
+
+
+def _is_model_event(event: Event) -> bool:
+ return event.obj_type in [ObjectType.FLAPP, ObjectType.FEDAPP
+ ] and event.event_type in [EventType.MODIFIED, EventType.DELETED]
+
+
+class ModelEventListener(EventListener):
+
+ def update(self, event: Event):
+ if not _is_model_event(event):
+ return
+ job_name = event.app_name
+ with db.session_scope() as session:
+ job: Job = session.query(Job).filter_by(name=job_name).first()
+ logging.debug('[ModelEventListener] job: %s, type: %s, state: %s', job.name, job.job_type, job.state)
+ if job is None:
+ emit_store('job_not_found', 1)
+ logging.warning('[ModelEventListener] job %s is not found', job_name)
+ return
+ if not job.is_training_job():
+ logging.debug(f'[ModelEventListener] stop creating model due to job {job.name} is not training')
+ return
+ if not job.state == JobState.COMPLETED:
+ logging.debug(f'[ModelEventListener] stop creating model due to job {job.name} is not completed')
+ return
+ model = session.query(Model).filter_by(job_id=job.id).first()
+ if model is not None:
+ logging.debug(
+ f'[ModelEventListener] stop creating model due to model is already created for job {job.name}')
+ return
+ model_job: ModelJob = session.query(ModelJob).filter_by(job_name=job.name).first()
+ if model_job is None:
+ logging.info(f'[ModelEventListener] stop creating model due to {job.name} is not a model job')
+ return
+ if model_job.model_job_type not in [ModelJobType.TRAINING, ModelJobType.EVALUATION]:
+ logging.info(f'[ModelEventListener] stop creating model due to model job {model_job.name} '
+ 'is not training or evaluation')
+ return
+ service = ModelService(session)
+ service.create_model_from_model_job(model_job=model_job)
+ logging.info(f'[ModelEventListener] model for job {job.name} is created')
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(job_name=job.name).first()
+ model = session.query(Model).filter_by(model_job_id=model_job.id).first()
+ if model is not None:
+ model_job.model_id = model.id
+ session.add(model_job)
+ session.commit()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/event_listener_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/event_listener_test.py
new file mode 100644
index 000000000..47201ea63
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/event_listener_test.py
@@ -0,0 +1,102 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobType
+from fedlearner_webconsole.mmgr.event_listener import _is_model_event, ModelEventListener
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.k8s.k8s_cache import Event, EventType, ObjectType
+
+
+class UtilsTest(NoWebServerTestCase):
+
+ def test_is_event_relevant(self):
+ self.assertFalse(
+ _is_model_event(Event(app_name='test', event_type=EventType.ADDED, obj_type=ObjectType.FLAPP, obj_dict={})))
+ self.assertFalse(
+ _is_model_event(Event(app_name='test', event_type=EventType.MODIFIED, obj_type=ObjectType.POD,
+ obj_dict={})))
+ self.assertTrue(
+ _is_model_event(
+ Event(app_name='test', event_type=EventType.MODIFIED, obj_type=ObjectType.FLAPP, obj_dict={})))
+
+
+class ListenerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ workflow = Workflow(id=1, name='test-workflow')
+ job = Job(name='test-job',
+ project_id=1,
+ job_type=JobType.NN_MODEL_TRANINING,
+ state=JobState.COMPLETED,
+ workflow_id=1)
+ session.add_all([workflow, job])
+ session.commit()
+
+ @staticmethod
+ def _get_job(session) -> Job:
+ return session.query(Job).filter_by(name='test-job').first()
+
+ @staticmethod
+ def _get_event() -> Event:
+ return Event(app_name='test-job', event_type=EventType.MODIFIED, obj_type=ObjectType.FLAPP, obj_dict={})
+
+ @patch('fedlearner_webconsole.mmgr.service.ModelService.create_model_from_model_job')
+ def test_model_update(self, mock_create_model: MagicMock):
+ event = self._get_event()
+ with db.session_scope() as session:
+ job = self._get_job(session)
+ job.state = JobState.STOPPED
+ session.commit()
+ ModelEventListener().update(event)
+ # not called since job state is stopped
+ mock_create_model.assert_not_called()
+ with db.session_scope() as session:
+ job = self._get_job(session)
+ job.state = JobState.COMPLETED
+ session.commit()
+ ModelEventListener().update(event)
+ # not called since model job is not found
+ mock_create_model.assert_not_called()
+
+ with db.session_scope() as session:
+ model_job = ModelJob(id=1, job_name=job.name, job_id=job.id, model_job_type=ModelJobType.TRAINING)
+ model = Model(id=1, model_job_id=1)
+ session.add_all([model_job, model])
+ session.commit()
+ ModelEventListener().update(event)
+ # create model
+ mock_create_model.assert_called()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ self.assertEqual(model_job.model_id, 1)
+ mock_create_model.reset_mock()
+
+ with db.session_scope() as session:
+ session.add(Model(name=job.name, job_id=job.id))
+ session.commit()
+ ModelEventListener().update(event)
+ # not called due to model is already created
+ mock_create_model.assert_not_called()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/BUILD.bazel
new file mode 100644
index 000000000..dd3fe32a9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/BUILD.bazel
@@ -0,0 +1,30 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "metrics_lib",
+ srcs = ["metrics_inquirer.py"],
+ imports = ["../../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:es_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:job_metrics_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "metrics_lib_test",
+ size = "small",
+ srcs = [
+ "metrics_inquirer_test.py",
+ ],
+ imports = ["../../.."],
+ main = "metrics_inquirer_test.py",
+ deps = [
+ ":metrics_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing/test_data:test_data_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/__init__.py b/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/metrics_inquirer.py b/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/metrics_inquirer.py
new file mode 100644
index 000000000..a69a56ca8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/metrics_inquirer.py
@@ -0,0 +1,187 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import List
+
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.proto.metrics_pb2 import ModelJobMetrics, Metric, ConfusionMatrix
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.es import ElasticSearchClient
+from fedlearner_webconsole.utils.job_metrics import get_feature_importance
+
+_ES_INDEX_NAME = 'apm*'
+_es_client = ElasticSearchClient()
+
+
+def _build_es_query_body(algorithm_type: str, job_name: str, metric_list: List[str]):
+ query = {
+ 'size': 0,
+ 'query': {
+ 'bool': {
+ 'must': [{
+ 'term': {
+ 'labels.k8s_job_name': job_name
+ }
+ }]
+ }
+ },
+ 'aggs': {
+ metric.upper(): {
+ 'filter': {
+ 'term': {
+ 'labels.k8s_job_name': job_name
+ }
+ },
+ 'aggs': {
+ mode.upper(): {
+ 'filter': {
+ 'exists': {
+ 'field': f'values.model.{mode}.{algorithm_type}.{metric}'
+ }
+ },
+ 'aggs': {
+ 'TOP': {
+ 'top_hits': {
+ 'size':
+ 0, # 0 means get all matching hits
+ '_source': [
+ f'values.model.{mode}.{algorithm_type}.{metric}', 'labels', '@timestamp'
+ ]
+ }
+ }
+ }
+ } for mode in ('train', 'eval')
+ }
+ } for metric in metric_list
+ }
+ }
+ return _es_client.search(index=_ES_INDEX_NAME, body=query, request_timeout=500)
+
+
+class TreeMetricsInquirer(object):
+ _CONF_METRIC_LIST = ['tp', 'tn', 'fp', 'fn']
+ _TREE_METRIC_LIST = ['acc', 'auc', 'precision', 'recall', 'f1', 'ks', 'mse', 'msre', 'abs'] + _CONF_METRIC_LIST
+ _ALGORITHM_TYPE = 'tree_vertical'
+
+ def _extract_metric(self, records: dict, mode: str, metric_name: str) -> Metric:
+ iter_to_values = {}
+ for record in records:
+ iteration = record['_source']['labels']['iteration']
+ value = record['_source'][f'values.model.{mode}.{self._ALGORITHM_TYPE}.{metric_name}']
+ if iteration not in iter_to_values:
+ iter_to_values[iteration] = []
+ iter_to_values[iteration].append(value)
+ ordered_iters = sorted(iter_to_values.keys())
+ values = [
+ # Avg
+ sum(iter_to_values[iteration]) / len(iter_to_values[iteration]) for iteration in ordered_iters
+ ]
+ return Metric(steps=ordered_iters, values=values)
+
+ def _extract_confusion_matrix(self, metrics: Metric) -> ConfusionMatrix:
+
+ def get_last_value(metric_name: str):
+ metric = metrics.get(metric_name)
+ if metric is not None and len(metric.values) > 0:
+ return int(metric.values[-1])
+ return 0
+
+ matrix = ConfusionMatrix(
+ tp=get_last_value('tp'),
+ tn=get_last_value('tn'),
+ fp=get_last_value('fp'),
+ fn=get_last_value('fn'),
+ )
+ # remove confusion relevant metrics
+ for key in self._CONF_METRIC_LIST:
+ metrics.pop(key)
+ return matrix
+
+ def _set_confusion_metric(self, metrics: ModelJobMetrics):
+
+ def is_training() -> bool:
+ iter_vals = metrics.train.get('tp')
+ if iter_vals is None:
+ return False
+ return len(iter_vals.values) > 0
+
+ if is_training():
+ confusion_matrix = self._extract_confusion_matrix(metrics.train)
+ else:
+ confusion_matrix = self._extract_confusion_matrix(metrics.eval)
+ metrics.confusion_matrix.CopyFrom(confusion_matrix)
+ return metrics
+
+ def query(self, job: Job, need_feature_importance: bool = False) -> ModelJobMetrics:
+ job_name = job.name
+ metrics = ModelJobMetrics()
+ query_result = _build_es_query_body(self._ALGORITHM_TYPE, job_name, self._TREE_METRIC_LIST)
+ if 'aggregations' not in query_result:
+ logging.warning(f'[METRICS] no aggregations found, job_name = {job_name}, result = {query_result}')
+ return metrics
+ aggregations = query_result['aggregations']
+ for name in self._TREE_METRIC_LIST:
+ train_item = aggregations[name.upper()]['TRAIN']['TOP']['hits']['hits']
+ eval_item = aggregations[name.upper()]['EVAL']['TOP']['hits']['hits']
+ if len(train_item) > 0:
+ metrics.train[name].MergeFrom(self._extract_metric(train_item, 'train', name))
+ if len(eval_item) > 0:
+ metrics.eval[name].MergeFrom(self._extract_metric(eval_item, 'eval', name))
+ self._set_confusion_metric(metrics)
+ if need_feature_importance:
+ metrics.feature_importance.update(get_feature_importance(job))
+ return metrics
+
+
+class NnMetricsInquirer(object):
+ _NN_METRIC_LIST = ['auc', 'loss']
+ _ALGORITHM_TYPE = 'nn_vertical'
+
+ def _extract_metric(self, records: dict, mode: str, metric_name: str) -> Metric:
+ timestamp_to_values = {}
+ for record in records:
+ timestamp_str = record['_source']['@timestamp']
+ timestamp = to_timestamp(timestamp_str) * 1000
+ value = record['_source'][f'values.model.{mode}.{self._ALGORITHM_TYPE}.{metric_name}']
+ timestamp_to_values[timestamp] = []
+ timestamp_to_values[timestamp].append(value)
+ ordered_iters = sorted(timestamp_to_values.keys())
+ values = [
+ # Avg
+ sum(timestamp_to_values[timestamp]) / len(timestamp_to_values[timestamp]) for timestamp in ordered_iters
+ ]
+ return Metric(steps=ordered_iters, values=values)
+
+ def query(self, job: Job) -> ModelJobMetrics:
+ job_name = job.name
+ metrics = ModelJobMetrics()
+ query_result = _build_es_query_body(self._ALGORITHM_TYPE, job_name, self._NN_METRIC_LIST)
+ if 'aggregations' not in query_result:
+ logging.warning(f'[METRICS] no aggregations found, job_name = {job_name}, result = {query_result}')
+ return metrics
+ aggregations = query_result['aggregations']
+ for name in self._NN_METRIC_LIST:
+ train_item = aggregations[name.upper()]['TRAIN']['TOP']['hits']['hits']
+ eval_item = aggregations[name.upper()]['EVAL']['TOP']['hits']['hits']
+ if len(train_item) > 0:
+ metrics.train[name].MergeFrom(self._extract_metric(train_item, 'train', name))
+ if len(eval_item) > 0:
+ metrics.eval[name].MergeFrom(self._extract_metric(eval_item, 'eval', name))
+ return metrics
+
+
+tree_metrics_inquirer = TreeMetricsInquirer()
+nn_metrics_inquirer = NnMetricsInquirer()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/metrics_inquirer_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/metrics_inquirer_test.py
new file mode 100644
index 000000000..8d559304d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/metrics/metrics_inquirer_test.py
@@ -0,0 +1,107 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch
+
+from testing.common import BaseTestCase
+from testing.test_data import es_query_result
+from fedlearner_webconsole.mmgr.metrics.metrics_inquirer import tree_metrics_inquirer, nn_metrics_inquirer
+from fedlearner_webconsole.job.models import Job, JobType
+from fedlearner_webconsole.utils.proto import to_dict
+
+_EXPECTED_TREE_METRICS_RESULT = {
+ 'train': {
+ 'ks': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.47770564314760644, 0.5349813321918623, 0.5469192171410906, 0.5596894247461416, 0.5992009702504102,
+ 0.6175715202967825, 0.6366317091151221, 0.6989964566835509, 0.7088535349932226, 0.7418848541057288
+ ]
+ },
+ 'recall': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.40186915887850466, 0.4252336448598131, 0.45794392523364486, 0.46261682242990654, 0.5233644859813084,
+ 0.514018691588785, 0.5093457943925234, 0.5373831775700935, 0.5467289719626168, 0.5654205607476636
+ ]
+ },
+ 'acc': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [0.857, 0.862, 0.868, 0.872, 0.886, 0.883, 0.884, 0.895, 0.896, 0.902]
+ },
+ 'auc': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.8011640626857863, 0.8377684240565029, 0.8533328577203871, 0.860663242253454, 0.8797977455946351,
+ 0.8921428741290338, 0.9041610187629308, 0.9179270409740553, 0.928827495184419, 0.9439282062257736
+ ]
+ },
+ 'precision': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.8514851485148515, 0.8584905660377359, 0.8596491228070176, 0.8839285714285714, 0.9032258064516129,
+ 0.8943089430894309, 0.9083333333333333, 0.9504132231404959, 0.9435483870967742, 0.9603174603174603
+ ]
+ },
+ 'f1': {
+ 'steps': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
+ 'values': [
+ 0.546031746031746, 0.56875, 0.5975609756097561, 0.607361963190184, 0.6627218934911242,
+ 0.6528189910979227, 0.6526946107784432, 0.6865671641791044, 0.6923076923076923, 0.711764705882353
+ ]
+ }
+ },
+ 'confusion_matrix': {
+ 'tp': 121,
+ 'tn': 781,
+ 'fp': 5,
+ 'fn': 93
+ },
+ 'feature_importance': {
+ 'x': 0.3
+ },
+ 'eval': {}
+}
+
+
+class MetricsInquirerTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.mmgr.metrics.metrics_inquirer.get_feature_importance')
+ @patch('fedlearner_webconsole.mmgr.metrics.metrics_inquirer._build_es_query_body')
+ def test_query_tree_metrics(self, mock_es_query, mock_get_importance):
+ mock_es_query.return_value = es_query_result.fake_es_query_tree_metrics_result_v2
+ mock_get_importance.return_value = {'x': 0.3}
+ job = Job(name='test-job', job_type=JobType.TREE_MODEL_TRAINING)
+ metrics = tree_metrics_inquirer.query(job, need_feature_importance=True)
+ metrics_dict = to_dict(metrics)
+ self.assertIn('train', metrics_dict)
+ self.assertEqual(metrics_dict, _EXPECTED_TREE_METRICS_RESULT)
+
+ @patch('fedlearner_webconsole.mmgr.metrics.metrics_inquirer._build_es_query_body')
+ def test_query_nn_vertical_metrics(self, mock_es_query):
+ mock_es_query.return_value = es_query_result.fake_es_query_nn_metrics_result_v2
+ job = Job(name='test-job', job_type=JobType.NN_MODEL_TRANINING)
+ metrics = nn_metrics_inquirer.query(job)
+ metrics_dict = to_dict(metrics)
+ self.assertIn('train', metrics_dict)
+ self.assertEqual(1, len(metrics_dict['train']['loss']['values']))
+ self.assertEqual(1, len(metrics_dict['train']['auc']['values']))
+ self.assertIn(5.694229602813721, metrics_dict['train']['loss']['values'])
+ self.assertIn(0.6585884094238281, metrics_dict['train']['auc']['values'])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_apis.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_apis.py
new file mode 100644
index 000000000..a2e8d5e21
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_apis.py
@@ -0,0 +1,241 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from http import HTTPStatus
+from sqlalchemy.orm import joinedload
+from flask_restful import Resource
+from typing import Optional
+from webargs.flaskparser import use_args, use_kwargs
+from marshmallow import fields, validate
+
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.proto.filtering_pb2 import FilterOp
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.mmgr.models import Model, ModelJob
+from fedlearner_webconsole.mmgr.service import ModelService, get_model
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.utils.flask_utils import FilterExpField, make_flask_response
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.proto.audit_pb2 import Event
+
+
+class ModelsApi(Resource):
+ FILTER_FIELDS = {
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'group_id': SupportedField(type=FieldType.NUMBER, ops={FilterOp.EQUAL: None}),
+ 'algorithm_type': SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ }
+
+ def __init__(self):
+ self._filter_builder = FilterBuilder(model_class=Model, supported_fields=self.FILTER_FIELDS)
+
+ @credentials_required
+ @use_args(
+ {
+ 'group_id':
+ fields.Integer(required=False, load_default=None),
+ 'keyword':
+ fields.String(required=False, load_default=None),
+ 'algorithm_type':
+ fields.String(
+ required=False, load_default=None, validate=validate.OneOf([t.name for t in AlgorithmType])),
+ 'page':
+ fields.Integer(required=False, load_default=None),
+ 'page_size':
+ fields.Integer(required=False, load_default=None),
+ 'filter':
+ FilterExpField(required=False, load_default=None),
+ },
+ location='query')
+ def get(self, params: dict, project_id: int):
+ """Get the list of models.
+ ---
+ tags:
+ - mmgr
+ description: Get the list of models
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: query
+ name: group_id
+ schema:
+ type: integer
+ - in: query
+ name: keyword
+ schema:
+ type: string
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: filter
+ schema:
+ type: string
+ responses:
+ 200:
+ description: the list of models
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelPb'
+ """
+ with db.session_scope() as session:
+ query = session.query(Model).options(
+ joinedload(Model.job).load_only(Job.name, Job.workflow_id).options(
+ joinedload(Job.workflow).load_only(Workflow.name)),
+ joinedload(Model.model_job).load_only(ModelJob.name))
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ if params['group_id'] is not None:
+ query = query.filter_by(group_id=params['group_id'])
+ if params['keyword'] is not None:
+ query = query.filter(Model.name.like(f'%{params["keyword"]}%'))
+ if params['algorithm_type'] is not None:
+ query = query.filter(Model.algorithm_type == AlgorithmType[params['algorithm_type']])
+ if params['filter']:
+ try:
+ query = self._filter_builder.build_query(query, params['filter'])
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ query = query.order_by(Model.created_at.desc())
+ pagination = paginate(query, params['page'], params['page_size'])
+ data = [d.to_proto() for d in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+
+class ModelApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, model_id: int):
+ """Get the model.
+ ---
+ tags:
+ - mmgr
+ description: get the model.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the model
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelPb'
+ """
+ with db.session_scope() as session:
+ model = get_model(project_id=project_id, model_id=model_id, session=session)
+ return make_flask_response(data=model.to_proto())
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL, op_type=Event.OperationType.UPDATE)
+ @use_kwargs({'comment': fields.Str(required=False, load_default=None)}, location='json')
+ def patch(self, comment: Optional[str], project_id: int, model_id: int):
+ """Patch the model.
+ ---
+ tags:
+ - mmgr
+ description: patch the model.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: False
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ comment:
+ type: string
+ responses:
+ 200:
+ description: detail of the model
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelPb'
+ """
+ with db.session_scope() as session:
+ model = get_model(project_id=project_id, model_id=model_id, session=session)
+ if comment is not None:
+ model.comment = comment
+ session.commit()
+ return make_flask_response(model.to_proto())
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL, op_type=Event.OperationType.DELETE)
+ def delete(self, project_id: int, model_id: int):
+ """Delete the model.
+ ---
+ tags:
+ - mmgr
+ decription: delete the model
+ parameters:
+ - in: path
+ name: proejct_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 204:
+ description: delete the model successfully
+ """
+ with db.session_scope() as session:
+ model = get_model(project_id=project_id, model_id=model_id, session=session)
+ ModelService(session).delete(model.id)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+def initialize_mmgr_model_apis(api):
+ api.add_resource(ModelsApi, '/projects//models')
+ api.add_resource(ModelApi, '/projects//models/')
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_apis_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_apis_test.py
new file mode 100644
index 000000000..d3eaf7d00
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_apis_test.py
@@ -0,0 +1,127 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import urllib.parse
+from http import HTTPStatus
+from datetime import datetime
+
+from testing.common import BaseTestCase
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobType, ModelType
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.job.models import Job, JobType
+
+
+class ModelsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ m1 = Model(name='m1', project_id=1, group_id=1, algorithm_type=AlgorithmType.NN_VERTICAL)
+ m2 = Model(name='m2', project_id=1, group_id=2, algorithm_type=AlgorithmType.TREE_VERTICAL)
+ m3 = Model(name='m3', project_id=1, group_id=2, algorithm_type=AlgorithmType.TREE_VERTICAL)
+ session.add_all([m1, m2, m3])
+ session.commit()
+
+ def test_get_models(self):
+ resp = self.get_helper('/api/v2/projects/1/models')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 3)
+ resp = self.get_helper('/api/v2/projects/1/models?group_id=2')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['name'], 'm2')
+ resp = self.get_helper('/api/v2/projects/1/models?keyword=1')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'm1')
+ filter_param = urllib.parse.quote('(and(group_id=1)(name~="m"))')
+ resp = self.get_helper(f'/api/v2/projects/1/models?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'm1')
+ resp = self.get_helper('/api/v2/projects/1/models?algorithm_type=TREE_VERTICAL')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted([d['name'] for d in data]), ['m2', 'm3'])
+
+
+class ModelApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ job = Job(id=3, name='job', job_type=JobType.NN_MODEL_TRANINING, workflow_id=2, project_id=1)
+ workflow = Workflow(id=2, name='workflow', project_id=1)
+ model_job = ModelJob(id=1, name='model_job', model_job_type=ModelJobType.TRAINING, group_id=1)
+ model = Model(id=1,
+ name='m1',
+ uuid='uuid',
+ project_id=1,
+ group_id=1,
+ job_id=3,
+ model_job_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ model_type=ModelType.NN_MODEL,
+ version=1,
+ created_at=datetime(2022, 5, 10, 0, 0, 0),
+ updated_at=datetime(2022, 5, 10, 0, 0, 0))
+ session.add_all([model, job, workflow, model_job])
+ session.commit()
+
+ def test_get_model(self):
+ resp = self.get_helper('/api/v2/projects/1/models/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ resp, {
+ 'id': 1,
+ 'name': 'm1',
+ 'uuid': 'uuid',
+ 'algorithm_type': 'NN_VERTICAL',
+ 'model_path': '',
+ 'comment': '',
+ 'group_id': 1,
+ 'project_id': 1,
+ 'job_id': 3,
+ 'model_job_id': 1,
+ 'version': 1,
+ 'workflow_id': 2,
+ 'workflow_name': 'workflow',
+ 'job_name': 'job',
+ 'model_job_name': 'model_job',
+ 'created_at': 1652140800,
+ 'updated_at': 1652140800
+ })
+
+ def test_patch_model(self):
+ resp = self.patch_helper('/api/v2/projects/1/models/1', data={'comment': 'comment'})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ model: Model = session.query(Model).get(1)
+ self.assertEqual(model.comment, 'comment')
+
+ def test_delete_model(self):
+ resp = self.delete_helper('/api/v2/projects/1/models/1')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ model = session.query(Model).execution_options(include_deleted=True).get(1)
+ self.assertIsNotNone(model.deleted_at)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_apis.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_apis.py
new file mode 100644
index 000000000..fa247fe39
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_apis.py
@@ -0,0 +1,1085 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import logging
+import tempfile
+from http import HTTPStatus
+from flask import send_file
+from flask_restful import Resource
+from typing import Optional, List
+from webargs.flaskparser import use_args, use_kwargs
+from marshmallow import Schema, post_load, fields, validate
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.services import BatchService
+from fedlearner_webconsole.exceptions import NotFoundException, ResourceConflictException, InternalException, \
+ InvalidArgumentException, NoAccessException, UnauthorizedException
+from fedlearner_webconsole.utils.sorting import SorterBuilder, SortExpression, parse_expression
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowExternalState
+from fedlearner_webconsole.workflow_template.service import dict_to_workflow_definition
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.mmgr.controller import CreateModelJob, start_model_job, stop_model_job, \
+ ModelJobController, ModelJobGroupController
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobGroup, ModelJobType, ModelJobRole, \
+ is_federated, AuthStatus, GroupAutoUpdateStatus, GroupAuthFrontendStatus, ModelJobStatus
+from fedlearner_webconsole.mmgr.service import ModelJobService, ModelJobGroupService, get_sys_template_id, \
+ get_model_job, get_project, get_participant
+from fedlearner_webconsole.mmgr.model_job_configer import ModelJobConfiger, set_load_model_name
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.utils.decorators.pp_flask import input_validator
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.flask_utils import make_flask_response, get_current_user, FilterExpField, \
+ FilterExpression
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterOp, SimpleExpression, FilterBuilder
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.scheduler.scheduler import scheduler
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.mmgr_pb2 import PeerModelJobPb, ModelJobGlobalConfig
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+from fedlearner_webconsole.flag.models import Flag
+
+
+def _check_model_job_global_config_enable(project_id: int) -> bool:
+ with db.session_scope() as session:
+ participants = session.query(Project).get(project_id).participants
+ flag = True
+ for p in participants:
+ client = SystemServiceClient.from_participant(domain_name=p.domain_name)
+ resp = client.list_flags()
+ if not resp.get(Flag.MODEL_JOB_GLOBAL_CONFIG_ENABLED.name):
+ flag = False
+ return flag
+
+
+class CreateModelJobParams(Schema):
+ name = fields.Str(required=True)
+ group_id = fields.Integer(required=False, load_default=None)
+ model_job_type = fields.Str(required=True,
+ validate=validate.OneOf([
+ ModelJobType.TRAINING.name, ModelJobType.EVALUATION.name,
+ ModelJobType.PREDICTION.name
+ ]))
+ algorithm_type = fields.Str(required=True,
+ validate=validate.OneOf([
+ AlgorithmType.TREE_VERTICAL.name, AlgorithmType.NN_VERTICAL.name,
+ AlgorithmType.NN_HORIZONTAL.name
+ ]))
+ algorithm_id = fields.Integer(required=False, load_default=None)
+ eval_model_job_id = fields.Integer(required=False, load_default=None)
+ model_id = fields.Integer(required=False, load_default=None)
+ dataset_id = fields.Integer(required=False, load_default=None)
+ data_batch_id = fields.Integer(required=False, load_default=None)
+ config = fields.Dict(required=False, load_default={})
+ comment = fields.Str(required=False, load_default=None)
+ global_config = fields.Dict(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'] = dict_to_workflow_definition(data['config'])
+ data['model_job_type'] = ModelJobType[data['model_job_type']]
+ data['algorithm_type'] = AlgorithmType[data['algorithm_type']]
+ if data.get('eval_model_job_id') is not None:
+ with db.session_scope() as session:
+ model = session.query(Model).filter_by(model_job_id=data.get('eval_model_job_id')).first()
+ data['model_id'] = model.id
+ if data['global_config'] is not None:
+ data['global_config'] = ParseDict(data['global_config'], ModelJobGlobalConfig())
+ return data
+
+
+# TODO(hangweiqiang): remove dataset_id in parameters
+class ConfigModelJobParams(Schema):
+ algorithm_id = fields.Integer(required=False, load_default=None)
+ dataset_id = fields.Integer(required=False, load_default=None)
+ config = fields.Dict(required=False, load_default=None)
+ global_config = fields.Dict(required=False, load_default=None)
+ comment = fields.Str(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ if data['config'] is not None:
+ data['config'] = dict_to_workflow_definition(data['config'])
+ if data['global_config'] is not None:
+ data['global_config'] = ParseDict(data['global_config'], ModelJobGlobalConfig())
+ return data
+
+
+class ListModelJobsSchema(Schema):
+ group_id = fields.Integer(required=False, load_default=None)
+ keyword = fields.String(required=False, load_default=None)
+ types = fields.List(fields.String(required=False,
+ validate=validate.OneOf([
+ ModelJobType.TRAINING.name, ModelJobType.EVALUATION.name,
+ ModelJobType.PREDICTION.name
+ ])),
+ required=False,
+ load_default=None)
+ configured = fields.Boolean(required=False, load_default=None)
+ algorithm_types = fields.List(fields.String(required=True,
+ validate=validate.OneOf([
+ AlgorithmType.TREE_VERTICAL.name, AlgorithmType.NN_VERTICAL.name,
+ AlgorithmType.NN_HORIZONTAL.name
+ ])),
+ required=False,
+ load_default=None)
+ states = fields.List(fields.String(required=False,
+ validate=validate.OneOf([s.name for s in WorkflowExternalState])),
+ required=False,
+ load_default=None)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
+ filter_exp = FilterExpField(data_key='filter', required=False, load_default=None)
+ sorter_exp = fields.String(required=False, load_default=None, data_key='order_by')
+
+ @post_load()
+ def make(self, data, **kwargs):
+ if data['types'] is not None:
+ data['types'] = [ModelJobType[t] for t in data['types']]
+ if data['states'] is not None:
+ data['states'] = [WorkflowExternalState[s] for s in data['states']]
+ if data['algorithm_types'] is not None:
+ data['algorithm_types'] = [AlgorithmType[t] for t in data['algorithm_types']]
+ return data
+
+
+class ModelJobApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, model_job_id: int):
+ """Get the model job by id
+ ---
+ tags:
+ - mmgr
+ description: get the model job by id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the model job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobPb'
+ """
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ ModelJobService(session).update_model_job_status(model_job)
+ ModelJobController(session, project_id).update_participants_auth_status(model_job)
+ session.commit()
+ return make_flask_response(model_job.to_proto())
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB, op_type=Event.OperationType.UPDATE)
+ @use_args(ConfigModelJobParams(), location='json')
+ def put(self, params: dict, project_id: int, model_job_id: int):
+ """Update the model job
+ ---
+ tags:
+ - mmgr
+ description: update the model job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/ConfigModelJobParams'
+ responses:
+ 200:
+ description: detail of the model job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobPb'
+ """
+ dataset_id = params['dataset_id']
+ algorithm_id = params['algorithm_id']
+ config = params['config']
+ global_config = params['global_config']
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ model_job.algorithm_id = algorithm_id
+ if dataset_id is not None:
+ model_job.dataset_id = dataset_id
+ model_job.comment = params['comment']
+ if global_config is not None:
+ configer = ModelJobConfiger(session=session,
+ model_job_type=model_job.model_job_type,
+ algorithm_type=model_job.algorithm_type,
+ project_id=project_id)
+ domain_name = SettingService.get_system_info().pure_domain_name
+ config = configer.get_config(dataset_id=model_job.dataset_id,
+ model_id=model_job.model_id,
+ model_job_config=global_config.global_config[domain_name])
+ ModelJobService(session).config_model_job(model_job, config=config, create_workflow=False)
+ model_job.role = ModelJobRole.PARTICIPANT
+ model_job.creator_username = get_current_user().username
+ # Compatible with old versions, use PUT for authorization
+ ModelJobService.update_model_job_auth_status(model_job=model_job, auth_status=AuthStatus.AUTHORIZED)
+ ModelJobController(session, project_id).inform_auth_status_to_participants(model_job)
+
+ session.commit()
+ scheduler.wakeup(model_job.workflow_id)
+ return make_flask_response(model_job.to_proto())
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB, op_type=Event.OperationType.UPDATE)
+ @use_kwargs(
+ {
+ 'metric_is_public':
+ fields.Boolean(required=False, load_default=None),
+ 'auth_status':
+ fields.String(required=False, load_default=None, validate=validate.OneOf([s.name for s in AuthStatus])),
+ 'comment':
+ fields.String(required=False, load_default=None)
+ },
+ location='json')
+ def patch(self, project_id: int, model_job_id: int, metric_is_public: Optional[bool], auth_status: Optional[str],
+ comment: Optional[str]):
+ """Patch the attribute of model job
+ ---
+ tags:
+ - mmgr
+ description: change the attribuet of model job, e.g. whether metric is public
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ metric_is_public:
+ type: boolean
+ auth_status:
+ type: string
+ comment:
+ type: string
+ responses:
+ 200:
+ description: detail of the model job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobPb'
+ """
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ if metric_is_public is not None:
+ model_job.metric_is_public = metric_is_public
+ if auth_status is not None:
+ ModelJobService.update_model_job_auth_status(model_job=model_job, auth_status=AuthStatus[auth_status])
+ ModelJobController(session, project_id).inform_auth_status_to_participants(model_job)
+ model_job.creator_username = get_current_user().username
+ if comment is not None:
+ model_job.comment = comment
+ session.commit()
+ return make_flask_response(model_job.to_proto())
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB, op_type=Event.OperationType.DELETE)
+ def delete(self, project_id: int, model_job_id: int):
+ """Delete the model job
+ ---
+ tags:
+ - mmgr
+ description: delete the model job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 204:
+ description: delete the model job successfully
+ 409:
+ description: model job cannot be deleted
+ """
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ if not model_job.is_deletable():
+ raise ResourceConflictException(f'model job cannot be deleted due to model job is {model_job.state}')
+ ModelJobService(session).delete(model_job.id)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class StartModelJobApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB, op_type=Event.OperationType.UPDATE)
+ def post(self, project_id: int, model_job_id: int):
+ """Start the model job
+ ---
+ tags:
+ - mmgr
+ description: start the model job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: start the model job successfully
+ """
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(id=model_job_id, project_id=project_id).first()
+ if model_job is None:
+ raise NotFoundException(f'[StartModelJobApi] model job {model_job_id} is not found')
+ start_model_job(model_job_id=model_job_id)
+ return make_flask_response(status=HTTPStatus.OK)
+
+
+class StopModelJobApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB, op_type=Event.OperationType.UPDATE)
+ def post(self, project_id: int, model_job_id: int):
+ """Stop the model job
+ ---
+ tags:
+ - mmgr
+ description: stop the model job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: stop the model job successfully
+ """
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(id=model_job_id, project_id=project_id).first()
+ if model_job is None:
+ raise NotFoundException(f'[StopModelJobApi] model job {model_job_id} is not found')
+ stop_model_job(model_job_id=model_job_id)
+ return make_flask_response(status=HTTPStatus.OK)
+
+
+class ModelJobMetricsApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, model_job_id: int):
+ """Get the model job metrics by id
+ ---
+ tags:
+ - mmgr
+ description: get the model job metrics by id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the model job metrics
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobMetrics'
+ 500:
+ description: error exists when query metrics for model job
+ """
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ try:
+ metrics = ModelJobService(session).query_metrics(model_job)
+ except ValueError as e:
+ logging.warning(f'[Model]error when query metrics for model job {model_job_id}')
+ raise InternalException(details=str(e)) from e
+ return make_flask_response(metrics), HTTPStatus.OK
+
+
+class ModelJobResultsApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, model_job_id: int):
+ """Get the model job result by id
+ ---
+ tags:
+ - mmgr
+ description: get the model job result by id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: model_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: file of the model job results
+ content:
+ application/json:
+ schema:
+ type: string
+ 204:
+ description: the output path does not exist
+ """
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ output_path = model_job.get_output_path()
+ file_manager = FileManager()
+ if file_manager.exists(output_path):
+ with tempfile.NamedTemporaryFile(suffix='.tar') as temp_file:
+ FileOperator().archive_to([output_path], temp_file.name)
+ return send_file(temp_file.name,
+ attachment_filename=f'{model_job.name}_result.tar',
+ mimetype='application/x-tar',
+ as_attachment=True)
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+def _validate_create_model_job_params(project_id: int,
+ name: str,
+ model_job_type: ModelJobType,
+ model_id: int,
+ group_id: Optional[int] = None):
+ if model_job_type == ModelJobType.TRAINING and model_id is not None:
+ raise InvalidArgumentException(details='model id must be None for training job')
+ if model_job_type in [ModelJobType.EVALUATION, ModelJobType.PREDICTION] and model_id is None:
+ raise InvalidArgumentException(details='model id must not be None for eval or predict job')
+ if model_job_type == ModelJobType.TRAINING and group_id is None:
+ raise InvalidArgumentException(details='training model job must be in a group')
+ if model_job_type in [ModelJobType.EVALUATION, ModelJobType.PREDICTION] and group_id is not None:
+ raise InvalidArgumentException(details='eval or predict job must not be in a group')
+ with db.session_scope() as session:
+ if group_id is not None:
+ group = session.query(ModelJobGroup).filter_by(project_id=project_id, id=group_id).first()
+ if group is None:
+ raise InvalidArgumentException(f'group {group_id} is not found in project {project_id}')
+ if model_id:
+ model = session.query(Model).filter_by(project_id=project_id, id=model_id).first()
+ if model is None:
+ raise InvalidArgumentException(f'model {model_id} is not found in project {project_id}')
+ model_job = session.query(ModelJob).filter_by(name=name).first()
+ if model_job is not None:
+ raise ResourceConflictException(f'model job {name} already exist')
+
+
+def _build_model_job_configured_query(exp: SimpleExpression):
+ if exp.bool_value:
+ return Workflow.config.isnot(None)
+ return Workflow.config.is_(None)
+
+
+# TODO(hangweiqiang): use filtering expression
+class ModelJobsApi(Resource):
+ FILTER_FIELDS = {
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'algorithm_type': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: None}),
+ 'model_job_type': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: None}),
+ 'configured': SupportedField(type=FieldType.BOOL, ops={FilterOp.EQUAL: _build_model_job_configured_query}),
+ 'role': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: None}),
+ 'status': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: None}),
+ 'auth_status': SupportedField(type=FieldType.STRING, ops={FilterOp.IN: None})
+ }
+
+ SORTER_FIELDS = ['created_at']
+
+ def __init__(self):
+ self._filter_builder = FilterBuilder(model_class=ModelJob, supported_fields=self.FILTER_FIELDS)
+ self._sorter_builder = SorterBuilder(model_class=ModelJob, supported_fields=self.SORTER_FIELDS)
+
+ @credentials_required
+ @use_kwargs(ListModelJobsSchema(), location='query')
+ def get(self, project_id: int, group_id: Optional[int], keyword: Optional[str], types: Optional[List[ModelJobType]],
+ configured: Optional[bool], algorithm_types: Optional[List[AlgorithmType]],
+ states: Optional[List[WorkflowExternalState]], page: Optional[int], page_size: Optional[int],
+ filter_exp: Optional[FilterExpression], sorter_exp: str):
+ """Get the list of model jobs
+ ---
+ tags:
+ - mmgr
+ description: get the list of model jobs
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: query
+ name: group_id
+ schema:
+ type: integer
+ - in: query
+ name: keyword
+ schema:
+ type: string
+ - in: query
+ name: types
+ schema:
+ type: array
+ items:
+ type: string
+ - in: query
+ name: algorithm_types
+ schema:
+ type: array
+ items:
+ type: string
+ - in: query
+ name: states
+ schema:
+ type: array
+ items:
+ type: string
+ - in: query
+ name: configured
+ schema:
+ type: boolean
+ - in: query
+ name: filter
+ schema:
+ type: string
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: list of model jobs
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobRef'
+ """
+ # update auth_status and participants_info of old data
+ with db.session_scope() as session:
+ model_jobs = session.query(ModelJob).filter_by(participants_info=None, project_id=project_id).all()
+ if model_jobs is not None:
+ participants = ParticipantService(session).get_participants_by_project(project_id)
+ participants_info = ParticipantsInfo(participants_map={
+ p.pure_domain_name(): ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name) for p in participants
+ })
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ participants_info.participants_map[pure_domain_name].auth_status = AuthStatus.AUTHORIZED.name
+ for model_job in model_jobs:
+ model_job.auth_status = AuthStatus.AUTHORIZED
+ model_job.set_participants_info(participants_info)
+ session.commit()
+ with db.session_scope() as session:
+ query = session.query(ModelJob)
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ if group_id is not None:
+ query = query.filter_by(group_id=group_id)
+ if types is not None:
+ query = query.filter(ModelJob.model_job_type.in_(types))
+ if algorithm_types is not None:
+ query = query.filter(ModelJob.algorithm_type.in_(algorithm_types))
+ if keyword is not None:
+ query = query.filter(ModelJob.name.like(f'%{keyword}%'))
+ if configured is not None:
+ if configured:
+ query = query.join(ModelJob.workflow).filter(Workflow.config.isnot(None))
+ else:
+ query = query.join(ModelJob.workflow).filter(Workflow.config.is_(None))
+ if filter_exp:
+ try:
+ query = query.outerjoin(Workflow, Workflow.uuid == ModelJob.workflow_uuid)
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ try:
+ if sorter_exp is not None:
+ sorter_exp = parse_expression(sorter_exp)
+ else:
+ sorter_exp = SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ pagination = paginate(query, page, page_size)
+ model_jobs = pagination.get_items()
+ for model_job in model_jobs:
+ ModelJobService(session).update_model_job_status(model_job)
+ if states is not None:
+ model_jobs = [m for m in model_jobs if m.state in states]
+ data = [m.to_ref() for m in model_jobs]
+ session.commit()
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB, op_type=Event.OperationType.CREATE)
+ @use_args(CreateModelJobParams(), location='json')
+ def post(self, params: dict, project_id: int):
+ """Create a model job
+ ---
+ tags:
+ - mmgr
+ description: create a model job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/CreateModelJobParams'
+ responses:
+ 201:
+ description: detail of the model job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobPb'
+ 500:
+ description: error exists when creating model job
+ """
+ name = params['name']
+ config = params['config']
+ model_job_type = params['model_job_type']
+ algorithm_type = params['algorithm_type']
+ model_id = params['model_id']
+ group_id = params['group_id']
+ dataset_id = params['dataset_id']
+ data_batch_id = params['data_batch_id']
+ algorithm_id = params['algorithm_id']
+ global_config = params['global_config']
+ comment = params['comment']
+ with db.session_scope() as session:
+ get_project(project_id, session)
+ _validate_create_model_job_params(project_id, name, model_job_type, model_id, group_id)
+ # if platform is old version or the peer's platform is old version
+ if not global_config or not _check_model_job_global_config_enable(project_id):
+ if data_batch_id is not None:
+ raise InternalException('auto update is not supported when our\'s or peer\'s platform is old version')
+ # model job type is TRAINING
+ if model_job_type in [ModelJobType.TRAINING]:
+ with db.session_scope() as session:
+ model_job = ModelJobController(session, project_id).launch_model_job(group_id=group_id)
+ session.commit()
+ return make_flask_response(model_job.to_proto(), status=HTTPStatus.CREATED)
+ # model job type is EVALUATION or PREDICTION
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ succeeded, msg = CreateModelJob().run(project_id=project_id,
+ name=name,
+ model_job_type=model_job_type,
+ coordinator_pure_domain_name=pure_domain_name,
+ algorithm_type=algorithm_type,
+ dataset_id=dataset_id,
+ model_id=model_id,
+ group_id=group_id)
+ if not succeeded:
+ raise InternalException(f'error when creating model job with message: {msg}')
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name=name).first()
+ model_job.algorithm_id = algorithm_id
+ model_job.comment = comment
+ model_job.creator_username = get_current_user().username
+ workflow_uuid = model_job.workflow_uuid
+ ModelJobService(session).config_model_job(model_job=model_job,
+ config=config,
+ create_workflow=True,
+ workflow_uuid=workflow_uuid)
+ model_job.role = ModelJobRole.COORDINATOR
+ session.commit()
+ workflow = session.query(Workflow).filter_by(uuid=workflow_uuid).first()
+ # TODO(gezhengqiang): refactor config_model_job service and remove wake up after refactoring workflow
+ scheduler.wakeup(workflow.id)
+ return make_flask_response(data=model_job.to_proto(), status=HTTPStatus.CREATED)
+ # new version
+ with db.session_scope() as session:
+ version = None
+ # model job type is TRAINING
+ if group_id:
+ group: ModelJobGroup = ModelJobGroupService(session).lock_and_update_version(group_id)
+ if group.get_group_auth_frontend_status() not in [GroupAuthFrontendStatus.ALL_AUTHORIZED]:
+ raise UnauthorizedException(f'participants not all authorized in the group {group.name}')
+ version = group.latest_version
+ model_job = ModelJobService(session).create_model_job(name=name,
+ uuid=resource_uuid(),
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=model_job_type,
+ algorithm_type=algorithm_type,
+ global_config=global_config,
+ group_id=group_id,
+ project_id=project_id,
+ data_batch_id=data_batch_id,
+ comment=comment,
+ version=version)
+ model_job.creator_username = get_current_user().username
+ if group_id and data_batch_id is not None:
+ group.auto_update_status = GroupAutoUpdateStatus.ACTIVE
+ group.start_data_batch_id = data_batch_id
+ ModelJobGroupController(session=session, project_id=project_id).update_participants_model_job_group(
+ uuid=group.uuid,
+ auto_update_status=group.auto_update_status,
+ start_data_batch_id=group.start_data_batch_id)
+ session.commit()
+ return make_flask_response(data=model_job.to_proto(), status=HTTPStatus.CREATED)
+
+
+class PeerModelJobApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, model_job_id: int, participant_id: int):
+ """Get the peer model job
+ ---
+ tags:
+ - mmgr
+ description: get the peer model job
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ - in: path
+ name: model_job_id
+ required: true
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ required: true
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: get the peer model job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.PeerModelJobPb'
+ """
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ project = model_job.project
+ participant = get_participant(participant_id, project)
+ client = RpcClient.from_project_and_participant(project.name, project.token, participant.domain_name)
+ resp = client.get_model_job(model_job_uuid=model_job.uuid, need_metrics=False)
+ # to support backward compatibility, since peer system may not have metric_is_public
+ # TODO(hangweiqiang): remove code of backward compatibility
+ metric_is_public = True
+ if resp.HasField('metric_is_public'):
+ metric_is_public = resp.metric_is_public.value
+ peer_job = PeerModelJobPb(name=resp.name,
+ uuid=resp.uuid,
+ algorithm_type=resp.algorithm_type,
+ model_job_type=resp.model_job_type,
+ state=resp.state,
+ group_uuid=resp.group_uuid,
+ config=resp.config,
+ metric_is_public=metric_is_public)
+ return make_flask_response(peer_job, status=HTTPStatus.OK)
+
+
+class PeerModelJobMetricsApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, model_job_id: int, participant_id: int):
+ """Get the peer model job metrics
+ ---
+ tags:
+ - mmgr
+ description: get the peer model job metrics
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ - in: path
+ name: model_job_id
+ required: true
+ schema:
+ type: integer
+ - in: path
+ name: participant_id
+ required: true
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: detail of the model job metrics
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobMetrics'
+ 403:
+ description: the metric of peer model job is not public
+ """
+ with db.session_scope() as session:
+ model_job = get_model_job(project_id, model_job_id, session)
+ project = model_job.project
+ participant = get_participant(participant_id, project)
+ client = RpcClient.from_project_and_participant(project.name, project.token, participant.domain_name)
+ resp = client.get_model_job(model_job_uuid=model_job.uuid, need_metrics=True)
+ if resp.HasField('metric_is_public') and not resp.metric_is_public.value:
+ raise NoAccessException('peer metric is not public')
+ metrics = json.loads(resp.metrics)
+ return make_flask_response(metrics, status=HTTPStatus.OK)
+
+
+class LaunchModelJobApi(Resource):
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB, op_type=Event.OperationType.CREATE)
+ def post(self, project_id: int, group_id: int):
+ """Launch the model job
+ ---
+ tags:
+ - mmgr
+ description: launch the model job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 201:
+ description: launch the model job successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobPb'
+ 500:
+ description: error exists when launching model job by 2PC
+ """
+ with db.session_scope() as session:
+ model_job = ModelJobController(session, project_id).launch_model_job(group_id=group_id)
+ return make_flask_response(model_job.to_proto(), status=HTTPStatus.CREATED)
+
+
+class NextAutoUpdateModelJobApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, group_id: int):
+ """Get the next auto update model job
+ ---
+ tags:
+ - mmgr
+ description: get the next auto update model job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the model job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobPb'
+ """
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(project_id=project_id, group_id=group_id,
+ auto_update=True).order_by(
+ ModelJob.created_at.desc()).limit(1).first()
+ if model_job is None:
+ group = session.query(ModelJobGroup).get(group_id)
+ raise NotFoundException(f'The auto update model job update the group {group.name} is not found')
+ if model_job.status in [ModelJobStatus.CONFIGURED, ModelJobStatus.RUNNING, ModelJobStatus.PENDING]:
+ raise InternalException(f'The latest auto update model job {model_job.name} is running')
+ data_batch_id = 0
+ load_model_name = ''
+ if model_job.status in [ModelJobStatus.STOPPED, ModelJobStatus.FAILED, ModelJobStatus.ERROR]:
+ previous_success_model_job = session.query(ModelJob).filter_by(
+ project_id=project_id, group_id=group_id, auto_update=True,
+ status=ModelJobStatus.SUCCEEDED).order_by(ModelJob.created_at.desc()).limit(1).first()
+ if previous_success_model_job is None:
+ return make_flask_response(model_job.to_proto())
+ model_job = previous_success_model_job
+ next_data_batch = BatchService(session).get_next_batch(model_job.data_batch)
+ if model_job.status in [ModelJobStatus.SUCCEEDED] and next_data_batch is not None:
+ model = session.query(Model).filter_by(model_job_id=model_job.id).first()
+ if model is None:
+ raise NotFoundException(f'The model job {model_job.name}\'s model is not found')
+ load_model_name = model.name
+ data_batch_id = next_data_batch.id
+ if model_job.model_id is None:
+ model_job.model_id = model.id
+ model_job.data_batch_id = data_batch_id
+ global_config = model_job.get_global_config()
+ if global_config is not None:
+ for config in global_config.global_config.values():
+ set_load_model_name(config, load_model_name)
+ model_job.set_global_config(global_config)
+ return make_flask_response(model_job.to_proto())
+
+
+class ModelJobDefinitionApi(Resource):
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'model_job_type': fields.Str(required=True, validate=validate.OneOf([t.name for t in ModelJobType])),
+ 'algorithm_type': fields.Str(required=True, validate=validate.OneOf([t.name for t in AlgorithmType])),
+ },
+ location='query')
+ def get(self, model_job_type: str, algorithm_type: str):
+ """Get variables of model_job
+ ---
+ tags:
+ - mmgr
+ description: Get variables of given type of algorithm and model job
+ parameters:
+ - in: path
+ name: model_job_type
+ schema:
+ type: string
+ - in: path
+ name: algorithm_type
+ schema:
+ type: string
+ responses:
+ 200:
+ description: variables of given algorithm type and model job type
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ variables:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Variable'
+ is_federated:
+ type: boolean
+ """
+ model_job_type = ModelJobType[model_job_type]
+ algorithm_type = AlgorithmType[algorithm_type]
+ with db.session_scope() as session:
+ template_id = get_sys_template_id(session=session,
+ algorithm_type=algorithm_type,
+ model_job_type=model_job_type)
+ template: WorkflowTemplate = session.query(WorkflowTemplate).get(template_id)
+ config = template.get_config()
+ variables = config.job_definitions[0].variables
+ flag = is_federated(algorithm_type=algorithm_type, model_job_type=model_job_type)
+ return make_flask_response(data={'variables': list(variables), 'is_federated': flag})
+
+
+def initialize_mmgr_model_job_apis(api):
+ api.add_resource(ModelJobsApi, '/projects//model_jobs')
+ api.add_resource(ModelJobApi, '/projects//model_jobs/')
+ api.add_resource(ModelJobMetricsApi, '/projects//model_jobs//metrics')
+ api.add_resource(ModelJobResultsApi, '/projects//model_jobs//results')
+ api.add_resource(StartModelJobApi, '/projects//model_jobs/:start')
+ api.add_resource(StopModelJobApi, '/projects//model_jobs/:stop')
+ api.add_resource(PeerModelJobApi,
+ '/projects//model_jobs//peers/')
+ api.add_resource(PeerModelJobMetricsApi,
+ '/projects//model_jobs//peers//metrics')
+ api.add_resource(LaunchModelJobApi, '/projects//model_job_groups/:launch')
+ api.add_resource(NextAutoUpdateModelJobApi,
+ '/projects//model_job_groups//next_auto_update_model_job')
+ api.add_resource(ModelJobDefinitionApi, '/model_job_definitions')
+
+ schema_manager.append(CreateModelJobParams)
+ schema_manager.append(ConfigModelJobParams)
+ schema_manager.append(ListModelJobsSchema)
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_apis_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_apis_test.py
new file mode 100644
index 000000000..ab7d7e3cc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_apis_test.py
@@ -0,0 +1,1291 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+import tarfile
+import tempfile
+import unittest
+import urllib.parse
+from io import BytesIO
+from pathlib import Path
+from http import HTTPStatus
+from datetime import datetime
+from unittest.mock import patch, Mock, MagicMock, call
+from envs import Envs
+from testing.common import BaseTestCase
+from testing.fake_model_job_config import get_global_config, get_workflow_config
+from google.protobuf.wrappers_pb2 import BoolValue
+from google.protobuf.empty_pb2 import Empty
+from google.protobuf.struct_pb2 import Value
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.utils.flask_utils import to_dict
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.participant.models import ProjectParticipant
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobGroup, ModelJobType, ModelJobRole, AuthStatus,\
+ ModelJobStatus, GroupAutoUpdateStatus
+from fedlearner_webconsole.algorithm.models import AlgorithmType, Algorithm
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.proto.service_pb2 import GetModelJobResponse
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, ModelJobConfig, ModelJobPb
+from fedlearner_webconsole.workflow.models import WorkflowState, TransactionState
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetType, DatasetJobState, \
+ DatasetJobStage, DataBatch
+
+
+class ModelJobsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Envs.SYSTEM_INFO = '{"domain_name": "fl-test.com"}'
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ dataset_job = DatasetJob(id=1,
+ name='datasetjob',
+ uuid='dataset-job-uuid',
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ input_dataset_id=3,
+ output_dataset_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN)
+ dataset_job_stage = DatasetJobStage(id=1,
+ name='data-join',
+ uuid='dataset-job-stage-uuid',
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ dataset_job_id=1,
+ data_batch_id=1)
+ data_batch = DataBatch(id=1,
+ name='20221213',
+ dataset_id=1,
+ path='/data/dataset/haha/batch/20221213',
+ latest_parent_dataset_job_stage_id=1,
+ event_time=datetime(2022, 12, 13, 16, 37, 37))
+ dataset = Dataset(id=1,
+ uuid='uuid',
+ name='dataset',
+ dataset_type=DatasetType.PSI,
+ path='/data/dataset/haha',
+ is_published=True)
+ project = Project(id=1, name='test-project')
+ participant = Participant(id=1, name='peer', domain_name='fl-peer.com')
+ pro_participant = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ group = ModelJobGroup(id=1, name='test-group', project_id=project.id, uuid='uuid', latest_version=2)
+ session.add_all([project, group, dataset, dataset_job, data_batch, dataset_job_stage])
+ participants_info = ParticipantsInfo()
+ w1 = Workflow(name='w1',
+ uuid='u1',
+ state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.PARTICIPANT_PREPARE)
+ mj1 = ModelJob(name='mj1',
+ workflow_uuid=w1.uuid,
+ project_id=1,
+ group_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ model_job_type=ModelJobType.TRAINING,
+ role=ModelJobRole.COORDINATOR,
+ auth_status=AuthStatus.PENDING,
+ created_at=datetime(2022, 8, 4, 0, 0, 0))
+ mj1.set_participants_info(participants_info)
+ w2 = Workflow(name='w2', uuid='u2', state=WorkflowState.READY, target_state=None)
+ w2.set_config(get_workflow_config(model_job_type=ModelJobType.EVALUATION))
+ mj2 = ModelJob(name='mj2',
+ workflow_uuid=w2.uuid,
+ project_id=1,
+ group_id=2,
+ algorithm_type=AlgorithmType.TREE_VERTICAL,
+ model_job_type=ModelJobType.EVALUATION,
+ role=ModelJobRole.PARTICIPANT,
+ auth_status=AuthStatus.AUTHORIZED,
+ created_at=datetime(2022, 8, 4, 0, 0, 1))
+ mj2.set_participants_info(participants_info)
+ w3 = Workflow(name='w3', uuid='u3', state=WorkflowState.RUNNING, target_state=None)
+ w3.set_config(get_workflow_config(model_job_type=ModelJobType.PREDICTION))
+ mj3 = ModelJob(name='mj3',
+ workflow_uuid=w3.uuid,
+ project_id=1,
+ algorithm_type=AlgorithmType.NN_HORIZONTAL,
+ model_job_type=ModelJobType.PREDICTION,
+ role=ModelJobRole.COORDINATOR,
+ auth_status=AuthStatus.PENDING,
+ created_at=datetime(2022, 8, 4, 0, 0, 2))
+ mj3.set_participants_info(participants_info)
+ w4 = Workflow(name='w4', uuid='u4', state=WorkflowState.RUNNING, target_state=None)
+ w4.set_config(get_workflow_config(model_job_type=ModelJobType.PREDICTION))
+ mj4 = ModelJob(name='mj31',
+ workflow_uuid=w4.uuid,
+ project_id=1,
+ algorithm_type=AlgorithmType.TREE_VERTICAL,
+ model_job_type=ModelJobType.PREDICTION,
+ role=ModelJobRole.PARTICIPANT,
+ auth_status=AuthStatus.AUTHORIZED,
+ created_at=datetime(2022, 8, 4, 0, 0, 3))
+ mj4.set_participants_info(participants_info)
+ w5 = Workflow(name='w5', uuid='u5', state=WorkflowState.COMPLETED, target_state=None)
+ mj5 = ModelJob(id=123,
+ project_id=1,
+ name='mj5',
+ workflow_uuid=w5.uuid,
+ role=ModelJobRole.COORDINATOR,
+ auth_status=AuthStatus.PENDING,
+ created_at=datetime(2022, 8, 4, 0, 0, 4))
+ mj5.set_participants_info(participants_info)
+ mj6 = ModelJob(id=124,
+ project_id=2,
+ name='mj6',
+ workflow_uuid=w5.uuid,
+ role=ModelJobRole.COORDINATOR,
+ auth_status=AuthStatus.PENDING,
+ created_at=datetime(2022, 8, 4, 0, 0, 4))
+ mj5.set_participants_info(participants_info)
+ model = Model(id=12, name='test', model_job_id=123, group_id=1, uuid='model-uuid', project_id=1)
+ session.add_all([w1, w2, w3, mj1, mj2, mj3, w4, mj4, w5, mj5, mj6, model, participant, pro_participant])
+ session.commit()
+
+ def test_get_model_jobs_by_project_or_group(self):
+ resp = self.get_helper('/api/v2/projects/2/model_jobs')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj6'])
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?group_id=2')
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj2'])
+
+ def test_get_model_jobs_by_type(self):
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?types=TRAINING&types=EVALUATION')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj1', 'mj2'])
+
+ def test_get_model_jobs_by_algorithm_types(self):
+ resp = self.get_helper(
+ '/api/v2/projects/1/model_jobs?algorithm_types=NN_VERTICAL&&algorithm_types=TREE_VERTICAL')
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj1', 'mj2', 'mj31'])
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?algorithm_types=NN_HORIZONTAL')
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj3'])
+
+ def test_get_model_jobs_by_states(self):
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?states=PENDING_ACCEPT')
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj1'])
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?states=RUNNING&states=READY_TO_RUN')
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj2', 'mj3', 'mj31'])
+
+ def test_get_model_jobs_by_keyword(self):
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?keyword=mj3')
+ data = self.get_response_data(resp)
+ model_job_names = sorted([d['name'] for d in data])
+ self.assertEqual(model_job_names, ['mj3', 'mj31'])
+
+ def test_get_model_jobs_by_configured(self):
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?configured=false')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted([d['name'] for d in data]), ['mj1', 'mj5'])
+ resp = self.get_helper('/api/v2/projects/1/model_jobs?configured=true')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted([d['name'] for d in data]), ['mj2', 'mj3', 'mj31'])
+
+ def test_get_model_jobs_by_expression(self):
+ filter_param = urllib.parse.quote('(algorithm_type:["NN_VERTICAL","TREE_VERTICAL"])')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted([d['name'] for d in data]), ['mj1', 'mj2', 'mj31'])
+ filter_param = urllib.parse.quote('(algorithm_type:["NN_HORIZONTAL"])')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj3'])
+ filter_param = urllib.parse.quote('(role:["COORDINATOR"])')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj1', 'mj3', 'mj5'])
+ filter_param = urllib.parse.quote('(name~="1")')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj1', 'mj31'])
+ filter_param = urllib.parse.quote('(model_job_type:["TRAINING","EVALUATION"])')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj1', 'mj2'])
+ filter_param = urllib.parse.quote('(status:["RUNNING"])')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj3', 'mj31'])
+ filter_param = urllib.parse.quote('(configured=true)')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj2', 'mj3', 'mj31'])
+ filter_param = urllib.parse.quote('(auth_status:["AUTHORIZED"])')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj2', 'mj31'])
+ sorter_param = urllib.parse.quote('created_at asc')
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?order_by={sorter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['mj1', 'mj2', 'mj3', 'mj31', 'mj5'])
+ self.assertEqual(data[0]['status'], ModelJobStatus.PENDING.name)
+ self.assertEqual(data[1]['status'], ModelJobStatus.PENDING.name)
+ self.assertEqual(data[2]['status'], ModelJobStatus.RUNNING.name)
+ self.assertEqual(data[3]['status'], ModelJobStatus.RUNNING.name)
+ self.assertEqual(data[4]['status'], ModelJobStatus.SUCCEEDED.name)
+ resp = self.get_helper(f'/api/v2/projects/1/model_jobs?page=2&page_size=2&order_by={sorter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted(d['name'] for d in data), ['mj3', 'mj31'])
+
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ def test_update_auth_status_of_old_data(self, mock_get_system_info):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='test')
+ with db.session_scope() as session:
+ project = Project(id=3, name='project2')
+ participant = Participant(id=3, name='peer2', domain_name='fl-peer2.com')
+ pro_participant = ProjectParticipant(id=2, project_id=3, participant_id=3)
+ model_job6 = ModelJob(id=6, project_id=3, name='j6', participants_info=None)
+ model_job7 = ModelJob(id=7, project_id=3, name='j7', participants_info=None)
+ session.add_all([project, participant, pro_participant, model_job6, model_job7])
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/3/model_jobs')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'peer2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ with db.session_scope() as session:
+ model_job6 = session.query(ModelJob).get(6)
+ self.assertEqual(model_job6.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(model_job6.get_participants_info(), participants_info)
+ model_job7 = session.query(ModelJob).get(7)
+ self.assertEqual(model_job7.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(model_job7.get_participants_info(), participants_info)
+
+ @patch('fedlearner_webconsole.rpc.v2.system_service_client.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.two_pc.model_job_creator.ModelJobCreator.prepare')
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_post_train_model_job(self, mock_remote_two_pc: Mock, mock_prepare: Mock, mock_list_flags: Mock):
+ mock_prepare.return_value = True, ''
+ config = get_workflow_config(ModelJobType.TRAINING)
+ mock_remote_two_pc.return_value = True, ''
+ mock_list_flags.return_value = {'model_job_global_config_enabled': True}
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ group.role = ModelJobRole.COORDINATOR
+ group.algorithm_type = AlgorithmType.TREE_VERTICAL
+ group.dataset_id = 1
+ group.set_config(config)
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'train-job',
+ 'group_id': 1,
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'config': to_dict(config),
+ 'comment': 'comment'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(group_id=1, version=3).first()
+ self.assertEqual(model_job.project_id, 1)
+ self.assertEqual(model_job.group_id, 1)
+ self.assertEqual(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.TREE_VERTICAL)
+ self.assertEqual(model_job.dataset_id, 1)
+ self.assertEqual(model_job.dataset_name(), 'dataset')
+ self.assertEqual(
+ model_job.workflow.get_config(),
+ WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[
+ make_variable(name='mode', typed_value='train'),
+ make_variable(name='data_source',
+ typed_value='dataset-job-stage-uuid-psi-data-join-job'),
+ make_variable(name='data_path', typed_value=''),
+ make_variable(name='file_wildcard', typed_value='*.data'),
+ ],
+ yaml_template='{}')
+ ]))
+
+ @patch('fedlearner_webconsole.two_pc.model_job_creator.ModelJobCreator.prepare')
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_post_eval_model_job(self, mock_remote_two_pc: Mock, mock_prepare: Mock):
+ mock_prepare.return_value = True, ''
+ config = get_workflow_config(ModelJobType.EVALUATION)
+ mock_remote_two_pc.return_value = True, ''
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'eval-job',
+ 'model_job_type': 'EVALUATION',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'config': to_dict(config),
+ 'eval_model_job_id': 123
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='eval-job').first()
+ self.assertEqual(model_job.project_id, 1)
+ self.assertEqual(model_job.role, ModelJobRole.COORDINATOR)
+ self.assertEqual(model_job.model_job_type, ModelJobType.EVALUATION)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.TREE_VERTICAL)
+ self.assertEqual(model_job.model_id, 12)
+ self.assertEqual(model_job.dataset_id, 1)
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ @patch('fedlearner_webconsole.rpc.v2.system_service_client.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_model_job')
+ def test_post_model_jobs_with_global_config(self, mock_create_model_job, mock_list_flags, mock_remote_do_two_pc):
+ mock_create_model_job.return_value = Empty()
+ global_config = get_global_config()
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'eval-job',
+ 'model_job_type': 'EVALUATION',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'model_id': 12,
+ 'global_config': to_dict(global_config),
+ 'comment': 'comment'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='eval-job').first()
+ self.assertEqual(model_job.model_job_type, ModelJobType.EVALUATION)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.TREE_VERTICAL)
+ self.assertEqual(model_job.dataset_id, 1)
+ self.assertEqual(model_job.model_id, 12)
+ self.assertEqual(model_job.group_id, 1)
+ self.assertEqual(model_job.get_global_config(), get_global_config())
+ self.assertEqual(model_job.comment, 'comment')
+ self.assertEqual(model_job.status, ModelJobStatus.PENDING)
+ self.assertEqual(model_job.creator_username, 'ada')
+ mock_list_flags.return_value = {'model_job_global_config_enabled': True}
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'train-job-1',
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'group_id': 1,
+ 'global_config': to_dict(global_config),
+ 'comment': 'comment'
+ })
+ # fail due to no authorization
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['test'].auth_status = AuthStatus.AUTHORIZED.name
+ participants_info.participants_map['peer'].auth_status = AuthStatus.PENDING.name
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ group.set_participants_info(participants_info)
+ group.authorized = True
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'train-job-1',
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'group_id': 1,
+ 'global_config': to_dict(global_config),
+ 'comment': 'comment'
+ })
+ # fail due to peer no authorization
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ participants_info = group.get_participants_info()
+ participants_info.participants_map['peer'].auth_status = AuthStatus.AUTHORIZED.name
+ group.set_participants_info(participants_info)
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'train-job-1',
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'group_id': 1,
+ 'global_config': to_dict(global_config),
+ 'comment': 'comment'
+ })
+ # create successfully
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='train-job-1').first()
+ self.assertEqual(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertEqual(model_job.version, 3)
+ mock_list_flags.return_value = {'model_job_global_config_enabled': False}
+ mock_remote_do_two_pc.return_value = True, ''
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ group.role = ModelJobRole.COORDINATOR
+ group.algorithm_type = AlgorithmType.NN_VERTICAL
+ group.set_config()
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'train-job-2',
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'NN_VERTICAL',
+ 'dataset_id': 1,
+ 'group_id': 1,
+ 'global_config': to_dict(global_config),
+ 'comment': 'comment'
+ })
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(group_id=1, version=4).first()
+ self.assertIsNotNone(model_job)
+ self.assertEqual(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_VERTICAL)
+
+ @patch('fedlearner_webconsole.rpc.v2.system_service_client.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.update_model_job_group')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_model_job')
+ def test_post_auto_update_model_job(self, mock_creat_model_job: MagicMock, mock_update_model_job_group: MagicMock,
+ mock_list_flags: MagicMock):
+ mock_creat_model_job.return_value = Empty()
+ mock_list_flags.return_value = {'model_job_global_config_enabled': True}
+ global_config = get_global_config()
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.kind = DatasetJobKind.OT_PSI_DATA_JOIN
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['test'].auth_status = AuthStatus.AUTHORIZED.name
+ participants_info.participants_map['peer'].auth_status = AuthStatus.AUTHORIZED.name
+ group = session.query(ModelJobGroup).get(1)
+ group.set_participants_info(participants_info)
+ group.authorized = True
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'auto-update-train-job-1',
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'NN_VERTICAL',
+ 'dataset_id': 1,
+ 'data_batch_id': 1,
+ 'group_id': 1,
+ 'global_config': to_dict(global_config),
+ 'comment': 'comment'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ self.assertEqual(mock_update_model_job_group.call_args_list, [
+ call(auto_update_status=GroupAutoUpdateStatus.ACTIVE,
+ start_dataset_job_stage_uuid='dataset-job-stage-uuid',
+ uuid='uuid')
+ ])
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(group_id=1, version=3).first()
+ self.assertIsNotNone(model_job)
+ self.assertEqual(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(model_job.auto_update, True)
+ self.assertEqual(model_job.data_batch_id, 1)
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ self.assertEqual(group.start_data_batch_id, 1)
+ self.assertEqual(group.auto_update_status, GroupAutoUpdateStatus.ACTIVE)
+
+ def test_post_model_jobs_failed(self):
+ # fail due to missing model_id for eval job
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'eval-job',
+ 'model_job_type': 'EVALUATION',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'config': {},
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # fail due to model_id existence for train job
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'train-job',
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 1,
+ 'config': {},
+ 'model_id': 1
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ def test_post_horizontal_eval_model_job(self):
+ with db.session_scope() as session:
+ model = Model(id=1, name='train-model', project_id=1)
+ session.add(model)
+ session.commit()
+ config = get_workflow_config(ModelJobType.EVALUATION)
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'eval-job',
+ 'model_job_type': 'EVALUATION',
+ 'algorithm_type': 'NN_HORIZONTAL',
+ 'algorithm_id': 3,
+ 'model_id': 1,
+ 'config': to_dict(config),
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='eval-job').first()
+ self.assertEqual(model_job.project_id, 1)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_HORIZONTAL)
+ self.assertEqual(model_job.role, ModelJobRole.COORDINATOR)
+ self.assertEqual(model_job.algorithm_id, 3)
+ self.assertEqual(model_job.model_id, 1)
+ self.assertEqual(model_job.model_job_type, ModelJobType.EVALUATION)
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_post_model_job_failed_due_to_dataset(self, mock_remote_two_pc):
+ config = get_workflow_config(ModelJobType.EVALUATION)
+ mock_remote_two_pc.return_value = True, ''
+ # failed due to dataset is not found
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'eval-job',
+ 'model_job_type': 'EVALUATION',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 3,
+ 'config': to_dict(config),
+ 'eval_model_job_id': 123
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).get(1)
+ dataset.is_published = False
+ session.add(dataset)
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/model_jobs',
+ data={
+ 'name': 'eval-job',
+ 'model_job_type': 'EVALUATION',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'dataset_id': 2,
+ 'config': to_dict(config),
+ 'eval_model_job_id': 123
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+
+
+class ModelJobApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Envs.SYSTEM_INFO = '{"domain_name": "fl-test.com"}'
+ with db.session_scope() as session:
+ project = Project(id=1, name='test-project')
+ participant = Participant(id=1, name='part', domain_name='fl-demo1.com')
+ pro_part = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ group = ModelJobGroup(id=1, name='test-group', project_id=project.id, uuid='uuid')
+ workflow_uuid = 'uuid'
+ workflow = Workflow(id=1,
+ name='test-workflow-1',
+ project_id=1,
+ state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.PARTICIPANT_PREPARE,
+ uuid=workflow_uuid)
+ dataset_job = DatasetJob(id=1,
+ name='datasetjob',
+ uuid='uuid',
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=3,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN)
+ dataset = Dataset(id=3,
+ uuid='uuid',
+ name='dataset',
+ dataset_type=DatasetType.PSI,
+ path='/data/dataset/haha')
+ model_job = ModelJob(id=1,
+ name='test-model-job',
+ group_id=1,
+ project_id=1,
+ dataset_id=3,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.TREE_VERTICAL,
+ workflow_id=1,
+ workflow_uuid=workflow_uuid,
+ job_id=2,
+ job_name='uuid-train-job',
+ created_at=datetime(2022, 5, 10, 0, 0, 0))
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ model_job.set_participants_info(participants_info)
+ session.add_all([project, group, workflow, model_job, dataset, dataset_job, participant, pro_part])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_model_job')
+ def test_get_model_job(self, mock_get_model_job):
+ mock_get_model_job.side_effect = [ModelJobPb(auth_status=AuthStatus.AUTHORIZED.name)]
+ with db.session_scope() as session:
+ workflow: Workflow = session.query(Workflow).filter_by(uuid='uuid').first()
+ config = get_workflow_config(model_job_type=ModelJobType.TRAINING)
+ workflow.set_config(config)
+ workflow.state = WorkflowState.READY
+ workflow.target_state = None
+ workflow.start_at = 1
+ workflow.stop_at = 2
+ model_job: ModelJob = session.query(ModelJob).get(1)
+ model = Model(id=1,
+ name='test-model',
+ model_job_id=model_job.id,
+ group_id=model_job.group_id,
+ created_at=datetime(2022, 5, 10, 0, 0, 0),
+ updated_at=datetime(2022, 5, 10, 0, 0, 0))
+ session.add(model)
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/1/model_jobs/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.maxDiff = None
+ self.assertPartiallyEqual(data, {
+ 'id': 1,
+ 'name': 'test-model-job',
+ 'role': 'PARTICIPANT',
+ 'model_job_type': 'TRAINING',
+ 'algorithm_type': 'TREE_VERTICAL',
+ 'auth_status': 'PENDING',
+ 'auto_update': False,
+ 'status': 'PENDING',
+ 'error_message': '',
+ 'group_id': 1,
+ 'project_id': 1,
+ 'state': 'READY_TO_RUN',
+ 'configured': True,
+ 'dataset_id': 3,
+ 'dataset_name': 'dataset',
+ 'output_model_name': 'test-model',
+ 'created_at': 1652140800,
+ 'started_at': 1,
+ 'stopped_at': 2,
+ 'uuid': '',
+ 'algorithm_id': 0,
+ 'model_id': 0,
+ 'model_name': '',
+ 'workflow_id': 1,
+ 'job_id': 2,
+ 'job_name': 'uuid-train-job',
+ 'creator_username': '',
+ 'coordinator_id': 0,
+ 'comment': '',
+ 'version': 0,
+ 'metric_is_public': False,
+ 'auth_frontend_status': 'SELF_AUTH_PENDING',
+ 'participants_info': {
+ 'participants_map': {
+ 'demo1': {
+ 'auth_status': 'AUTHORIZED',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': ''
+ },
+ 'test': {
+ 'auth_status': 'PENDING',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': ''
+ }
+ }
+ }
+ },
+ ignore_fields=['config', 'output_models', 'updated_at', 'data_batch_id'])
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_model_job')
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.scheduler.scheduler.Scheduler.wakeup')
+ def test_put_model_job(self, mock_wake_up, mock_get_system_info, mock_inform_model_job):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='test')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ model_job.uuid = 'uuid'
+ session.commit()
+ config = get_workflow_config(ModelJobType.TRAINING)
+ data = {'algorithm_id': 1, 'config': to_dict(config)}
+ resp = self.put_helper('/api/v2/projects/1/model_jobs/1', data=data)
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['configured'], True)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ self.assertEqual(model_job.role, ModelJobRole.PARTICIPANT)
+ workflow = session.query(Workflow).filter_by(uuid='uuid').first()
+ self.assertEqual(workflow.template.name, 'sys-preset-tree-model')
+ self.assertEqual(
+ workflow.get_config(),
+ WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[
+ make_variable(name='mode', typed_value='train'),
+ make_variable(name='data_source', typed_value=''),
+ make_variable(name='data_path', typed_value='/data/dataset/haha/batch'),
+ make_variable(name='file_wildcard', typed_value='**/part*')
+ ],
+ yaml_template='{}')
+ ]))
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ self.assertEqual(model_job.get_participants_info(), participants_info)
+ self.assertEqual(mock_inform_model_job.call_args_list, [(('uuid', AuthStatus.AUTHORIZED),)])
+ mock_wake_up.assert_called_with(model_job.workflow_id)
+
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_model_job')
+ def test_patch_model_job(self, mock_inform_model_job, mock_get_system_info):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='test')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ model_job.uuid = 'uuid'
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ model_job.set_participants_info(participants_info)
+ session.commit()
+ resp = self.patch_helper('/api/v2/projects/1/model_jobs/1',
+ data={
+ 'metric_is_public': False,
+ 'auth_status': 'HAHA'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ resp = self.patch_helper('/api/v2/projects/1/model_jobs/1',
+ data={
+ 'metric_is_public': False,
+ 'auth_status': 'PENDING',
+ 'comment': 'hahahaha'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ self.assertFalse(model_job.metric_is_public)
+ self.assertEqual(model_job.auth_status, AuthStatus.PENDING)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ self.assertEqual(model_job.get_participants_info(), participants_info)
+ self.assertEqual(mock_inform_model_job.call_args_list, [(('uuid', AuthStatus.PENDING),)])
+ self.assertEqual(model_job.creator_username, 'ada')
+ self.assertEqual(model_job.comment, 'hahahaha')
+ mock_inform_model_job.reset_mock()
+ self.patch_helper('/api/v2/projects/1/model_jobs/1',
+ data={
+ 'metric_is_public': True,
+ 'auth_status': 'AUTHORIZED'
+ })
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ self.assertTrue(model_job.metric_is_public)
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ self.assertEqual(model_job.get_participants_info(), participants_info)
+ self.assertEqual(mock_inform_model_job.call_args_list, [(('uuid', AuthStatus.AUTHORIZED),)])
+
+ @patch('fedlearner_webconsole.mmgr.model_job_configer.ModelJobConfiger.get_config')
+ def test_put_model_job_with_global_config(self, mock_get_config):
+ mock_get_config.return_value = get_workflow_config(ModelJobType.TRAINING)
+ global_config = get_global_config()
+ resp = self.put_helper('/api/v2/projects/1/model_jobs/1',
+ data={
+ 'dataset_id': 3,
+ 'global_config': to_dict(global_config),
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_get_config.assert_called_with(dataset_id=3,
+ model_id=None,
+ model_job_config=global_config.global_config['test'])
+ with db.session_scope() as sesssion:
+ self.maxDiff = None
+ model_job: ModelJob = sesssion.query(ModelJob).get(1)
+ self.assertEqual(model_job.dataset_id, 3)
+ self.assertEqual(
+ model_job.config(),
+ WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[
+ make_variable('mode', typed_value='train'),
+ make_variable('data_source', typed_value=''),
+ make_variable('data_path', typed_value='/data/dataset/haha/batch'),
+ make_variable('file_wildcard', typed_value='**/part*')
+ ],
+ yaml_template='{}')
+ ]))
+
+ def test_delete_model_job(self):
+ resp = self.delete_helper('/api/v2/projects/1/model_jobs/1')
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ model_job.workflow.state = WorkflowState.STOPPED
+ session.commit()
+ resp = self.delete_helper('/api/v2/projects/1/model_jobs/1')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).execution_options(include_deleted=True).get(1)
+ self.assertIsNotNone(model_job.deleted_at)
+
+
+class ModelJobResultsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=123, name='test-project')
+ session.add(project)
+ model_job = ModelJob(id=123, name='test-model', project_id=project.id, job_name='test-job')
+ session.add(model_job)
+ session.commit()
+
+ @patch('fedlearner_webconsole.mmgr.models.ModelJob.get_job_path')
+ def test_get_results(self, mock_get_job_path):
+ with tempfile.TemporaryDirectory() as file:
+ mock_get_job_path.return_value = file
+ Path(os.path.join(file, 'outputs')).mkdir()
+ Path(os.path.join(file, 'outputs', '1.output')).write_text('output_1', encoding='utf-8')
+ Path(os.path.join(file, 'outputs', '2.output')).write_text('output_2', encoding='utf-8')
+ resp = self.get_helper('/api/v2/projects/123/model_jobs/123/results')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(resp.content_type, 'application/x-tar')
+ with tarfile.TarFile(fileobj=BytesIO(resp.data)) as tar:
+ with tempfile.TemporaryDirectory() as temp_dir:
+ tar.extractall(temp_dir)
+ self.assertEqual(['1.output', '2.output'], sorted(os.listdir(os.path.join(temp_dir, 'outputs'))))
+ with open(os.path.join(temp_dir, 'outputs', '1.output'), encoding='utf-8') as f:
+ self.assertEqual(f.read(), 'output_1')
+
+
+class StartModelJobApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.mmgr.model_job_apis.start_model_job')
+ def test_start_model_job(self, mock_start_model_job: MagicMock):
+ with db.session_scope() as session:
+ model_job = ModelJob(id=1, name='train-job', project_id=1)
+ session.add(model_job)
+ session.commit()
+ resp = self.post_helper(f'/api/v2/projects/1/model_jobs/{model_job.id}:start')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_start_model_job.assert_called_with(model_job_id=1)
+
+
+class StopModelJobApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.mmgr.model_job_apis.stop_model_job')
+ def test_stop_model_job(self, mock_stop_model_job: MagicMock):
+ with db.session_scope() as session:
+ model_job = ModelJob(id=1, name='train_job', workflow_id=1, project_id=1)
+ session.add(model_job)
+ session.commit()
+ resp = self.post_helper(f'/api/v2/projects/1/model_jobs/{model_job.id}:stop')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_stop_model_job.assert_called_with(model_job_id=1)
+
+
+class PeerModelJobTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='party', domain_name='fl-test.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ model_job = ModelJob(id=1, project_id=1, name='model-job', uuid='uuid', workflow_uuid='workflow_uuid')
+ workflow = Workflow(name='workflow', uuid='workflow_uuid')
+ workflow.set_config(WorkflowDefinition(group_alias='haha'))
+ with db.session_scope() as session:
+ session.add_all([project, participant, relationship, model_job])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.get_model_job')
+ def test_get_peer_model_job(self, mock_get_model_job):
+ mock_get_model_job.return_value = GetModelJobResponse(name='name',
+ uuid='uuid',
+ group_uuid='uuid',
+ algorithm_type='NN_VERTICAL',
+ model_job_type='TRAINING',
+ state='COMPLETED',
+ metrics='12',
+ metric_is_public=BoolValue(value=False))
+ resp = self.get_helper('/api/v2/projects/1/model_jobs/1/peers/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_get_model_job.assert_called_with(model_job_uuid='uuid', need_metrics=False)
+ self.assertResponseDataEqual(
+ resp, {
+ 'name': 'name',
+ 'uuid': 'uuid',
+ 'algorithm_type': 'NN_VERTICAL',
+ 'model_job_type': 'TRAINING',
+ 'group_uuid': 'uuid',
+ 'state': 'COMPLETED',
+ 'config': {
+ 'group_alias': '',
+ 'variables': [],
+ 'job_definitions': []
+ },
+ 'metric_is_public': False,
+ })
+
+
+class PeerModelJobMetricsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='party', domain_name='fl-test.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ model_job = ModelJob(id=1, project_id=1, name='model-job', uuid='uuid', workflow_uuid='workflow_uuid')
+ with db.session_scope() as session:
+ session.add_all([project, participant, relationship, model_job])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.get_model_job')
+ def test_get_peer_model_job(self, mock_get_model_job):
+ metrics = {'auc': 0.5}
+ mock_get_model_job.return_value = GetModelJobResponse(name='name', uuid='uuid', metrics=json.dumps(metrics))
+ resp = self.get_helper('/api/v2/projects/1/model_jobs/1/peers/1/metrics')
+ mock_get_model_job.assert_called_with(model_job_uuid='uuid', need_metrics=True)
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'auc': 0.5})
+ mock_get_model_job.assert_called_with(model_job_uuid='uuid', need_metrics=True)
+ self.assertEqual(self.get_response_data(resp), metrics)
+ mock_get_model_job.return_value = GetModelJobResponse(name='name',
+ uuid='uuid',
+ metric_is_public=BoolValue(value=False))
+ resp = self.get_helper('/api/v2/projects/1/model_jobs/1/peers/1/metrics')
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+ mock_get_model_job.return_value = GetModelJobResponse(name='name',
+ uuid='uuid',
+ metric_is_public=BoolValue(value=True))
+ resp = self.get_helper('/api/v2/projects/1/model_jobs/1/peers/1/metrics')
+ # internal error since the metric is not valid
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+
+
+class LaunchModelJobApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ project = Project(id=1, name='test-project')
+ dataset_job = DatasetJob(id=1,
+ name='datasetjob',
+ uuid='uuid',
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=3,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN)
+ dataset = Dataset(id=3,
+ uuid='uuid',
+ name='datasetjob',
+ dataset_type=DatasetType.PSI,
+ path='/data/dataset/haha')
+ algorithm = Algorithm(id=2, name='algorithm')
+ group = ModelJobGroup(name='group',
+ uuid='uuid',
+ project_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_id=2,
+ role=ModelJobRole.COORDINATOR,
+ dataset_id=3)
+ group.set_config(get_workflow_config(ModelJobType.TRAINING))
+ session.add_all([dataset_job, dataset, project, group, algorithm])
+ session.commit()
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_launch_model_job(self, mock_remote_do_two_pc):
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).filter_by(uuid='uuid').first()
+ mock_remote_do_two_pc.return_value = True, ''
+ resp = self.post_helper(f'/api/v2/projects/1/model_job_groups/{group.id}:launch')
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(name='group').first()
+ model_job = group.model_jobs[0]
+ self.assertEqual(model_job.group_id, group.id)
+ self.assertTrue(model_job.project_id, group.project_id)
+ self.assertEqual(model_job.version, 1)
+ self.assertEqual(group.latest_version, 1)
+ self.assertTrue(model_job.algorithm_type, group.algorithm_type)
+ self.assertTrue(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertTrue(model_job.dataset_id, group.dataset_id)
+ self.assertTrue(model_job.workflow.get_config(), group.get_config())
+
+
+class NextAutoUpdateModelJobApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ data_batch = DataBatch(id=1,
+ name='20220101-08',
+ dataset_id=1,
+ event_time=datetime(year=2000, month=1, day=1, hour=8),
+ latest_parent_dataset_job_stage_id=1)
+ group1 = ModelJobGroup(id=1, name='group1', project_id=1, auto_update_status=GroupAutoUpdateStatus.ACTIVE)
+ group2 = ModelJobGroup(id=2, name='group2', project_id=1, auto_update_status=GroupAutoUpdateStatus.ACTIVE)
+ group3 = ModelJobGroup(id=3, name='group3', project_id=1, auto_update_status=GroupAutoUpdateStatus.ACTIVE)
+ model_job1 = ModelJob(id=1,
+ group_id=1,
+ auto_update=False,
+ project_id=1,
+ created_at=datetime(2022, 12, 16, 1, 0, 0),
+ status=ModelJobStatus.SUCCEEDED,
+ data_batch_id=1)
+ global_config2 = ModelJobGlobalConfig(
+ dataset_uuid='uuid',
+ global_config={
+ 'test1': ModelJobConfig(algorithm_uuid='uuid1', variables=[Variable(name='load_model_name')]),
+ 'test2': ModelJobConfig(algorithm_uuid='uuid2', variables=[Variable(name='load_model_name')])
+ })
+ model_job2 = ModelJob(id=2,
+ group_id=2,
+ auto_update=True,
+ project_id=1,
+ data_batch_id=5,
+ created_at=datetime(2022, 12, 16, 2, 0, 0),
+ status=ModelJobStatus.RUNNING)
+ model_job2.set_global_config(global_config2)
+ global_config3 = ModelJobGlobalConfig(
+ dataset_uuid='uuid',
+ global_config={
+ 'test1': ModelJobConfig(algorithm_uuid='uuid1', variables=[Variable(name='load_model_name')]),
+ 'test2': ModelJobConfig(algorithm_uuid='uuid2', variables=[Variable(name='load_model_name')])
+ })
+ model_job3 = ModelJob(id=3,
+ name='test-model',
+ group_id=3,
+ auto_update=True,
+ created_at=datetime(2022, 12, 16, 3, 0, 0),
+ status=ModelJobStatus.SUCCEEDED,
+ data_batch_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ project_id=1,
+ comment='comment',
+ version=3)
+ model_job3.set_global_config(global_config3)
+ session.add_all([project, group1, group2, group3, model_job1, model_job2, model_job3, data_batch])
+ session.commit()
+
+ @patch('fedlearner_webconsole.dataset.services.BatchService.get_next_batch')
+ def test_get_next_auto_update_model_job(self, mock_get_next_batch: MagicMock):
+ # fail due to model job group has no auto update model jobs
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/1/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ # fail due to the latest auto update model job is running
+ mock_get_next_batch.return_value = DataBatch(id=2)
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/2/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(2)
+ model_job.status = ModelJobStatus.CONFIGURED
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/2/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(2)
+ model_job.status = ModelJobStatus.PENDING
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/2/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+ # when the latest auto update model job is stopped and there is no previous successful model job
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(2)
+ model_job.status = ModelJobStatus.STOPPED
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/2/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['data_batch_id'], 5)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={
+ 'test1':
+ ModelJobConfig(algorithm_uuid='uuid1',
+ variables=[
+ Variable(name='load_model_name',
+ value='',
+ value_type=Variable.ValueType.STRING)
+ ]),
+ 'test2':
+ ModelJobConfig(algorithm_uuid='uuid2',
+ variables=[
+ Variable(name='load_model_name',
+ value='',
+ value_type=Variable.ValueType.STRING)
+ ])
+ })
+ self.assertEqual(data['global_config'], to_dict(global_config))
+ # when the latest auto model job is failed and there is previous successful auto update model job
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(2)
+ model_job.status = ModelJobStatus.FAILED
+ global_config = ModelJobGlobalConfig(
+ dataset_uuid='uuid',
+ global_config={
+ 'test1': ModelJobConfig(algorithm_uuid='uuid1', variables=[Variable(name='load_model_name')]),
+ 'test2': ModelJobConfig(algorithm_uuid='uuid2', variables=[Variable(name='load_model_name')])
+ })
+ model_job = ModelJob(id=4,
+ group_id=2,
+ auto_update=True,
+ project_id=1,
+ data_batch_id=3,
+ created_at=datetime(2022, 12, 16, 1, 0, 0),
+ status=ModelJobStatus.SUCCEEDED,
+ model_id=2)
+ model = Model(id=2, model_job_id=4, name='test-previous-model')
+ model_job.set_global_config(global_config)
+ session.add_all([model, model_job])
+ session.commit()
+ mock_get_next_batch.return_value = DataBatch(id=4)
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/2/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['data_batch_id'], 4)
+ global_config = ModelJobGlobalConfig(
+ dataset_uuid='uuid',
+ global_config={
+ 'test1':
+ ModelJobConfig(algorithm_uuid='uuid1',
+ variables=[
+ Variable(name='load_model_name',
+ value='test-previous-model',
+ typed_value=Value(string_value='test-previous-model'),
+ value_type=Variable.ValueType.STRING)
+ ]),
+ 'test2':
+ ModelJobConfig(algorithm_uuid='uuid2',
+ variables=[
+ Variable(name='load_model_name',
+ value='test-previous-model',
+ typed_value=Value(string_value='test-previous-model'),
+ value_type=Variable.ValueType.STRING)
+ ])
+ })
+ self.assertEqual(data['global_config'], to_dict(global_config))
+ # when the latest auto update model job is succeeded and next batch is None
+ mock_get_next_batch.return_value = None
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/3/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['data_batch_id'], 0)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={
+ 'test1':
+ ModelJobConfig(algorithm_uuid='uuid1',
+ variables=[
+ Variable(name='load_model_name',
+ value='',
+ typed_value=Value(string_value=''),
+ value_type=Variable.ValueType.STRING)
+ ]),
+ 'test2':
+ ModelJobConfig(algorithm_uuid='uuid2',
+ variables=[
+ Variable(name='load_model_name',
+ value='',
+ typed_value=Value(string_value=''),
+ value_type=Variable.ValueType.STRING)
+ ])
+ })
+ self.assertEqual(data['global_config'], to_dict(global_config))
+ # when the latest auto update model job is succeeded and there is next data batch, but there is no model
+ mock_get_next_batch.return_value = DataBatch(id=3)
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/3/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ # when the latest auto update model job is succeeded and there is next data batch, and there is model
+ with db.session_scope() as session:
+ model = Model(id=1, name='test-model', model_job_id=3, uuid='uuid')
+ session.add(model)
+ session.commit()
+ mock_get_next_batch.return_value = DataBatch(id=3)
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/3/next_auto_update_model_job')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ global_config = ModelJobGlobalConfig(
+ dataset_uuid='uuid',
+ global_config={
+ 'test1':
+ ModelJobConfig(algorithm_uuid='uuid1',
+ variables=[
+ Variable(name='load_model_name',
+ value='test-model',
+ typed_value=Value(string_value='test-model'),
+ value_type=Variable.ValueType.STRING)
+ ]),
+ 'test2':
+ ModelJobConfig(algorithm_uuid='uuid2',
+ variables=[
+ Variable(name='load_model_name',
+ value='test-model',
+ typed_value=Value(string_value='test-model'),
+ value_type=Variable.ValueType.STRING)
+ ])
+ })
+ self.assertEqual(data['data_batch_id'], 3)
+ self.assertEqual(data['global_config'], to_dict(global_config))
+ self.assertEqual(data['model_id'], 1)
+
+
+class ModelJobDefinitionApiTest(BaseTestCase):
+
+ def test_get_definitions(self):
+ resp = self.get_helper('/api/v2/model_job_definitions?algorithm_type=NN_VERTICAL&model_job_type=TRAINING')
+ data = self.get_response_data(resp)
+ self.assertEqual(data['is_federated'], True)
+ self.assertEqual(len(data['variables']), 32)
+ resp = self.get_helper('/api/v2/model_job_definitions?algorithm_type=NN_HORIZONTAL&model_job_type=EVALUATION')
+ data = self.get_response_data(resp)
+ self.assertEqual(data['is_federated'], False)
+ self.assertEqual(len(data['variables']), 8)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_configer.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_configer.py
new file mode 100644
index 000000000..29d40c3d7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_configer.py
@@ -0,0 +1,179 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import logging
+from typing import Optional, List
+from sqlalchemy.orm.session import Session
+from google.protobuf.struct_pb2 import Value
+
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DataBatch
+from fedlearner_webconsole.mmgr.models import Model, ModelJobType
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from fedlearner_webconsole.workflow_template.utils import make_variable, set_value
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobConfig
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.utils.const import SYS_PRESET_TREE_TEMPLATE, SYS_PRESET_VERTICAL_NN_TEMPLATE, \
+ SYS_PRESET_HORIZONTAL_NN_TEMPLATE, SYS_PRESET_HORIZONTAL_NN_EVAL_TEMPLATE
+
+LOAD_MODEL_NAME = 'load_model_name'
+
+
+def set_load_model_name(config: ModelJobConfig, model_name: str):
+ """Set variable of load_model_name inplace"""
+ for variable in config.variables:
+ if variable.name == LOAD_MODEL_NAME:
+ assert variable.value_type == Variable.ValueType.STRING
+ variable.value = model_name
+ variable.typed_value.MergeFrom(Value(string_value=model_name))
+
+
+def get_sys_template_id(session: Session, algorithm_type: AlgorithmType, model_job_type: ModelJobType) -> Optional[int]:
+ template_name = None
+ if algorithm_type == AlgorithmType.NN_VERTICAL:
+ template_name = SYS_PRESET_VERTICAL_NN_TEMPLATE
+ if algorithm_type == AlgorithmType.NN_HORIZONTAL:
+ if model_job_type == ModelJobType.TRAINING:
+ template_name = SYS_PRESET_HORIZONTAL_NN_TEMPLATE
+ else:
+ template_name = SYS_PRESET_HORIZONTAL_NN_EVAL_TEMPLATE
+ if algorithm_type == AlgorithmType.TREE_VERTICAL:
+ template_name = SYS_PRESET_TREE_TEMPLATE
+ if template_name:
+ template_id = session.query(WorkflowTemplate.id).filter_by(name=template_name).first()
+ if template_id is not None:
+ return template_id[0]
+ return None
+
+
+def _set_variable(variables: List[Variable], new_variable: Variable):
+ for variable in variables:
+ if variable.name == new_variable.name:
+ variable.CopyFrom(new_variable)
+ return
+ raise Exception(f'variable {new_variable.name} is not found')
+
+
+class ModelJobConfiger:
+
+ def __init__(self, session: Session, model_job_type: ModelJobType, algorithm_type: AlgorithmType, project_id: int):
+ self._session = session
+ self.model_job_type = model_job_type
+ self.algorithm_type = algorithm_type
+ self.project_id = project_id
+
+ @staticmethod
+ def _init_config(config: WorkflowDefinition, variables: List[Variable]):
+ assert len(config.job_definitions) == 1
+ new_dict = {i.name: i for i in variables}
+ for var in config.job_definitions[0].variables:
+ if var.name in new_dict:
+ var.typed_value.CopyFrom(new_dict[var.name].typed_value)
+ var.value = new_dict[var.name].value
+
+ def _get_config(self) -> WorkflowDefinition:
+ template_id = get_sys_template_id(session=self._session,
+ algorithm_type=self.algorithm_type,
+ model_job_type=self.model_job_type)
+ if template_id is None:
+ raise InternalException('preset template is not found')
+ template: WorkflowTemplate = self._session.query(WorkflowTemplate).get(template_id)
+ return template.get_config()
+
+ def get_dataset_variables(self, dataset_id: Optional[int], data_batch_id: Optional[int] = None) -> List[Variable]:
+ if dataset_id is None:
+ return []
+ dataset: Dataset = self._session.query(Dataset).get(dataset_id)
+ dataset_job: DatasetJob = self._session.query(DatasetJob).filter_by(output_dataset_id=dataset_id).first()
+ if dataset_job is None:
+ raise InvalidArgumentException(f'dataset job for dataset {dataset_id} is not found')
+ data_source = dataset.get_data_source()
+ data_path = os.path.join(dataset.path, 'batch')
+ if data_batch_id is not None:
+ data_batch = self._session.query(DataBatch).get(data_batch_id)
+ data_path = data_batch.path
+ # TODO(hangweiqiang): use data path for all kind, and set file_wildcard for nn
+ variables = []
+ if dataset_job.kind == DatasetJobKind.RSA_PSI_DATA_JOIN:
+ # there is no data_source in nn horizontal preset template
+ if self.algorithm_type != AlgorithmType.NN_HORIZONTAL:
+ variables.append(make_variable(name='data_source', typed_value=data_source))
+ variables.append(make_variable(name='data_path', typed_value=''))
+ if self.algorithm_type == AlgorithmType.TREE_VERTICAL:
+ variables.append(make_variable(name='file_wildcard', typed_value='*.data'))
+ if dataset_job.kind in [
+ DatasetJobKind.OT_PSI_DATA_JOIN, DatasetJobKind.HASH_DATA_JOIN, DatasetJobKind.DATA_ALIGNMENT,
+ DatasetJobKind.IMPORT_SOURCE
+ ]:
+ # there is no data_source in nn horizontal preset template
+ if self.algorithm_type != AlgorithmType.NN_HORIZONTAL:
+ variables.append(make_variable(name='data_source', typed_value=''))
+ variables.append(make_variable(name='data_path', typed_value=data_path))
+ if self.algorithm_type == AlgorithmType.TREE_VERTICAL:
+ variables.append(make_variable(name='file_wildcard', typed_value='**/part*'))
+ return variables
+
+ def get_config(self, dataset_id: int, model_id: Optional[int],
+ model_job_config: ModelJobConfig) -> WorkflowDefinition:
+ """get local workflow config from model_job_config"""
+ config = self._get_config()
+ self._init_config(config=config, variables=model_job_config.variables)
+ mode = 'train' if self.model_job_type == ModelJobType.TRAINING else 'eval'
+ variables = config.job_definitions[0].variables
+ # there is no mode variable in nn horizontal preset template
+ if self.algorithm_type != AlgorithmType.NN_HORIZONTAL:
+ _set_variable(variables=variables, new_variable=make_variable(name='mode', typed_value=mode))
+ dataset_variables = self.get_dataset_variables(dataset_id=dataset_id)
+ for var in dataset_variables:
+ _set_variable(variables=variables, new_variable=var)
+ if model_job_config.algorithm_uuid:
+ algorithm = AlgorithmFetcher(self.project_id).get_algorithm(model_job_config.algorithm_uuid)
+ parameter = model_job_config.algorithm_parameter
+ algo_dict = {
+ 'algorithmId': algorithm.id,
+ 'algorithmUuid': algorithm.uuid,
+ 'algorithmProjectId': algorithm.algorithm_project_id,
+ 'algorithmProjectUuid': algorithm.algorithm_project_uuid,
+ 'participantId': algorithm.participant_id,
+ 'path': algorithm.path,
+ 'config': to_dict(parameter)['variables']
+ }
+ variables = config.job_definitions[0].variables
+ for variable in variables:
+ if variable.name == 'algorithm':
+ set_value(variable=variable, typed_value=algo_dict)
+ if model_id is not None:
+ model: Model = self._session.query(Model).get(model_id)
+ _set_variable(variables=variables,
+ new_variable=make_variable(name='load_model_name', typed_value=model.job_name()))
+ return config
+
+ # TODO(hangweiqiang): remove this function after ModelJobConfig is used
+ def set_dataset(self, config: WorkflowDefinition, dataset_id: Optional[int], data_batch_id: Optional[int] = None):
+ variables = config.job_definitions[0].variables
+ dataset_variables = self.get_dataset_variables(dataset_id=dataset_id, data_batch_id=data_batch_id)
+ names = {variable.name for variable in variables}
+ for variable in dataset_variables:
+ # check existence of variable in config for backward compatibility
+ if variable.name in names:
+ _set_variable(variables=variables, new_variable=variable)
+ else:
+ logging.info(f'[set_dataset] variable {variable.name} is not found in config')
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_configer_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_configer_test.py
new file mode 100644
index 000000000..187da1a86
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_configer_test.py
@@ -0,0 +1,324 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import tempfile
+import unittest
+from datetime import datetime
+from envs import Envs
+from unittest.mock import patch
+from google.protobuf.struct_pb2 import Value
+
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.mmgr.models import Model, ModelJobType
+from fedlearner_webconsole.mmgr.model_job_configer import ModelJobConfiger, get_sys_template_id, set_load_model_name
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, AlgorithmType
+from fedlearner_webconsole.algorithm.utils import algorithm_cache_path
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobState, DatasetJobKind, DatasetType, \
+ DatasetJobStage, DataBatch
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobConfig
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmParameter, AlgorithmVariable
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.utils.proto import to_dict, remove_secrets
+
+
+def _get_config() -> WorkflowDefinition:
+ return WorkflowDefinition(job_definitions=[
+ JobDefinition(variables=[
+ Variable(name='data_source'),
+ Variable(name='data_path'),
+ Variable(name='file_wildcard'),
+ ])
+ ])
+
+
+def _set_config(config: WorkflowDefinition, name: str, value: str):
+ for var in config.job_definitions[0].variables:
+ if var.name == name:
+ var.value = value
+ var.typed_value.MergeFrom(Value(string_value=value))
+ var.value_type = Variable.ValueType.STRING
+
+
+class UtilsTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ session.commit()
+
+ def test_get_template(self):
+ with db.session_scope() as session:
+ template_id = get_sys_template_id(session,
+ AlgorithmType.TREE_VERTICAL,
+ model_job_type=ModelJobType.TRAINING)
+ self.assertEqual(session.query(WorkflowTemplate).get(template_id).name, 'sys-preset-tree-model')
+ template_id = get_sys_template_id(session, AlgorithmType.NN_VERTICAL, model_job_type=ModelJobType.TRAINING)
+ self.assertEqual(session.query(WorkflowTemplate).get(template_id).name, 'sys-preset-nn-model')
+ template_id = get_sys_template_id(session,
+ AlgorithmType.NN_HORIZONTAL,
+ model_job_type=ModelJobType.TRAINING)
+ self.assertEqual(session.query(WorkflowTemplate).get(template_id).name, 'sys-preset-nn-horizontal-model')
+ template_id = get_sys_template_id(session,
+ AlgorithmType.NN_HORIZONTAL,
+ model_job_type=ModelJobType.EVALUATION)
+ self.assertEqual(
+ session.query(WorkflowTemplate).get(template_id).name, 'sys-preset-nn-horizontal-eval-model')
+
+ def test_set_load_model_name(self):
+ config = ModelJobConfig(algorithm_uuid='uuid', variables=[Variable(name='load_model_name')])
+ set_load_model_name(config, 'test-model')
+ expected_config = ModelJobConfig(algorithm_uuid='uuid',
+ variables=[
+ Variable(name='load_model_name',
+ value='test-model',
+ typed_value=Value(string_value='test-model'),
+ value_type=Variable.ValueType.STRING)
+ ])
+ self.assertEqual(config, expected_config)
+ config = ModelJobConfig(algorithm_uuid='uuid', variables=[Variable(name='test')])
+ set_load_model_name(config, 'test-model')
+ expected_config = ModelJobConfig(algorithm_uuid='uuid', variables=[Variable(name='test')])
+ self.assertEqual(config, expected_config)
+
+
+class ModelJobConfigerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='part', domain_name='test')
+ project.participants = [participant]
+ dataset_job = DatasetJob(id=1,
+ name='data-join',
+ uuid='dataset-job-uuid',
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ workflow_id=1)
+ dataset_job_stage = DatasetJobStage(id=1,
+ name='data-join',
+ uuid='dataset-job-stage-uuid',
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ dataset_job_id=1,
+ data_batch_id=1)
+ workflow = Workflow(id=1, uuid='workflow-uuid', name='workflow')
+ dataset = Dataset(id=2, uuid='uuid', name='datasetjob', dataset_type=DatasetType.PSI, path='/data/dataset/haha')
+ data_batch = DataBatch(id=1,
+ name='20221213',
+ dataset_id=1,
+ path='/data/dataset/haha/batch/20221213',
+ event_time=datetime(2022, 12, 13, 16, 37, 37))
+ algorithm_project = AlgorithmProject(id=1,
+ name='algo-project',
+ uuid='uuid',
+ type=AlgorithmType.NN_VERTICAL,
+ path='/data/algorithm_project/uuid')
+ algorithm = Algorithm(id=1,
+ name='algo',
+ uuid='uuid',
+ type=AlgorithmType.NN_VERTICAL,
+ path='/data/algorithm/uuid',
+ algorithm_project_id=1)
+ parameter = AlgorithmParameter()
+ parameter.variables.extend([AlgorithmVariable(name='EMBED_SIZE', value='128')])
+ algorithm.set_parameter(parameter=parameter)
+ job = Job(id=1,
+ name='uuid-train-job',
+ workflow_id=1,
+ project_id=1,
+ job_type=JobType.NN_MODEL_TRANINING,
+ state=JobState.COMPLETED)
+ model = Model(id=2, name='model', job_id=1)
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ session.add_all([
+ project, participant, dataset_job, dataset_job_stage, dataset, algorithm, algorithm_project, model, job,
+ workflow, data_batch
+ ])
+ session.commit()
+
+ def test_get_config(self):
+ with db.session_scope() as session:
+ parameter = AlgorithmParameter(variables=[AlgorithmVariable(name='EMBED_SIZE', value='256')])
+ model_job_config = ModelJobConfig(
+ algorithm_uuid='uuid',
+ algorithm_parameter=parameter,
+ variables=[Variable(name='sparse_estimator', typed_value=Value(string_value='true'))])
+ configer = ModelJobConfiger(session, ModelJobType.TRAINING, AlgorithmType.NN_VERTICAL, 1)
+ config = configer.get_config(dataset_id=2, model_id=2, model_job_config=model_job_config)
+ self.assertEqual(config.job_definitions[0].job_type, JobDefinition.JobType.NN_MODEL_TRANINING)
+ self.assertEqual(len(config.job_definitions), 1)
+ var_dict = {var.name: var for var in config.job_definitions[0].variables}
+ self.assertEqual(var_dict['load_model_name'].typed_value, Value(string_value='uuid-train-job'))
+ self.assertEqual(var_dict['data_source'].typed_value,
+ Value(string_value='dataset-job-stage-uuid-psi-data-join-job'))
+ self.assertEqual(var_dict['mode'].typed_value, Value(string_value='train'))
+ self.assertEqual(var_dict['sparse_estimator'].typed_value, Value(string_value='true'))
+ self.assertEqual(
+ to_dict(var_dict['algorithm'].typed_value), {
+ 'algorithmId': 1.0,
+ 'algorithmUuid': 'uuid',
+ 'algorithmProjectUuid': 'uuid',
+ 'config': [{
+ 'comment': '',
+ 'display_name': '',
+ 'name': 'EMBED_SIZE',
+ 'required': False,
+ 'value': '256',
+ 'value_type': 'STRING'
+ }],
+ 'participantId': 0.0,
+ 'path': '/data/algorithm/uuid',
+ 'algorithmProjectId': 1.0
+ })
+ self.assertEqual(json.loads(var_dict['algorithm'].widget_schema), {
+ 'component': 'AlgorithmSelect',
+ 'required': True,
+ 'tag': 'OPERATING_PARAM'
+ })
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm_files')
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_client.ResourceServiceClient.get_algorithm')
+ def test_get_config_when_algorithm_from_participant(self, mock_get_algorithm, mock_get_algorithm_files):
+ with db.session_scope() as session:
+ algo = session.query(Algorithm).get(1).to_proto()
+ algo.uuid = 'uuid-from-participant'
+ mock_get_algorithm.return_value = remove_secrets(algo)
+ parameter = AlgorithmParameter(variables=[AlgorithmVariable(name='EMBED_SIZE', value='256')])
+ model_job_config = ModelJobConfig(
+ algorithm_uuid='uuid-from-participant',
+ algorithm_parameter=parameter,
+ variables=[Variable(name='sparse_estimator', typed_value=Value(string_value='true'))])
+ configer = ModelJobConfiger(session, ModelJobType.TRAINING, AlgorithmType.NN_VERTICAL, 1)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ Envs.STORAGE_ROOT = temp_dir
+ config = configer.get_config(dataset_id=2, model_id=2, model_job_config=model_job_config)
+ var_dict = {var.name: var for var in config.job_definitions[0].variables}
+ self.assertEqual(
+ to_dict(var_dict['algorithm'].typed_value), {
+ 'algorithmId': 0,
+ 'algorithmUuid': 'uuid-from-participant',
+ 'algorithmProjectUuid': 'uuid',
+ 'config': [{
+ 'comment': '',
+ 'display_name': '',
+ 'name': 'EMBED_SIZE',
+ 'required': False,
+ 'value': '256',
+ 'value_type': 'STRING'
+ }],
+ 'participantId': 1.0,
+ 'path': algorithm_cache_path(Envs.STORAGE_ROOT, 'uuid-from-participant'),
+ 'algorithmProjectId': 0
+ })
+
+ def test_get_dataset_variables(self):
+ with db.session_scope() as session:
+ # test for config RSA dataset for tree
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.TREE_VERTICAL,
+ project_id=1)
+ variables = configer.get_dataset_variables(dataset_id=2)
+ expected_variables = [
+ make_variable(name='data_source', typed_value='dataset-job-stage-uuid-psi-data-join-job'),
+ make_variable(name='data_path', typed_value=''),
+ make_variable(name='file_wildcard', typed_value='*.data'),
+ ]
+ self.assertEqual(variables, expected_variables)
+ # test for config RSA dataset for NN
+ config = _get_config()
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ project_id=1)
+ variables = configer.get_dataset_variables(dataset_id=2)
+ expected_variables = [
+ make_variable(name='data_source', typed_value='dataset-job-stage-uuid-psi-data-join-job'),
+ make_variable(name='data_path', typed_value='')
+ ]
+ self.assertEqual(variables, expected_variables)
+ # test for config RSA dataset when datset_job_stage is None
+ dataset_job_stage = session.query(DatasetJobStage).get(1)
+ dataset_job_stage.dataset_job_id = 2
+ session.flush()
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ project_id=1)
+ variables = configer.get_dataset_variables(dataset_id=2)
+ expected_variables = [
+ make_variable(name='data_source', typed_value='workflow-uuid-psi-data-join-job'),
+ make_variable(name='data_path', typed_value='')
+ ]
+ self.assertEqual(variables, expected_variables)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.kind = DatasetJobKind.OT_PSI_DATA_JOIN
+ session.commit()
+ with db.session_scope() as session:
+ # test for config OT dataset for tree
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.TREE_VERTICAL,
+ project_id=1)
+ variables = configer.get_dataset_variables(dataset_id=2)
+ expected_variables = [
+ make_variable(name='data_source', typed_value=''),
+ make_variable(name='data_path', typed_value='/data/dataset/haha/batch'),
+ make_variable(name='file_wildcard', typed_value='**/part*'),
+ ]
+ self.assertEqual(variables, expected_variables)
+ # test for config OT dataset for nn
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ project_id=1)
+ variables = configer.get_dataset_variables(dataset_id=2)
+ expected_variables = [
+ make_variable(name='data_source', typed_value=''),
+ make_variable(name='data_path', typed_value='/data/dataset/haha/batch'),
+ ]
+ self.assertEqual(variables, expected_variables)
+ # test when data_batch_id is set
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ project_id=1)
+ variables = configer.get_dataset_variables(dataset_id=2, data_batch_id=1)
+ expected_variables = [
+ make_variable(name='data_source', typed_value=''),
+ make_variable(name='data_path', typed_value='/data/dataset/haha/batch/20221213')
+ ]
+ self.assertEqual(variables, expected_variables)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_group_apis.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_group_apis.py
new file mode 100644
index 000000000..83f2daa19
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_group_apis.py
@@ -0,0 +1,687 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from http import HTTPStatus
+from flask_restful import Resource
+from typing import Optional
+from webargs.flaskparser import use_args, use_kwargs
+from marshmallow import Schema, post_load, fields, validate
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import ResourceConflictException, InternalException, InvalidArgumentException
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterOp, SimpleExpression
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, AlgorithmProjectList
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.proto.review_pb2 import TicketDetails, TicketType
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder
+from fedlearner_webconsole.utils.sorting import SorterBuilder, SortExpression, parse_expression
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.workflow_template.service import dict_to_workflow_definition
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.mmgr.controller import CreateModelJobGroup, ModelJobGroupController
+from fedlearner_webconsole.mmgr.models import ModelJobGroup, ModelJobType, ModelJobRole, GroupCreateStatus, \
+ GroupAutoUpdateStatus
+from fedlearner_webconsole.mmgr.service import ModelJobGroupService, get_model_job_group, get_participant,\
+ get_dataset, get_project, get_algorithm, ModelJobService
+from fedlearner_webconsole.mmgr.model_job_configer import ModelJobConfiger
+from fedlearner_webconsole.algorithm.models import AlgorithmType, AlgorithmProject
+from fedlearner_webconsole.utils.decorators.pp_flask import input_validator
+from fedlearner_webconsole.utils.flask_utils import FilterExpField, make_flask_response, get_current_user
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+from fedlearner_webconsole.swagger.models import schema_manager
+
+
+class CreateModelJobGroupParams(Schema):
+ name = fields.Str(required=True)
+ dataset_id = fields.Integer(required=False, load_default=None)
+ algorithm_type = fields.Str(required=True,
+ validate=validate.OneOf([
+ AlgorithmType.TREE_VERTICAL.name, AlgorithmType.NN_VERTICAL.name,
+ AlgorithmType.NN_HORIZONTAL.name
+ ]))
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['algorithm_type'] = AlgorithmType[data['algorithm_type']]
+ return data
+
+
+class CreateModelJobGroupParamsV2(Schema):
+ name = fields.Str(required=True)
+ dataset_id = fields.Integer(required=True)
+ algorithm_type = fields.Str(required=True,
+ validate=validate.OneOf([
+ AlgorithmType.TREE_VERTICAL.name, AlgorithmType.NN_VERTICAL.name,
+ AlgorithmType.NN_HORIZONTAL.name
+ ]))
+ algorithm_project_list = fields.Dict(required=False, load_default=None)
+ comment = fields.Str(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['algorithm_type'] = AlgorithmType[data['algorithm_type']]
+ if data['algorithm_project_list'] is not None:
+ data['algorithm_project_list'] = ParseDict(data['algorithm_project_list'], AlgorithmProjectList())
+ return data
+
+
+class ConfigModelJobGroupParams(Schema):
+ authorized = fields.Boolean(required=False, load_default=None)
+ algorithm_id = fields.Integer(required=False, load_default=None)
+ config = fields.Dict(required=False, load_default=None)
+ cron_config = fields.String(required=False, load_default=None)
+ comment = fields.Str(required=False, load_default=None)
+ # TODO(gezhengqiang): delete dataset_id
+ dataset_id = fields.Integer(required=False, load_default=None)
+ global_config = fields.Dict(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ if data['config'] is not None:
+ data['config'] = dict_to_workflow_definition(data['config'])
+ if data['global_config'] is not None:
+ data['global_config'] = ParseDict(data['global_config'], ModelJobGlobalConfig())
+ return data
+
+
+class ConfigPeerModelJobGroup(Schema):
+ config = fields.Dict(required=False, load_default=None)
+ global_config = fields.Dict(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ if data['config'] is None and data['global_config'] is None:
+ raise InvalidArgumentException('either config or global config must be set')
+ if data['config'] is not None:
+ data['config'] = dict_to_workflow_definition(data['config'])
+ if data['global_config'] is not None:
+ data['global_config'] = ParseDict(data['global_config'], ModelJobGlobalConfig())
+ return data
+
+
+def _build_group_configured_query(exp: SimpleExpression):
+ if exp.bool_value:
+ return ModelJobGroup.config.isnot(None)
+ return ModelJobGroup.config.is_(None)
+
+
+class ModelJobGroupsApi(Resource):
+
+ FILTER_FIELDS = {
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'configured': SupportedField(
+ type=FieldType.BOOL,
+ ops={
+ FilterOp.EQUAL: _build_group_configured_query,
+ },
+ ),
+ 'role': SupportedField(type=FieldType.STRING, ops={
+ FilterOp.IN: None,
+ }),
+ 'algorithm_type': SupportedField(type=FieldType.STRING, ops={
+ FilterOp.IN: None,
+ }),
+ }
+
+ SORTER_FIELDS = ['created_at']
+
+ def __init__(self):
+ self._filter_builder = FilterBuilder(model_class=ModelJobGroup, supported_fields=self.FILTER_FIELDS)
+ self._sorter_builder = SorterBuilder(model_class=ModelJobGroup, supported_fields=self.SORTER_FIELDS)
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'page': fields.Integer(required=False, load_default=None),
+ 'page_size': fields.Integer(required=False, load_default=None),
+ 'filter_exp': FilterExpField(data_key='filter', required=False, load_default=None),
+ 'sorter_exp': fields.String(required=False, load_default=None, data_key='order_by'),
+ },
+ location='query')
+ def get(
+ self,
+ page: Optional[int],
+ page_size: Optional[int],
+ filter_exp: Optional[FilterExpression],
+ sorter_exp: Optional[str],
+ project_id: int,
+ ):
+ """Get the list of model job groups
+ ---
+ tags:
+ - mmgr
+ description: get the list of model job groups
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: filter
+ schema:
+ type: string
+ responses:
+ 200:
+ description: the list of model job groups
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobGroupRef'
+ """
+ with db.session_scope() as session:
+ # to filter out groups created by old api determined by uuid
+ query = session.query(ModelJobGroup).filter(ModelJobGroup.uuid.isnot(None))
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ if filter_exp:
+ try:
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ try:
+ if sorter_exp is not None:
+ sorter_exp = parse_expression(sorter_exp)
+ else:
+ sorter_exp = SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ pagination = paginate(query, page, page_size)
+ for group in pagination.get_items():
+ if len(group.model_jobs) != 0:
+ ModelJobService(session).update_model_job_status(group.model_jobs[0])
+ data = [d.to_ref() for d in pagination.get_items()]
+ session.commit()
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP, op_type=Event.OperationType.CREATE)
+ @use_args(CreateModelJobGroupParams(), location='json')
+ def post(self, params: dict, project_id: int):
+ """Create the model job group
+ ---
+ tags:
+ - mmgr
+ description: create the model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/CreateModelJobGroupParams'
+ responses:
+ 201:
+ description: the detail of the model job group
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobGroupPb'
+ 409:
+ description: the group already exists
+ 500:
+ description: error exists when creating model job by 2PC
+ """
+ name = params['name']
+ dataset_id = params['dataset_id']
+ with db.session_scope() as session:
+ get_project(project_id, session)
+ if dataset_id:
+ get_dataset(dataset_id, session)
+ group = session.query(ModelJobGroup).filter_by(name=name).first()
+ if group is not None:
+ raise ResourceConflictException(f'group {name} already exists')
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ model_job_group_uuid = resource_uuid()
+ group = ModelJobGroup(name=name,
+ uuid=model_job_group_uuid,
+ project_id=project_id,
+ dataset_id=dataset_id,
+ algorithm_type=params['algorithm_type'],
+ role=ModelJobRole.COORDINATOR,
+ creator_username=get_current_user().username,
+ authorized=True,
+ auth_status=AuthStatus.AUTHORIZED)
+ participants = ParticipantService(session).get_participants_by_project(project_id)
+ participants_info = ParticipantsInfo(participants_map={
+ p.pure_domain_name(): ParticipantInfo(auth_status=AuthStatus.PENDING.name) for p in participants
+ })
+ participants_info.participants_map[pure_domain_name].auth_status = AuthStatus.AUTHORIZED.name
+ group.set_participants_info(participants_info)
+ session.add(group)
+ ticket_helper = get_ticket_helper(session)
+ ticket_helper.create_ticket(TicketType.CREATE_MODELJOB_GROUP, TicketDetails(uuid=group.uuid))
+ session.commit()
+ succeeded, msg = CreateModelJobGroup().run(project_id=project_id,
+ name=name,
+ algorithm_type=params['algorithm_type'],
+ dataset_id=dataset_id,
+ coordinator_pure_domain_name=pure_domain_name,
+ model_job_group_uuid=model_job_group_uuid)
+ if not succeeded:
+ raise InternalException(f'creating model job by 2PC with message: {msg}')
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(name=name).first()
+ group.status = GroupCreateStatus.SUCCEEDED
+ session.commit()
+ return make_flask_response(data=group.to_proto(), status=HTTPStatus.CREATED)
+
+
+class ModelJobGroupsApiV2(Resource):
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP, op_type=Event.OperationType.CREATE)
+ @use_args(CreateModelJobGroupParamsV2(), location='json')
+ def post(self, params: dict, project_id: int):
+ """Create the model job group
+ ---
+ tags:
+ - mmgr
+ description: create the model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/CreateModelJobGroupParamsV2'
+ responses:
+ 201:
+ description: the detail of the model job group
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobGroupPb'
+ 409:
+ description: the group already exists
+ """
+ name = params['name']
+ dataset_id = params['dataset_id']
+ algorithm_type = params['algorithm_type']
+ algorithm_project_list = AlgorithmProjectList()
+ if params['algorithm_project_list']:
+ algorithm_project_list = params['algorithm_project_list']
+ with db.session_scope() as session:
+ get_project(project_id, session)
+ get_dataset(dataset_id, session)
+ group = session.query(ModelJobGroup).filter_by(name=name).first()
+ if group is not None:
+ raise ResourceConflictException(f'group {name} already exists')
+ model_job_group_uuid = resource_uuid()
+ group = ModelJobGroup(name=name,
+ uuid=model_job_group_uuid,
+ project_id=project_id,
+ dataset_id=dataset_id,
+ algorithm_type=algorithm_type,
+ role=ModelJobRole.COORDINATOR,
+ creator_username=get_current_user().username,
+ comment=params['comment'])
+ # make configured true
+ group.set_config()
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ # set algorithm project uuid map
+ group.set_algorithm_project_uuid_list(algorithm_project_list)
+ # set algorithm project id
+ algorithm_project_uuid = algorithm_project_list.algorithm_projects.get(pure_domain_name)
+ if algorithm_project_uuid is None and algorithm_type not in [AlgorithmType.TREE_VERTICAL]:
+ raise Exception(f'algorithm project uuid must be given if algorithm type is {algorithm_type.name}')
+ if algorithm_project_uuid is not None:
+ algorithm_project = session.query(AlgorithmProject).filter_by(uuid=algorithm_project_uuid).first()
+ if algorithm_project is not None:
+ group.algorithm_project_id = algorithm_project.id
+ session.add(group)
+ ModelJobGroupService(session).initialize_auth_status(group)
+ ticket_helper = get_ticket_helper(session)
+ ticket_helper.create_ticket(TicketType.CREATE_MODELJOB_GROUP, TicketDetails(uuid=group.uuid))
+ if group.ticket_status in [TicketStatus.APPROVED]:
+ ModelJobGroupController(
+ session=session,
+ project_id=project_id).create_model_job_group_for_participants(model_job_group_id=group.id)
+ session.commit()
+ return make_flask_response(group.to_proto(), status=HTTPStatus.CREATED)
+
+
+class ModelJobGroupApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, group_id: int):
+ """Get the model job group
+ ---
+ tags:
+ - mmgr
+ descriptions: get the model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the model job group
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobGroupPb'
+ """
+ with db.session_scope() as session:
+ group = get_model_job_group(project_id=project_id, group_id=group_id, session=session)
+ ModelJobGroupController(session, project_id).update_participants_auth_status(group)
+ return make_flask_response(group.to_proto())
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP, op_type=Event.OperationType.UPDATE)
+ @use_args(ConfigModelJobGroupParams(), location='json')
+ def put(self, params: dict, project_id: int, group_id: int):
+ """Update the model job group
+ ---
+ tags:
+ - mmgr
+ description: update the model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/ConfigModelJobGroupParams'
+ responses:
+ 200:
+ description: update the model job group successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobGroupPb'
+ 400:
+ description: algorihm is not found or algorithm type mismatch between group and algorithms
+ """
+ with db.session_scope() as session:
+ group = get_model_job_group(project_id=project_id, group_id=group_id, session=session)
+ if params['authorized'] is not None:
+ group.authorized = params['authorized']
+ if group.authorized:
+ group.auth_status = AuthStatus.AUTHORIZED
+ group.set_config()
+ else:
+ group.auth_status = AuthStatus.PENDING
+ ModelJobGroupController(session, project_id).inform_auth_status_to_participants(group)
+ if params['algorithm_id'] is not None:
+ algorithm = get_algorithm(project_id=project_id, algorithm_id=params['algorithm_id'], session=session)
+ if algorithm is None:
+ raise InvalidArgumentException(f'algorithm {params["algorithm_id"]} is not found')
+ if algorithm.type != group.algorithm_type:
+ raise InvalidArgumentException(f'algorithm type mismatch between group and algorithm: '
+ f'{group.algorithm_type.name} vs {algorithm.type.name}')
+ group.algorithm_id = params['algorithm_id']
+ group.algorithm_project_id = algorithm.algorithm_project_id
+ if params['dataset_id'] is not None:
+ group.dataset_id = params['dataset_id']
+ if params['config'] is not None:
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=group.algorithm_type,
+ project_id=project_id)
+ configer.set_dataset(config=params['config'], dataset_id=group.dataset_id)
+ group.set_config(params['config'])
+ if params['global_config'] is not None:
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=group.algorithm_type,
+ project_id=project_id)
+ domain_name = SettingService.get_system_info().pure_domain_name
+ config = configer.get_config(dataset_id=group.dataset_id,
+ model_id=None,
+ model_job_config=params['global_config'].global_config[domain_name])
+ group.set_config(config)
+ if params['comment'] is not None:
+ group.comment = params['comment']
+ if group.creator_username is None:
+ group.creator_username = get_current_user().username
+ if params['cron_config'] is not None:
+ ModelJobGroupService(session).update_cronjob_config(group=group, cron_config=params['cron_config'])
+ session.commit()
+ return make_flask_response(data=group.to_proto())
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP, op_type=Event.OperationType.DELETE)
+ def delete(self, project_id: int, group_id: int):
+ """Delete the model job group
+ ---
+ tags:
+ - mmgr
+ description: delete the model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 204:
+ description: delete the model job group successfully
+ 409:
+ description: group cannot be deleted due to some model job is ready or running
+ """
+ with db.session_scope() as session:
+ group = get_model_job_group(project_id=project_id, group_id=group_id, session=session)
+ if not group.is_deletable():
+ raise ResourceConflictException('group cannot be deleted due to some model job is ready or running')
+ ModelJobGroupService(session).delete(group.id)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class PeerModelJobGroupApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, group_id: int, participant_id: int):
+ """Get the peer model job group
+ ---
+ tags:
+ - mmgr
+ description: Get the peer model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the model job group
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.GetModelJobGroupResponse'
+ """
+ with db.session_scope() as session:
+ group = get_model_job_group(project_id=project_id, group_id=group_id, session=session)
+ resp = ModelJobGroupController(session, project_id).get_model_job_group_from_participant(
+ participant_id=participant_id, model_job_group_uuid=group.uuid)
+ return make_flask_response(resp, status=HTTPStatus.OK)
+
+ @credentials_required
+ @use_args(ConfigPeerModelJobGroup(), location='json')
+ def patch(self, params: dict, project_id: int, group_id: int, participant_id: int):
+ """Patch a peer model job group
+ ---
+ tags:
+ - mmgr
+ description: patch a peer model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/ConfigPeerModelJobGroup'
+ responses:
+ 200:
+ description: update the peer model job group successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.UpdateModelJobGroupResponse'
+ """
+ config = params['config']
+ global_config = params['global_config']
+ with db.session_scope() as session:
+ group = get_model_job_group(project_id=project_id, group_id=group_id, session=session)
+ project = group.project
+ participant = get_participant(participant_id, project)
+ client = RpcClient.from_project_and_participant(project.name, project.token, participant.domain_name)
+ if global_config is not None:
+ configer = ModelJobConfiger(session=session,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=group.algorithm_type,
+ project_id=project_id)
+ domain_name = participant.pure_domain_name()
+ config = configer.get_config(dataset_id=group.dataset_id,
+ model_id=None,
+ model_job_config=global_config.global_config[domain_name])
+ resp = client.update_model_job_group(model_job_group_uuid=group.uuid, config=config)
+ return make_flask_response(resp, status=HTTPStatus.OK)
+
+
+class ModelJobGroupStopAutoUpdateApi(Resource):
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP, op_type=Event.OperationType.STOP)
+ def post(self, project_id: int, group_id: int):
+ """Stop trigger auto update model job in this model job group
+ ---
+ tags:
+ - mmgr
+ description: create the model job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: stop the auto update model job successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ModelJobGroupPb'
+ """
+ with db.session_scope() as session:
+ group: ModelJobGroup = get_model_job_group(project_id=project_id, group_id=group_id, session=session)
+ group.auto_update_status = GroupAutoUpdateStatus.STOPPED
+ ModelJobGroupController(session=session, project_id=project_id).update_participants_model_job_group(
+ uuid=group.uuid, auto_update_status=group.auto_update_status)
+ session.commit()
+ return make_flask_response(data=group.to_proto(), status=HTTPStatus.OK)
+
+
+def initialize_mmgr_model_job_group_apis(api):
+ api.add_resource(ModelJobGroupsApi, '/projects//model_job_groups')
+ api.add_resource(ModelJobGroupsApiV2, '/projects//model_job_groups_v2')
+ api.add_resource(ModelJobGroupApi, '/projects//model_job_groups/')
+ api.add_resource(ModelJobGroupStopAutoUpdateApi,
+ '/projects//model_job_groups/:stop_auto_update')
+ api.add_resource(PeerModelJobGroupApi,
+ '/projects//model_job_groups//peers/')
+
+ schema_manager.append(CreateModelJobGroupParams)
+ schema_manager.append(CreateModelJobGroupParamsV2)
+ schema_manager.append(ConfigModelJobGroupParams)
+ schema_manager.append(ConfigPeerModelJobGroup)
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_group_apis_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_group_apis_test.py
new file mode 100644
index 000000000..1dcd4757b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/model_job_group_apis_test.py
@@ -0,0 +1,551 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import urllib.parse
+from http import HTTPStatus
+from datetime import datetime
+from unittest.mock import patch, Mock, ANY, MagicMock, call
+from google.protobuf.empty_pb2 import Empty
+from envs import Envs
+from testing.common import BaseTestCase
+from testing.fake_model_job_config import get_global_config, get_workflow_config
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.flask_utils import to_dict
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.participant.models import ProjectParticipant
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobGroup, ModelJobType, ModelJobRole, \
+ GroupAuthFrontendStatus, GroupAutoUpdateStatus, ModelJobStatus
+from fedlearner_webconsole.algorithm.models import AlgorithmType, Algorithm, AlgorithmProject
+from fedlearner_webconsole.proto.service_pb2 import UpdateModelJobGroupResponse
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo, ParticipantsInfo
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGroupPb, AlgorithmProjectList
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+
+
+class ModelJobGroupsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='party', domain_name='fl-peer.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ algo_project = AlgorithmProject(id=1, name='algo')
+ algo = Algorithm(id=2, name='algo', algorithm_project_id=1)
+ session.add_all([project, algo, algo_project, participant, relationship])
+ g1 = ModelJobGroup(id=1,
+ name='g1',
+ uuid='u1',
+ role=ModelJobRole.COORDINATOR,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_project_id=1,
+ algorithm_id=2,
+ project_id=1,
+ created_at=datetime(2021, 1, 1, 0, 0, 0))
+ g1.set_config(get_workflow_config(ModelJobType.TRAINING))
+ g2 = ModelJobGroup(name='g2',
+ uuid='u2',
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ algorithm_type=AlgorithmType.NN_HORIZONTAL,
+ created_at=datetime(2021, 1, 1, 0, 0, 1))
+ g3 = ModelJobGroup(name='g3',
+ uuid='u3',
+ project_id=2,
+ role=ModelJobRole.PARTICIPANT,
+ algorithm_type=AlgorithmType.TREE_VERTICAL,
+ created_at=datetime(2021, 1, 1, 0, 0, 1))
+ workflow = Workflow(id=1, name='workflow', state=WorkflowState.RUNNING)
+ model_job = ModelJob(id=1, group_id=1, status=ModelJobStatus.PENDING, workflow_id=1)
+ dataset = Dataset(name='dataset', uuid='dataset_uuid', is_published=True)
+ session.add_all([g1, g2, g3, dataset, workflow, model_job])
+ session.commit()
+
+ def test_get_groups(self):
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['name'], 'g2')
+ self.assertEqual(data[0]['configured'], False)
+ self.assertEqual(data[1]['name'], 'g1')
+ self.assertEqual(data[1]['configured'], True)
+ self.assertEqual(data[1]['latest_job_state'], 'RUNNING')
+ resp = self.get_helper('/api/v2/projects/0/model_job_groups')
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data), 3)
+ resp = self.get_helper('/api/v2/projects/0/model_job_groups?filter=(configured%3Dfalse)')
+ data = self.get_response_data(resp)
+ self.assertEqual(sorted([d['name'] for d in data]), ['g2', 'g3'])
+
+ def test_get_groups_by_filtering_expression(self):
+ filter_param = urllib.parse.quote('(algorithm_type:["NN_VERTICAL","NN_HORIZONTAL"])')
+ resp = self.get_helper(f'/api/v2/projects/0/model_job_groups?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g2', 'g1'])
+ filter_param = urllib.parse.quote('(algorithm_type:["NN_VERTICAL"])')
+ resp = self.get_helper(f'/api/v2/projects/0/model_job_groups?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g1'])
+ filter_param = urllib.parse.quote('(role:["COORDINATOR"])')
+ resp = self.get_helper(f'/api/v2/projects/0/model_job_groups?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g2', 'g1'])
+ filter_param = urllib.parse.quote('(name~="1")')
+ resp = self.get_helper(f'/api/v2/projects/0/model_job_groups?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g1'])
+ filter_param = urllib.parse.quote('created_at asc')
+ resp = self.get_helper(f'/api/v2/projects/0/model_job_groups?order_by={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g1', 'g2', 'g3'])
+
+ # TODO(linfan): refactor transaction manager
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.two_pc.model_job_group_creator.ModelJobGroupCreator.prepare')
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_post_model_job_group(self, mock_remote_twp_pc, mock_prepare, mock_system_info):
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test', name='name')
+ mock_prepare.return_value = True, ''
+ with db.session_scope() as session:
+ dataset_id = session.query(Dataset).filter_by(uuid='dataset_uuid').first().id
+ mock_remote_twp_pc.return_value = True, ''
+ resp = self.post_helper('/api/v2/projects/1/model_job_groups',
+ data={
+ 'name': 'group',
+ 'algorithm_type': AlgorithmType.NN_VERTICAL.name,
+ 'dataset_id': dataset_id
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(name='group').first()
+ self.assertEqual(group.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(group.role, ModelJobRole.COORDINATOR)
+ self.assertIsNone(group.coordinator_id)
+ self.assertEqual(group.creator_username, 'ada')
+ self.assertEqual(
+ group.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'peer': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ }))
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.PART_AUTH_PENDING)
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_post_model_job_group_failed(self, mock_remote_twp_pc):
+ mock_remote_twp_pc.return_value = True, ''
+ resp = self.post_helper('/api/v2/projects/1/model_job_groups',
+ data={
+ 'name': 'group',
+ 'algorithm_type': AlgorithmType.NN_VERTICAL.name,
+ 'dataset_id': -1
+ })
+ # fail due to dataset is not found
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).filter_by(uuid='dataset_uuid').first()
+ dataset.is_published = False
+ session.add(dataset)
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/model_job_groups',
+ data={
+ 'name': 'group',
+ 'algorithm_type': AlgorithmType.NN_VERTICAL.name,
+ 'dataset_id': dataset.id
+ })
+ # fail due to dataset is not published
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+
+
+class ModelJobGroupsApiV2Test(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Envs.SYSTEM_INFO = '{"domain_name": "fl-test.com"}'
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='party', domain_name='fl-peer.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ algo_project = AlgorithmProject(id=1, name='algo')
+ algo = Algorithm(id=2, name='algo', algorithm_project_id=1)
+ dataset = Dataset(id=1, name='dataset', uuid='dataset_uuid', is_published=True)
+ session.add_all([project, algo, algo_project, participant, relationship, dataset])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_model_job_group')
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ def test_post_model_job(self, mock_system_info, mock_client):
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test', name='name')
+ algorithm_project_list = AlgorithmProjectList()
+ algorithm_project_list.algorithm_projects['test'] = 'uuid-test'
+ algorithm_project_list.algorithm_projects['peer'] = 'uuid-peer'
+ resp = self.post_helper('/api/v2/projects/1/model_job_groups_v2',
+ data={
+ 'name': 'group',
+ 'dataset_id': 1,
+ 'algorithm_type': AlgorithmType.NN_VERTICAL.name,
+ 'algorithm_project_list': to_dict(algorithm_project_list),
+ 'comment': 'comment'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ mock_client.assert_called()
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(name='group').first()
+ self.assertEqual(group.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(group.role, ModelJobRole.COORDINATOR)
+ self.assertEqual(group.dataset.uuid, 'dataset_uuid')
+ self.assertIsNone(group.coordinator_id)
+ self.assertEqual(group.creator_username, 'ada')
+ self.assertEqual(group.comment, 'comment')
+ self.assertEqual(
+ group.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'peer': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ }))
+ self.assertEqual(group.get_algorithm_project_uuid_list(),
+ AlgorithmProjectList(algorithm_projects={
+ 'peer': 'uuid-peer',
+ 'test': 'uuid-test'
+ }))
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.PART_AUTH_PENDING)
+ self.assertEqual(group.to_proto().configured, True)
+ resp = self.post_helper('/api/v2/projects/1/model_job_groups_v2',
+ data={
+ 'name': 'new_group',
+ 'dataset_id': 1,
+ 'algorithm_type': AlgorithmType.TREE_VERTICAL.name,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).filter_by(name='new_group').first()
+ self.assertEqual(group.dataset.uuid, 'dataset_uuid')
+ algorithm_project_list = AlgorithmProjectList()
+ self.assertEqual(group.get_algorithm_project_uuid_list(), algorithm_project_list)
+ resp = self.post_helper('/api/v2/projects/1/model_job_groups_v2',
+ data={
+ 'name': 'new_group',
+ 'dataset_id': 1,
+ 'algorithm_type': AlgorithmType.NN_VERTICAL.name,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+
+
+class ModelJobGroupApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Envs.SYSTEM_INFO = '{"domain_name": "fl-test.com"}'
+ with db.session_scope() as session:
+ algo_project = AlgorithmProject(id=123, name='algo_project', project_id=1)
+ dataset = Dataset(id=2, name='dataset')
+ algorithm = Algorithm(id=1,
+ name='algo',
+ algorithm_project_id=123,
+ type=AlgorithmType.NN_VERTICAL,
+ project_id=1)
+ group = ModelJobGroup(id=1,
+ name='group',
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ uuid='uuid',
+ creator_username='ada',
+ project_id=1,
+ created_at=datetime(2022, 5, 6, 0, 0, 0),
+ updated_at=datetime(2022, 5, 6, 0, 0, 0))
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part1', domain_name='fl-demo1.com')
+ participant2 = Participant(id=2, name='part2', domain_name='fl-demo2.com')
+ pro_part1 = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ pro_part2 = ProjectParticipant(id=2, project_id=1, participant_id=2)
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['test'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['demo1'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['demo2'].auth_status = AuthStatus.PENDING.name
+ group.set_participants_info(participants_info)
+ session.add_all(
+ [algo_project, algorithm, group, dataset, project, participant1, participant2, pro_part1, pro_part2])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_model_job_group')
+ def test_get_group(self, mock_client: MagicMock):
+ mock_client.side_effect = [
+ ModelJobGroupPb(auth_status=AuthStatus.AUTHORIZED.name),
+ ModelJobGroupPb(auth_status=AuthStatus.AUTHORIZED.name)
+ ]
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ group.algorithm_project_id = 1
+ group.algorithm_id = 2
+ group.dataset_id = 2
+ group.set_config(get_workflow_config(ModelJobType.TRAINING))
+ algorithm_project_list = AlgorithmProjectList()
+ algorithm_project_list.algorithm_projects['test'] = 'uuid-test'
+ algorithm_project_list.algorithm_projects['demo1'] = 'uuid-demo1'
+ algorithm_project_list.algorithm_projects['demo2'] = 'uuid-demo2'
+ group.set_algorithm_project_uuid_list(algorithm_project_list)
+ group.comment = 'comment'
+ group.latest_version = 1
+ model_job = ModelJob(name='job-1', group_id=1)
+ session.add(model_job)
+ session.commit()
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.maxDiff = None
+ self.assertResponseDataEqual(resp, {
+ 'id': 1,
+ 'uuid': 'uuid',
+ 'name': 'group',
+ 'project_id': 1,
+ 'role': 'PARTICIPANT',
+ 'creator_username': 'ada',
+ 'coordinator_id': 0,
+ 'authorized': False,
+ 'auto_update_status': 'INITIAL',
+ 'dataset_id': 2,
+ 'algorithm_type': 'NN_VERTICAL',
+ 'algorithm_project_id': 1,
+ 'algorithm_id': 2,
+ 'comment': 'comment',
+ 'cron_config': '',
+ 'configured': True,
+ 'latest_version': 1,
+ 'config': to_dict(get_workflow_config(ModelJobType.TRAINING)),
+ 'latest_job_state': 'PENDING',
+ 'auth_frontend_status': 'SELF_AUTH_PENDING',
+ 'auth_status': 'PENDING',
+ 'created_at': 1651795200,
+ 'algorithm_project_uuid_list': {
+ 'algorithm_projects': {
+ 'test': 'uuid-test',
+ 'demo1': 'uuid-demo1',
+ 'demo2': 'uuid-demo2'
+ }
+ },
+ 'participants_info': {
+ 'participants_map': {
+ 'test': {
+ 'auth_status': 'PENDING',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': ''
+ },
+ 'demo1': {
+ 'auth_status': 'AUTHORIZED',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': ''
+ },
+ 'demo2': {
+ 'auth_status': 'AUTHORIZED',
+ 'name': '',
+ 'role': '',
+ 'state': '',
+ 'type': ''
+ }
+ }
+ },
+ },
+ ignore_fields=['model_jobs', 'updated_at', 'start_data_batch_id'])
+ data = self.get_response_data(resp)
+ self.assertEqual(len(data['model_jobs']), 1)
+ self.assertPartiallyEqual(data['model_jobs'][0], {
+ 'id': 1,
+ 'name': 'job-1',
+ 'role': 'PARTICIPANT',
+ 'model_job_type': 'UNSPECIFIED',
+ 'algorithm_type': 'UNSPECIFIED',
+ 'state': 'PENDING_ACCEPT',
+ 'group_id': 1,
+ 'status': 'PENDING',
+ 'uuid': '',
+ 'configured': False,
+ 'creator_username': '',
+ 'coordinator_id': 0,
+ 'version': 0,
+ 'project_id': 0,
+ 'started_at': 0,
+ 'stopped_at': 0,
+ 'metric_is_public': False,
+ 'algorithm_id': 0,
+ 'auth_status': 'PENDING',
+ 'auto_update': False,
+ 'auth_frontend_status': 'SELF_AUTH_PENDING',
+ 'participants_info': {
+ 'participants_map': {}
+ }
+ },
+ ignore_fields=['created_at', 'updated_at'])
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_model_job_group')
+ def test_put_model_job_group(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ config = get_workflow_config(ModelJobType.TRAINING)
+ resp = self.put_helper('/api/v2/projects/1/model_job_groups/1',
+ data={
+ 'authorized': True,
+ 'algorithm_id': 1,
+ 'config': to_dict(config),
+ 'comment': 'comment'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ self.assertTrue(group.authorized)
+ self.assertEqual(group.algorithm_id, 1)
+ self.assertEqual(group.algorithm_project_id, 123)
+ self.assertEqual(group.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(group.get_config(), config)
+ self.assertEqual(group.comment, 'comment')
+ self.assertEqual(group.to_proto().configured, True)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map['test'].auth_status, AuthStatus.AUTHORIZED.name)
+ self.assertEqual(mock_client.call_args_list, [(('uuid', AuthStatus.AUTHORIZED),),
+ (('uuid', AuthStatus.AUTHORIZED),)])
+
+ @patch('fedlearner_webconsole.mmgr.model_job_configer.ModelJobConfiger.get_config')
+ def test_put_model_job_group_with_global_config(self, mock_get_config):
+ mock_get_config.return_value = get_workflow_config(ModelJobType.EVALUATION)
+ global_config = get_global_config()
+ resp = self.put_helper('/api/v2/projects/1/model_job_groups/1',
+ data={
+ 'dataset_id': 1,
+ 'global_config': to_dict(global_config)
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_get_config.assert_called_with(dataset_id=1,
+ model_id=None,
+ model_job_config=global_config.global_config['test'])
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ self.assertEqual(group.dataset_id, 1)
+ self.assertEqual(group.get_config(), get_workflow_config(ModelJobType.EVALUATION))
+
+ @patch('fedlearner_webconsole.mmgr.service.ModelJobGroupService.update_cronjob_config')
+ def test_put_model_job_group_with_cron_config(self, mock_cronjob_config: Mock):
+ resp = self.put_helper('/api/v2/projects/1/model_job_groups/1', data={})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_cronjob_config.assert_not_called()
+ self.put_helper('/api/v2/projects/1/model_job_groups/1', data={
+ 'cron_config': '*/10 * * * *',
+ })
+ mock_cronjob_config.assert_called_once_with(group=ANY, cron_config='*/10 * * * *')
+ self.put_helper('/api/v2/projects/1/model_job_groups/1', data={
+ 'cron_config': '',
+ })
+ mock_cronjob_config.assert_called_with(group=ANY, cron_config='')
+
+ def test_delete_model_job_group(self):
+ resp = self.delete_helper('/api/v2/projects/1/model_job_groups/1')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).execution_options(include_deleted=True).get(1)
+ self.assertIsNotNone(group.deleted_at)
+ for job in group.model_jobs:
+ self.assertIsNotNone(job.deleted_at)
+
+
+class PeerModelJobGroupApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='party', domain_name='fl-peer.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ group = ModelJobGroup(id=1, name='group', uuid='uuid', project_id=1, dataset_id=1)
+ session.add_all([project, participant, relationship, group])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.system_service_client.SystemServiceClient.list_flags')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_model_job_group')
+ def test_get_peer_model_job_group(self, mock_get_group, mock_list_flags):
+ config = WorkflowDefinition(job_definitions=[JobDefinition(variables=[Variable(name='test')])])
+ mock_get_group.return_value = ModelJobGroupPb(name='group', uuid='uuid', config=config)
+ mock_list_flags.return_value = {'model_job_global_config_enabled': True}
+ resp = self.get_helper('/api/v2/projects/1/model_job_groups/1/peers/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_get_group.assert_called()
+ data = self.get_response_data(resp)
+ self.assertEqual(data['name'], 'group')
+ self.assertEqual(data['uuid'], 'uuid')
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.update_model_job_group')
+ def test_patch_peer_model_job_group(self, mock_update_group):
+ config = get_workflow_config(ModelJobType.TRAINING)
+ mock_update_group.return_value = UpdateModelJobGroupResponse(uuid='uuid', config=config)
+ resp = self.patch_helper('/api/v2/projects/1/model_job_groups/1/peers/1', data={'config': to_dict(config)})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_update_group.assert_called_with(model_job_group_uuid='uuid', config=config)
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.update_model_job_group')
+ @patch('fedlearner_webconsole.mmgr.model_job_configer.ModelJobConfiger.get_config')
+ def test_patch_peer_model_job_group_with_global_config(self, mock_get_config, mock_update_group):
+ config = get_workflow_config(ModelJobType.TRAINING)
+ mock_get_config.return_value = config
+ mock_update_group.return_value = UpdateModelJobGroupResponse(uuid='uuid', config=config)
+ global_config = get_global_config()
+ resp = self.patch_helper('/api/v2/projects/1/model_job_groups/1/peers/1',
+ data={'global_config': to_dict(global_config)})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_get_config.assert_called_with(dataset_id=1,
+ model_id=None,
+ model_job_config=global_config.global_config['peer'])
+ mock_update_group.assert_called_with(model_job_group_uuid='uuid', config=config)
+
+
+class StopAutoUpdateApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ participant = Participant(id=1, name='party', domain_name='fl-peer.com', host='127.0.0.1', port=32443)
+ group = ModelJobGroup(id=1,
+ name='group',
+ uuid='uuid',
+ project_id=1,
+ dataset_id=1,
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE,
+ start_data_batch_id=1)
+ session.add_all([project, participant, relationship, group])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.update_model_job_group')
+ def test_post_stop(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ resp = self.post_helper('/api/v2/projects/1/model_job_groups/1:stop_auto_update')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['auto_update_status'], GroupAutoUpdateStatus.STOPPED.name)
+ self.assertEqual(
+ mock_client.call_args_list,
+ [call(uuid='uuid', auto_update_status=GroupAutoUpdateStatus.STOPPED, start_dataset_job_stage_uuid=None)])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/models.py b/web_console_v2/api/fedlearner_webconsole/mmgr/models.py
index 6db0ea885..4eda2e302 100644
--- a/web_console_v2/api/fedlearner_webconsole/mmgr/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/models.py
@@ -1,108 +1,648 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
-# Licensed under the Apache License, Version 2.0 (the 'License');
+# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
+# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-# coding: utf-8
+#
import enum
+import logging
+from typing import Optional
+from google.protobuf import text_format
from sqlalchemy.sql import func
-from sqlalchemy.orm import remote, foreign
from sqlalchemy.sql.schema import Index, UniqueConstraint
-from fedlearner_webconsole.utils.mixins import to_dict_mixin
from fedlearner_webconsole.db import db, default_table_args
-from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmType
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.mmgr.utils import get_job_path, get_exported_model_path, get_checkpoint_path, \
+ get_output_path
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.base_model import auth_model
+from fedlearner_webconsole.utils.base_model.softdelete_model import SoftDeleteModel
+from fedlearner_webconsole.utils.base_model.review_ticket_and_auth_model import ReviewTicketAndAuthModel
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowExternalState
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobPb, ModelJobGroupPb, ModelJobRef, ModelJobGroupRef, ModelPb, \
+ ModelJobGlobalConfig, AlgorithmProjectList
class ModelType(enum.Enum):
- NN_MODEL = 0
- NN_EVALUATION = 1
+ UNSPECIFIED = 0
+ NN_MODEL = 1
TREE_MODEL = 2
- TREE_EVALUATION = 3
-class ModelState(enum.Enum):
- NEW = -1 # before workflow has synced both party
- COMMITTING = 0 # (transient) after workflow has synced both party, before committing to k8s
- COMMITTED = 1 # after committed to k8s but before running
- WAITING = 2 # k8s is queueing the related job(s)
- RUNNING = 3 # k8s is running the related job(s)
- PAUSED = 4 # related workflow has been paused by end-user
- SUCCEEDED = 5
- FAILED = 6
- # DROPPING = 7 # (transient) removing model and its related resources
- DROPPED = 8 # model has been removed
+class ModelJobType(enum.Enum):
+ UNSPECIFIED = 0
+ NN_TRAINING = 1
+ NN_EVALUATION = 2
+ NN_PREDICTION = 3
+ TREE_TRAINING = 4
+ TREE_EVALUATION = 5
+ TREE_PREDICTION = 6
+ TRAINING = 7
+ EVALUATION = 8
+ PREDICTION = 9
-# TODO transaction
-@to_dict_mixin()
-class Model(db.Model):
- __tablename__ = 'models_v2'
- __table_args__ = (Index('idx_job_name', 'job_name'),
- UniqueConstraint('job_name', name='uniq_job_name'),
- default_table_args('model'))
-
- id = db.Column(db.Integer, primary_key=True, comment='id')
- name = db.Column(db.String(255),
- comment='name') # can be modified by end-user
- version = db.Column(db.Integer, default=0, comment='version')
- type = db.Column(db.Integer, comment='type')
- state = db.Column(db.Integer, comment='state')
+class ModelJobRole(enum.Enum):
+ PARTICIPANT = 0
+ COORDINATOR = 1
+
+
+class ModelJobStatus(enum.Enum):
+ PENDING = 'PENDING' # all model jobs are created, the local algorithm files and the local workflow are pending
+ CONFIGURED = 'CONFIGURED' # the local algorithm files are available and the local workflow is created
+ ERROR = 'ERROR' # error during creating model job
+ RUNNING = 'RUNNING'
+ STOPPED = 'STOPPED'
+ SUCCEEDED = 'SUCCEEDED'
+ FAILED = 'FAILED' # job failed during running
+
+
+class AuthStatus(enum.Enum):
+ PENDING = 'PENDING'
+ AUTHORIZED = 'AUTHORIZED'
+
+
+class GroupCreateStatus(enum.Enum):
+ PENDING = 'PENDING'
+ FAILED = 'FAILED'
+ SUCCEEDED = 'SUCCEEDED'
+
+
+class GroupAuthFrontendStatus(enum.Enum):
+ TICKET_PENDING = 'TICKET_PENDING'
+ TICKET_DECLINED = 'TICKET_DECLINED'
+ CREATE_PENDING = 'CREATE_PENDING'
+ CREATE_FAILED = 'CREATE_FAILED'
+ SELF_AUTH_PENDING = 'SELF_AUTH_PENDING'
+ PART_AUTH_PENDING = 'PART_AUTH_PENDING'
+ ALL_AUTHORIZED = 'ALL_AUTHORIZED'
+
+
+class GroupAutoUpdateStatus(enum.Enum):
+ INITIAL = 'INITIAL'
+ ACTIVE = 'ACTIVE'
+ STOPPED = 'STOPPED'
+
+
+class ModelJobCreateStatus(enum.Enum):
+ PENDING = 'PENDING'
+ FAILED = 'FAILED'
+ SUCCEEDED = 'SUCCEEDED'
+
+
+class ModelJobAuthFrontendStatus(enum.Enum):
+ TICKET_PENDING = 'TICKET_PENDING'
+ TICKET_DECLINED = 'TICKET_DECLINED'
+ CREATE_PENDING = 'CREATE_PENDING'
+ CREATE_FAILED = 'CREATE_FAILED'
+ SELF_AUTH_PENDING = 'SELF_AUTH_PENDING'
+ PART_AUTH_PENDING = 'PART_AUTH_PENDING'
+ ALL_AUTHORIZED = 'ALL_AUTHORIZED'
+
+
+class ModelJob(db.Model, SoftDeleteModel, ReviewTicketAndAuthModel):
+ __tablename__ = 'model_jobs_v2'
+ __table_args__ = (Index('idx_uuid',
+ 'uuid'), UniqueConstraint('job_name',
+ name='uniq_job_name'), default_table_args('model_jobs_v2'))
+
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ name = db.Column(db.String(255), comment='name')
+ uuid = db.Column(db.String(64), comment='uuid')
+ role = db.Column(db.Enum(ModelJobRole, native_enum=False, length=32, create_constraint=False),
+ default=ModelJobRole.PARTICIPANT,
+ comment='role')
+ model_job_type = db.Column(db.Enum(ModelJobType, native_enum=False, length=32, create_constraint=False),
+ default=ModelJobType.UNSPECIFIED,
+ comment='type')
job_name = db.Column(db.String(255), comment='job_name')
- parent_id = db.Column(db.Integer, comment='parent_id')
+ job_id = db.Column(db.Integer, comment='job id')
+ # the model id used for prediction or evaluation
+ model_id = db.Column(db.Integer, comment='model_id')
+ group_id = db.Column(db.Integer, comment='group_id')
+ project_id = db.Column(db.Integer, comment='project id')
+ workflow_id = db.Column(db.Integer, comment='workflow id')
+ workflow_uuid = db.Column(db.String(64), comment='workflow uuid')
+ algorithm_type = db.Column(db.Enum(AlgorithmType, native_enum=False, length=32, create_constraint=False),
+ default=AlgorithmType.UNSPECIFIED,
+ comment='algorithm type')
+ algorithm_id = db.Column(db.Integer, comment='algorithm id')
+ dataset_id = db.Column(db.Integer, comment='dataset id')
params = db.Column(db.Text(), comment='params')
metrics = db.Column(db.Text(), comment='metrics')
- created_at = db.Column(db.DateTime(timezone=True),
- comment='created_at',
- server_default=func.now())
+ extra = db.Column(db.Text(), comment='extra')
+ favorite = db.Column(db.Boolean, default=False, comment='favorite')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
+ version = db.Column(db.Integer, comment='version')
+ creator_username = db.Column(db.String(255), comment='creator username')
+ coordinator_id = db.Column(db.Integer, comment='coordinator participant id')
+ path = db.Column('fspath', db.String(512), key='path', comment='model job path')
+ metric_is_public = db.Column(db.Boolean(), default=False, comment='is metric public')
+ global_config = db.Column(db.Text(16777215), comment='global_config')
+ status = db.Column(db.Enum(ModelJobStatus, native_enum=False, length=32, create_constraint=False),
+ default=ModelJobStatus.PENDING,
+ comment='model job status')
+ create_status = db.Column(db.Enum(ModelJobCreateStatus, native_enum=False, length=32, create_constraint=False),
+ default=ModelJobCreateStatus.PENDING,
+ comment='create status')
+ auth_status = db.Column(db.Enum(AuthStatus, native_enum=False, length=32, create_constraint=False),
+ default=AuthStatus.PENDING,
+ comment='authorization status')
+ auto_update = db.Column(db.Boolean(), server_default=db.text('0'), comment='is auto update')
+ data_batch_id = db.Column(db.Integer, comment='data_batches id for auto update job')
+ error_message = db.Column(db.Text(), comment='error message')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created_at', server_default=func.now())
updated_at = db.Column(db.DateTime(timezone=True),
comment='updated_at',
server_default=func.now(),
onupdate=func.now())
deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted_at')
+ # the model id used for prediction or evaluation
+ model = db.relationship('Model', primaryjoin='Model.id == foreign(ModelJob.model_id)')
+ group = db.relationship('ModelJobGroup', primaryjoin='ModelJobGroup.id == foreign(ModelJob.group_id)')
+ project = db.relationship(Project.__name__, primaryjoin='Project.id == foreign(ModelJob.project_id)')
+ # job_name is the foreign key, job_id is unknown when creating
+ job = db.relationship('Job', primaryjoin='Job.name == foreign(ModelJob.job_name)')
+ # workflow_uuid is the foreign key, workflow_id is unknown when creating
+ workflow = db.relationship(Workflow.__name__, primaryjoin='Workflow.uuid == foreign(ModelJob.workflow_uuid)')
+ algorithm = db.relationship(Algorithm.__name__, primaryjoin='Algorithm.id == foreign(ModelJob.algorithm_id)')
+ dataset = db.relationship(Dataset.__name__, primaryjoin='Dataset.id == foreign(ModelJob.dataset_id)')
+ data_batch = db.relationship('DataBatch', primaryjoin='DataBatch.id == foreign(ModelJob.data_batch_id)')
- group_id = db.Column(db.Integer, default=0, comment='group_id')
- # TODO https://code.byted.org/data/fedlearner_web_console_v2/issues/289
- extra = db.Column(db.Text(), comment='extra') # json string
+ output_model = db.relationship(
+ 'Model',
+ uselist=False,
+ primaryjoin='ModelJob.id == foreign(Model.model_job_id)',
+ # To disable the warning of back_populates
+ overlaps='model_job')
+
+ def to_proto(self) -> ModelJobPb:
+ config = self.config()
+ model_job = ModelJobPb(
+ id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ role=self.role.name,
+ model_job_type=self.model_job_type.name,
+ algorithm_type=self.algorithm_type.name if self.algorithm_type else AlgorithmType.UNSPECIFIED.name,
+ algorithm_id=self.algorithm_id,
+ group_id=self.group_id,
+ project_id=self.project_id,
+ state=self.state.name,
+ configured=config is not None,
+ model_id=self.model_id,
+ model_name=self.model_name(),
+ job_id=self.job_id,
+ job_name=self.job_name,
+ workflow_id=self.workflow_id,
+ dataset_id=self.dataset_id,
+ dataset_name=self.dataset_name(),
+ creator_username=self.creator_username,
+ coordinator_id=self.coordinator_id,
+ auth_status=self.auth_status.name if self.auth_status else '',
+ status=self.status.name if self.status else '',
+ error_message=self.error_message,
+ auto_update=self.auto_update,
+ data_batch_id=self.data_batch_id,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ started_at=self.started_at(),
+ stopped_at=self.stopped_at(),
+ version=self.version,
+ comment=self.comment,
+ metric_is_public=self.metric_is_public,
+ global_config=self.get_global_config(),
+ participants_info=self.get_participants_info(),
+ auth_frontend_status=self.get_model_job_auth_frontend_status().name)
+ if config is not None:
+ model_job.config.MergeFrom(config)
+ if self.output_model is not None:
+ model_job.output_model_name = self.output_model.name
+ model_job.output_models.append(self.output_model.to_proto())
+ return model_job
+
+ def to_ref(self) -> ModelJobRef:
+ return ModelJobRef(
+ id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ group_id=self.group_id,
+ project_id=self.project_id,
+ role=self.role.name,
+ model_job_type=self.model_job_type.name,
+ algorithm_type=self.algorithm_type.name if self.algorithm_type else AlgorithmType.UNSPECIFIED.name,
+ algorithm_id=self.algorithm_id,
+ state=self.state.name,
+ configured=self.config() is not None,
+ creator_username=self.creator_username,
+ coordinator_id=self.coordinator_id,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ started_at=self.started_at(),
+ stopped_at=self.stopped_at(),
+ version=self.version,
+ metric_is_public=self.metric_is_public,
+ status=self.status.name if self.status else '',
+ auto_update=self.auto_update,
+ auth_status=self.auth_status.name if self.auth_status else '',
+ participants_info=self.get_participants_info(),
+ auth_frontend_status=self.get_model_job_auth_frontend_status().name)
+
+ @property
+ def state(self) -> WorkflowExternalState:
+ # TODO(hangweiqiang): design model job state
+ if self.workflow is None:
+ return WorkflowExternalState.PENDING_ACCEPT
+ return self.workflow.get_state_for_frontend()
+
+ def get_model_job_auth_frontend_status(self) -> ModelJobAuthFrontendStatus:
+ if self.ticket_status == TicketStatus.PENDING:
+ if self.ticket_uuid is not None:
+ return ModelJobAuthFrontendStatus.TICKET_PENDING
+ # Update old data that is set to PENDING by default when ticket is disabled
+ self.ticket_status = TicketStatus.APPROVED
+ if self.ticket_status == TicketStatus.DECLINED:
+ return ModelJobAuthFrontendStatus.TICKET_DECLINED
+ if self.auth_status not in [AuthStatus.AUTHORIZED]:
+ return ModelJobAuthFrontendStatus.SELF_AUTH_PENDING
+ if self.is_all_participants_authorized():
+ return ModelJobAuthFrontendStatus.ALL_AUTHORIZED
+ if self.create_status in [ModelJobCreateStatus.PENDING]:
+ return ModelJobAuthFrontendStatus.CREATE_PENDING
+ if self.create_status in [ModelJobCreateStatus.FAILED]:
+ return ModelJobAuthFrontendStatus.CREATE_FAILED
+ return ModelJobAuthFrontendStatus.PART_AUTH_PENDING
+
+ def get_job_path(self):
+ path = self.project.get_storage_root_path(None)
+ if path is None:
+ logging.warning('cannot find storage_root_path')
+ return None
+ return get_job_path(path, self.job_name)
+
+ def get_exported_model_path(self) -> Optional[str]:
+ """Get the path of the exported models.
+
+ Returns:
+ The path of the exported_models is returned. Return None if the
+ path can not found. There may be multiple checkpoints under the
+ path. The file structure of nn_model under the path of
+ exported_model is
+ - exported_models:
+ - ${terminated time, e.g. 1619769879}
+ - _SUCCESS
+ - saved_model.pb
+ - variables
+ - variables.data-00000-of-00001
+ - variables.index
+ """
+ job_path = self.get_job_path()
+ if job_path is None:
+ return None
+ return get_exported_model_path(job_path)
+
+ def get_checkpoint_path(self):
+ job_path = self.get_job_path()
+ if job_path is None:
+ return None
+ return get_checkpoint_path(job_path=job_path)
+
+ def get_output_path(self):
+ job_path = self.get_job_path()
+ if job_path is None:
+ return None
+ return get_output_path(job_path)
+
+ def model_name(self) -> Optional[str]:
+ if self.model_id is not None:
+ return self.model.name
+ return None
+
+ def dataset_name(self) -> Optional[str]:
+ # checking through relationship instead of existence of id, since item is possibly deleted
+ if self.dataset is not None:
+ return self.dataset.name
+ return None
+
+ def started_at(self) -> Optional[int]:
+ if self.workflow:
+ return self.workflow.start_at
+ return None
- parent = db.relationship('Model',
- primaryjoin=remote(id) == foreign(parent_id),
- backref='children')
- job = db.relationship('Job', primaryjoin=Job.name == foreign(job_name))
+ def stopped_at(self) -> Optional[int]:
+ if self.workflow:
+ return self.workflow.stop_at
+ return None
- def get_eval_model(self):
- return [
- child for child in self.children if child.type in
- [ModelType.NN_EVALUATION.value, ModelType.TREE_EVALUATION.value]
+ def config(self) -> Optional[WorkflowDefinition]:
+ if self.workflow:
+ return self.workflow.get_config()
+ return None
+
+ def is_deletable(self) -> bool:
+ return self.state in [
+ WorkflowExternalState.FAILED, WorkflowExternalState.STOPPED, WorkflowExternalState.COMPLETED
]
+ def set_global_config(self, proto: ModelJobGlobalConfig):
+ self.global_config = text_format.MessageToString(proto)
-@to_dict_mixin()
-class ModelGroup(db.Model):
- __tablename__ = 'model_groups_v2'
- __table_args__ = (default_table_args('model_groups_v2'))
+ def get_global_config(self) -> Optional[ModelJobGlobalConfig]:
+ if self.global_config is not None:
+ return text_format.Parse(self.global_config, ModelJobGlobalConfig())
+ return None
- id = db.Column(db.Integer, primary_key=True, comment='id')
- name = db.Column(db.String(255),
- comment='name') # can be modified by end-user
- created_at = db.Column(db.DateTime(timezone=True),
- comment='created_at',
- server_default=func.now())
+class Model(db.Model, SoftDeleteModel):
+ __tablename__ = 'models_v2'
+ __table_args__ = (UniqueConstraint('name', name='uniq_name'), UniqueConstraint('uuid', name='uniq_uuid'),
+ default_table_args('models_v2'))
+
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ name = db.Column(db.String(255), comment='name')
+ uuid = db.Column(db.String(64), comment='uuid')
+ algorithm_type = db.Column(db.Enum(AlgorithmType, native_enum=False, length=32, create_constraint=False),
+ default=AlgorithmType.UNSPECIFIED,
+ comment='algorithm type')
+ # TODO(hangweiqiang): remove model_type coloumn
+ model_type = db.Column(db.Enum(ModelType, native_enum=False, length=32, create_constraint=False),
+ default=ModelType.UNSPECIFIED,
+ comment='type')
+ model_path = db.Column(db.String(512), comment='model path')
+ favorite = db.Column(db.Boolean, default=False, comment='favorite model')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
+ group_id = db.Column(db.Integer, comment='group_id')
+ project_id = db.Column(db.Integer, comment='project_id')
+ job_id = db.Column(db.Integer, comment='job id')
+ model_job_id = db.Column(db.Integer, comment='model job id')
+ version = db.Column(db.Integer, comment='version')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created_at', server_default=func.now())
updated_at = db.Column(db.DateTime(timezone=True),
comment='updated_at',
server_default=func.now(),
onupdate=func.now())
deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted_at')
+ group = db.relationship('ModelJobGroup', primaryjoin='ModelJobGroup.id == foreign(Model.group_id)')
+ project = db.relationship('Project', primaryjoin='Project.id == foreign(Model.project_id)')
+ job = db.relationship('Job', primaryjoin='Job.id == foreign(Model.job_id)')
+ # the model_job generating this model
+ model_job = db.relationship('ModelJob', primaryjoin='ModelJob.id == foreign(Model.model_job_id)')
+ # the model_jobs inheriting this model
+ derived_model_jobs = db.relationship(
+ 'ModelJob',
+ primaryjoin='foreign(ModelJob.model_id) == Model.id',
+ # To disable the warning of back_populates
+ overlaps='model')
+
+ def to_proto(self) -> ModelPb:
+ return ModelPb(
+ id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ algorithm_type=self.algorithm_type.name if self.algorithm_type else AlgorithmType.UNSPECIFIED.name,
+ group_id=self.group_id,
+ project_id=self.project_id,
+ model_job_id=self.model_job_id,
+ model_job_name=self.model_job_name(),
+ job_id=self.job_id,
+ job_name=self.job_name(),
+ workflow_id=self.workflow_id(),
+ workflow_name=self.workflow_name(),
+ version=self.version,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ comment=self.comment,
+ model_path=self.model_path)
+
+ def workflow_id(self):
+ if self.job is not None:
+ return self.job.workflow_id
+ return None
+
+ def job_name(self):
+ if self.job is not None:
+ return self.job.name
+ return None
+
+ def workflow_name(self):
+ if self.job is not None:
+ return self.job.workflow.name
+ return None
+
+ def model_job_name(self):
+ if self.model_job is not None:
+ return self.model_job.name
+ return None
- # TODO https://code.byted.org/data/fedlearner_web_console_v2/issues/289
+ def get_exported_model_path(self):
+ """Get the path of the exported models
+ same with get_exported_path function in ModelJob class
+ """
+ return get_exported_model_path(self.model_path)
+
+ def get_checkpoint_path(self):
+ return get_checkpoint_path(self.model_path)
+
+
+class ModelJobGroup(db.Model, SoftDeleteModel, ReviewTicketAndAuthModel):
+ # inconsistency between table name and class name due to historical issues
+ __tablename__ = 'model_groups_v2'
+ __table_args__ = (UniqueConstraint('name', name='uniq_name'), default_table_args('model_groups_v2'))
+
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ uuid = db.Column(db.String(64), comment='uuid')
+ name = db.Column(db.String(255), comment='name')
+ project_id = db.Column(db.Integer, comment='project_id')
+ role = db.Column(db.Enum(ModelJobRole, native_enum=False, length=32, create_constraint=False),
+ default=ModelJobRole.PARTICIPANT,
+ comment='role')
+ authorized = db.Column(db.Boolean, default=False, comment='authorized to participants in project')
+ dataset_id = db.Column(db.Integer, comment='dataset id')
+ algorithm_type = db.Column(db.Enum(AlgorithmType, native_enum=False, length=32, create_constraint=False),
+ default=AlgorithmType.UNSPECIFIED,
+ comment='algorithm type')
+ algorithm_project_id = db.Column(db.Integer, comment='algorithm project id')
+ algorithm_id = db.Column(db.Integer, comment='algorithm id')
+ config = db.Column(db.Text(16777215), comment='config')
+ cron_job_global_config = db.Column(db.Text(16777215), comment='global config for cron job')
+ # use proto.AlgorithmProjectList to store the algorithm project uuid of each participant
+ algorithm_project_uuid_list = db.Column('algorithm_uuid_list',
+ db.Text(16777215),
+ key='algorithm_uuid_list',
+ comment='algorithm project uuid for all participants')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
+ creator_username = db.Column(db.String(255), comment='creator username')
+ coordinator_id = db.Column(db.Integer, comment='coordinator participant id')
+ cron_config = db.Column(db.String(255), comment='cron expression in UTC timezone')
+ path = db.Column('fspath', db.String(512), key='path', comment='model job group path')
+ _auth_status = db.Column('auth_status',
+ db.Enum(auth_model.AuthStatus, native_enum=False, length=32, create_constraint=False),
+ default=auth_model.AuthStatus.PENDING,
+ comment='auth status')
+ auto_update_status = db.Column(db.Enum(GroupAutoUpdateStatus, native_enum=False, length=32,
+ create_constraint=False),
+ default=GroupAutoUpdateStatus.INITIAL,
+ comment='auto update status')
+ start_data_batch_id = db.Column(db.Integer, comment='start data_batches id for auto update job')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created_at', server_default=func.now())
+ updated_at = db.Column(db.DateTime(timezone=True),
+ comment='updated_at',
+ server_default=func.now(),
+ onupdate=func.now())
+ deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted_at')
extra = db.Column(db.Text(), comment='extra') # json string
+ latest_version = db.Column(db.Integer, default=0, comment='latest version')
+ status = db.Column(db.Enum(GroupCreateStatus, native_enum=False, length=32, create_constraint=False),
+ default=GroupCreateStatus.PENDING,
+ comment='create status')
+ project = db.relationship('Project', primaryjoin='Project.id == foreign(ModelJobGroup.project_id)')
+ algorithm = db.relationship('Algorithm', primaryjoin='Algorithm.id == foreign(ModelJobGroup.algorithm_id)')
+ algorithm_project = db.relationship(
+ 'AlgorithmProject', primaryjoin='AlgorithmProject.id == foreign(ModelJobGroup.algorithm_project_id)')
+ dataset = db.relationship('Dataset', primaryjoin='Dataset.id == foreign(ModelJobGroup.dataset_id)')
+ model_jobs = db.relationship(
+ 'ModelJob',
+ order_by='desc(ModelJob.version)',
+ primaryjoin='ModelJobGroup.id == foreign(ModelJob.group_id)',
+ # To disable the warning of back_populates
+ overlaps='group')
+ start_data_batch = db.relationship('DataBatch',
+ primaryjoin='DataBatch.id == foreign(ModelJobGroup.start_data_batch_id)')
+
+ @property
+ def auth_status(self):
+ if self._auth_status is not None:
+ return self._auth_status
+ if self.authorized:
+ return auth_model.AuthStatus.AUTHORIZED
+ return auth_model.AuthStatus.PENDING
+
+ @auth_status.setter
+ def auth_status(self, auth_status: auth_model.AuthStatus):
+ self._auth_status = auth_status
+
+ def to_ref(self) -> ModelJobGroupRef:
+ group = ModelJobGroupRef(id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ role=self.role.name,
+ project_id=self.project_id,
+ authorized=self.authorized,
+ algorithm_type=self.algorithm_type.name,
+ configured=self.config is not None,
+ creator_username=self.creator_username,
+ coordinator_id=self.coordinator_id,
+ latest_version=self.latest_version,
+ participants_info=self.get_participants_info(),
+ auth_status=self.auth_status.name,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at))
+ latest_job_state = self.latest_job_state()
+ if latest_job_state is not None:
+ group.latest_job_state = latest_job_state.name
+ group.auth_frontend_status = self.get_group_auth_frontend_status().name
+ return group
+
+ def to_proto(self) -> ModelJobGroupPb:
+ group = ModelJobGroupPb(id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ role=self.role.name,
+ project_id=self.project_id,
+ authorized=self.authorized,
+ dataset_id=self.dataset_id,
+ algorithm_type=self.algorithm_type.name,
+ algorithm_project_id=self.algorithm_project_id,
+ algorithm_id=self.algorithm_id,
+ configured=self.config is not None,
+ creator_username=self.creator_username,
+ coordinator_id=self.coordinator_id,
+ cron_config=self.cron_config,
+ latest_version=self.latest_version,
+ participants_info=self.get_participants_info(),
+ algorithm_project_uuid_list=self.get_algorithm_project_uuid_list(),
+ auth_status=self.auth_status.name,
+ auto_update_status=self.auto_update_status.name if self.auto_update_status else '',
+ start_data_batch_id=self.start_data_batch_id,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ comment=self.comment)
+ latest_job_state = self.latest_job_state()
+ if latest_job_state is not None:
+ group.latest_job_state = latest_job_state.name
+ group.auth_frontend_status = self.get_group_auth_frontend_status().name
+ if self.config is not None:
+ group.config.MergeFrom(self.get_config())
+ group.model_jobs.extend([mj.to_ref() for mj in self.model_jobs])
+ return group
+
+ def latest_job_state(self) -> Optional[ModelJobStatus]:
+ if len(self.model_jobs) == 0:
+ return None
+ return self.model_jobs[0].status
+
+ def get_group_auth_frontend_status(self) -> GroupAuthFrontendStatus:
+ if self.ticket_status == TicketStatus.PENDING:
+ if self.ticket_uuid is not None:
+ return GroupAuthFrontendStatus.TICKET_PENDING
+ # Update old data that is set to PENDING by default when ticket is disabled
+ self.ticket_status = TicketStatus.APPROVED
+ if self.ticket_status == TicketStatus.DECLINED:
+ return GroupAuthFrontendStatus.TICKET_DECLINED
+ if not self.authorized:
+ return GroupAuthFrontendStatus.SELF_AUTH_PENDING
+ if self.is_all_participants_authorized():
+ return GroupAuthFrontendStatus.ALL_AUTHORIZED
+ if self.status == GroupCreateStatus.PENDING:
+ return GroupAuthFrontendStatus.CREATE_PENDING
+ if self.status == GroupCreateStatus.FAILED:
+ return GroupAuthFrontendStatus.CREATE_FAILED
+ return GroupAuthFrontendStatus.PART_AUTH_PENDING
+
+ def get_config(self) -> Optional[WorkflowDefinition]:
+ if self.config is not None:
+ return text_format.Parse(self.config, WorkflowDefinition())
+ return None
+
+ def set_config(self, config: Optional[WorkflowDefinition] = None):
+ if config is None:
+ config = WorkflowDefinition()
+ self.config = text_format.MessageToString(config)
+
+ def is_deletable(self) -> bool:
+ for model_job in self.model_jobs:
+ if not model_job.is_deletable():
+ return False
+ return True
+
+ def latest_completed_job(self) -> Optional[ModelJob]:
+ for job in self.model_jobs:
+ if job.state == WorkflowExternalState.COMPLETED:
+ return job
+ return None
+
+ def set_algorithm_project_uuid_list(self, proto: AlgorithmProjectList):
+ self.algorithm_project_uuid_list = text_format.MessageToString(proto)
+
+ def get_algorithm_project_uuid_list(self) -> AlgorithmProjectList:
+ algorithm_project_uuid_list = AlgorithmProjectList()
+ if self.algorithm_project_uuid_list is not None:
+ algorithm_project_uuid_list = text_format.Parse(self.algorithm_project_uuid_list, AlgorithmProjectList())
+ return algorithm_project_uuid_list
+
+
+def is_federated(algorithm_type: AlgorithmType, model_job_type: ModelJobType) -> bool:
+ return algorithm_type != AlgorithmType.NN_HORIZONTAL or model_job_type == ModelJobType.TRAINING
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/models_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/models_test.py
new file mode 100644
index 000000000..bcbff920a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/models_test.py
@@ -0,0 +1,452 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, PropertyMock
+from datetime import datetime
+from google.protobuf.json_format import MessageToDict
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.workflow.models import WorkflowExternalState
+from fedlearner_webconsole.mmgr.models import ModelJob, Model, ModelJobType, ModelJobGroup, ModelJobRole, \
+ GroupCreateStatus, GroupAuthFrontendStatus, AlgorithmProjectList, ModelJobAuthFrontendStatus, \
+ ModelJobCreateStatus, AuthStatus as ModelJobAuthStatus
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobRef, ModelPb, ModelJobGroupRef, ModelJobGroupPb, \
+ ModelJobGlobalConfig, ModelJobConfig
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+
+
+class ModelTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ model = Model(id=1,
+ name='model',
+ uuid='uuid',
+ group_id=2,
+ project_id=3,
+ job_id=4,
+ model_job_id=5,
+ version=1,
+ created_at=datetime(2022, 5, 10, 0, 0, 0),
+ updated_at=datetime(2022, 5, 10, 0, 0, 0))
+ session.add(model)
+ session.commit()
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ model: Model = session.query(Model).get(1)
+ pb = ModelPb(id=1,
+ name='model',
+ uuid='uuid',
+ group_id=2,
+ project_id=3,
+ algorithm_type='UNSPECIFIED',
+ job_id=4,
+ model_job_id=5,
+ version=1,
+ created_at=1652140800,
+ updated_at=1652140800)
+ self.assertEqual(model.to_proto(), pb)
+ model.algorithm_type = AlgorithmType.NN_VERTICAL
+ pb.algorithm_type = 'NN_VERTICAL'
+ self.assertEqual(model.to_proto(), pb)
+
+
+class ModelJobTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='test-project')
+ model_job = ModelJob(id=1,
+ name='job',
+ uuid='uuid',
+ project_id=1,
+ group_id=2,
+ model_job_type=ModelJobType.TRAINING,
+ role=ModelJobRole.COORDINATOR,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ ticket_status=TicketStatus.PENDING,
+ ticket_uuid='ticket_uuid',
+ job_name='uuid-train-job',
+ job_id=3,
+ workflow_uuid='uuid',
+ workflow_id=5,
+ algorithm_id=6,
+ dataset_id=7,
+ creator_username='ada',
+ coordinator_id=8,
+ version=1,
+ created_at=datetime(2022, 5, 10, 0, 0, 0),
+ updated_at=datetime(2022, 5, 10, 0, 0, 0),
+ metric_is_public=True,
+ auto_update=True,
+ auth_status=ModelJobAuthStatus.AUTHORIZED,
+ error_message='error_message')
+ session.add_all([project, model_job])
+ session.commit()
+
+ @patch('fedlearner_webconsole.project.models.Project.get_storage_root_path')
+ def test_exported_model_path(self, mock_get_storage_root_path):
+ mock_get_storage_root_path.return_value = '/data/'
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ session.add(model_job)
+ session.flush()
+ # test for model_job.get_exported_model_path
+ expected_path = f'/data/job_output/{model_job.job_name}/exported_models'
+ self.assertEqual(model_job.get_exported_model_path(), expected_path)
+ # test for model.get_exported_model_path
+ model_path = '/data/model_output/uuid'
+ model = Model(model_path=model_path)
+ expected_path = '/data/model_output/uuid/exported_models'
+ self.assertEqual(model.get_exported_model_path(), expected_path)
+
+ @patch('fedlearner_webconsole.mmgr.models.ModelJob.state', new_callable=PropertyMock)
+ def test_is_deletable(self, mock_state):
+ mock_state.return_value = WorkflowExternalState.RUNNING
+ model_job = ModelJob(name='model_job')
+ self.assertEqual(model_job.is_deletable(), False)
+ mock_state.return_value = WorkflowExternalState.FAILED
+ self.assertEqual(model_job.is_deletable(), True)
+
+ def test_to_ref(self):
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(1)
+ ref = ModelJobRef(id=1,
+ name='job',
+ uuid='uuid',
+ project_id=1,
+ group_id=2,
+ model_job_type='TRAINING',
+ role='COORDINATOR',
+ algorithm_type='NN_VERTICAL',
+ algorithm_id=6,
+ state='PENDING_ACCEPT',
+ configured=False,
+ creator_username='ada',
+ coordinator_id=8,
+ version=1,
+ created_at=1652140800,
+ updated_at=1652140800,
+ metric_is_public=True,
+ status='PENDING',
+ auth_frontend_status='TICKET_PENDING',
+ auth_status='AUTHORIZED',
+ auto_update=True,
+ participants_info=ParticipantsInfo())
+ self.assertEqual(model_job.to_ref(), ref)
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(1)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={'test': ModelJobConfig(algorithm_uuid='uuid')})
+ model_job.set_global_config(global_config)
+ self.assertPartiallyEqual(MessageToDict(model_job.to_proto()), {
+ 'id': '1',
+ 'name': 'job',
+ 'uuid': 'uuid',
+ 'role': 'COORDINATOR',
+ 'modelJobType': 'TRAINING',
+ 'algorithmType': 'NN_VERTICAL',
+ 'algorithmId': '6',
+ 'groupId': '2',
+ 'projectId': '1',
+ 'state': 'PENDING_ACCEPT',
+ 'jobId': '3',
+ 'workflowId': '5',
+ 'datasetId': '7',
+ 'creatorUsername': 'ada',
+ 'coordinatorId': '8',
+ 'version': '1',
+ 'jobName': 'uuid-train-job',
+ 'metricIsPublic': True,
+ 'status': 'PENDING',
+ 'authStatus': 'AUTHORIZED',
+ 'autoUpdate': True,
+ 'errorMessage': 'error_message',
+ 'globalConfig': {
+ 'globalConfig': {
+ 'test': {
+ 'algorithmUuid': 'uuid'
+ }
+ },
+ 'datasetUuid': 'uuid'
+ },
+ 'authFrontendStatus': 'TICKET_PENDING',
+ 'participantsInfo': {},
+ },
+ ignore_fields=['createdAt', 'updatedAt'])
+
+ def test_set_and_get_global_config(self):
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={'test': ModelJobConfig(algorithm_uuid='uuid')})
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(1)
+ self.assertIsNone(model_job.get_global_config())
+ model_job.set_global_config(proto=global_config)
+ session.commit()
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(1)
+ self.assertEqual(model_job.get_global_config(), global_config)
+
+ def test_get_model_job_auth_frontend_status(self):
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ # case 1
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(), ModelJobAuthFrontendStatus.TICKET_PENDING)
+ # case 2
+ model_job.ticket_status = TicketStatus.DECLINED
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(), ModelJobAuthFrontendStatus.TICKET_DECLINED)
+ # case 3
+ model_job.ticket_status = TicketStatus.APPROVED
+ model_job.auth_status = ModelJobAuthStatus.PENDING
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(),
+ ModelJobAuthFrontendStatus.SELF_AUTH_PENDING)
+ # case 4
+ model_job.auth_status = ModelJobAuthStatus.AUTHORIZED
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test_1': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test_2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ model_job.set_participants_info(participants_info)
+ model_job.create_status = ModelJobCreateStatus.PENDING
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(), ModelJobAuthFrontendStatus.CREATE_PENDING)
+ # case 5
+ model_job.create_status = ModelJobCreateStatus.FAILED
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(), ModelJobAuthFrontendStatus.CREATE_FAILED)
+ # case 6
+ model_job.create_status = ModelJobCreateStatus.SUCCEEDED
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(),
+ ModelJobAuthFrontendStatus.PART_AUTH_PENDING)
+ # case 7
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test_1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ model_job.set_participants_info(participants_info)
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(), ModelJobAuthFrontendStatus.ALL_AUTHORIZED)
+ # case 8
+ model_job.ticket_uuid = None
+ model_job.ticket_status = TicketStatus.PENDING
+ self.assertEqual(model_job.get_model_job_auth_frontend_status(), ModelJobAuthFrontendStatus.ALL_AUTHORIZED)
+ self.assertEqual(model_job.ticket_status, TicketStatus.APPROVED)
+
+
+class ModelJobGroupTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ config = WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[Variable(name='mode', value='train')])
+ ])
+ job = ModelJob(id=1,
+ name='job',
+ uuid='uuid',
+ project_id=2,
+ group_id=1,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ auth_status=ModelJobAuthStatus.AUTHORIZED,
+ auto_update=True,
+ created_at=datetime(2022, 5, 10, 0, 0, 0),
+ updated_at=datetime(2022, 5, 10, 0, 0, 0),
+ version=1)
+ group = ModelJobGroup(id=1,
+ name='group',
+ uuid='uuid',
+ project_id=2,
+ role=ModelJobRole.COORDINATOR,
+ authorized=False,
+ ticket_status=TicketStatus.PENDING,
+ ticket_uuid='ticket_uuid',
+ dataset_id=3,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_project_id=4,
+ algorithm_id=5,
+ creator_username='ada',
+ coordinator_id=6,
+ created_at=datetime(2022, 5, 10, 0, 0, 0),
+ updated_at=datetime(2022, 5, 10, 0, 0, 0))
+ algorithm_project_list = AlgorithmProjectList()
+ algorithm_project_list.algorithm_projects['test'] = 'uuid-test'
+ algorithm_project_list.algorithm_projects['peer'] = 'uuid-peer'
+ group.set_algorithm_project_uuid_list(algorithm_project_list)
+ group.set_config(config)
+ session.add_all([job, group])
+ session.commit()
+
+ def test_get_config(self):
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ self.assertEqual(
+ group.get_config(),
+ WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[Variable(name='mode', value='train')])
+ ]))
+
+ @patch('fedlearner_webconsole.mmgr.models.ModelJob.state', new_callable=PropertyMock)
+ def test_is_deletable(self, mock_state):
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ mock_state.return_value = WorkflowExternalState.RUNNING
+ self.assertEqual(group.is_deletable(), False)
+ mock_state.return_value = WorkflowExternalState.STOPPED
+ self.assertEqual(group.is_deletable(), True)
+
+ def test_auth_status(self):
+ group = ModelJobGroup(id=2, auth_status=None)
+ self.assertEqual(group.auth_status, AuthStatus.PENDING)
+ group.authorized = True
+ self.assertEqual(group.auth_status, AuthStatus.AUTHORIZED)
+ group.authorized = False
+ group.auth_status = AuthStatus.AUTHORIZED
+ self.assertEqual(group.auth_status, AuthStatus.AUTHORIZED)
+ with db.session_scope() as session:
+ session.add(group)
+ session.commit()
+ self.assertEqual(group._auth_status, AuthStatus.AUTHORIZED) # pylint: disable=protected-access
+
+ def test_get_group_auth_frontend_status(self):
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ # case 1
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.TICKET_PENDING)
+ # case 2
+ group.ticket_status = TicketStatus.DECLINED
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.TICKET_DECLINED)
+ # case 3
+ group.ticket_status = TicketStatus.APPROVED
+ group.authorized = False
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.SELF_AUTH_PENDING)
+ # case 4
+ group.authorized = True
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test_1': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test_2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ group.set_participants_info(participants_info)
+ group.status = GroupCreateStatus.PENDING
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.CREATE_PENDING)
+ # case 5
+ group.status = GroupCreateStatus.FAILED
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.CREATE_FAILED)
+ # case 6
+ group.status = GroupCreateStatus.SUCCEEDED
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.PART_AUTH_PENDING)
+ # case 7
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test_1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ group.set_participants_info(participants_info)
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.ALL_AUTHORIZED)
+ # case 8
+ group.ticket_uuid = None
+ group.ticket_status = TicketStatus.PENDING
+ self.assertEqual(group.get_group_auth_frontend_status(), GroupAuthFrontendStatus.ALL_AUTHORIZED)
+ self.assertEqual(group.ticket_status, TicketStatus.APPROVED)
+
+ def test_to_ref(self):
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ ref = ModelJobGroupRef(id=1,
+ name='group',
+ uuid='uuid',
+ role='COORDINATOR',
+ project_id=2,
+ algorithm_type='NN_VERTICAL',
+ configured=True,
+ creator_username='ada',
+ coordinator_id=6,
+ latest_job_state='PENDING',
+ auth_frontend_status='TICKET_PENDING',
+ auth_status='PENDING',
+ participants_info=group.get_participants_info(),
+ created_at=1652140800,
+ updated_at=1652140800)
+ self.assertEqual(group.to_ref(), ref)
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ proto = ModelJobGroupPb(id=1,
+ name='group',
+ uuid='uuid',
+ role='COORDINATOR',
+ project_id=2,
+ dataset_id=3,
+ algorithm_type='NN_VERTICAL',
+ algorithm_project_id=4,
+ algorithm_id=5,
+ configured=True,
+ creator_username='ada',
+ coordinator_id=6,
+ latest_job_state='PENDING',
+ auth_frontend_status='TICKET_PENDING',
+ auth_status='PENDING',
+ auto_update_status='INITIAL',
+ participants_info=group.get_participants_info(),
+ algorithm_project_uuid_list=group.get_algorithm_project_uuid_list(),
+ created_at=1652140800,
+ updated_at=1652140800)
+ config = WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[Variable(name='mode', value='train')])
+ ])
+ proto.config.MergeFrom(config)
+ proto.model_jobs.append(
+ ModelJobRef(id=1,
+ name='job',
+ uuid='uuid',
+ group_id=1,
+ project_id=2,
+ role='PARTICIPANT',
+ model_job_type='TRAINING',
+ algorithm_type='NN_VERTICAL',
+ state='PENDING_ACCEPT',
+ created_at=1652140800,
+ updated_at=1652140800,
+ version=1,
+ status='PENDING',
+ auto_update=True,
+ auth_status='AUTHORIZED',
+ auth_frontend_status='ALL_AUTHORIZED',
+ participants_info=ParticipantsInfo()))
+ self.assertEqual(group.to_proto(), proto)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/scheduler.py b/web_console_v2/api/fedlearner_webconsole/mmgr/scheduler.py
new file mode 100644
index 000000000..05f36cf7f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/scheduler.py
@@ -0,0 +1,206 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from sqlalchemy import or_
+from typing import List, Tuple
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.services import BatchService
+from fedlearner_webconsole.composer.interface import IRunnerV2, RunnerContext, RunnerOutput
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobRole, ModelJobStatus, ModelJobGroup, \
+ GroupCreateStatus, GroupAutoUpdateStatus, Model, ModelJobType, ModelJobAuthFrontendStatus
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobConfig
+from fedlearner_webconsole.mmgr.model_job_configer import ModelJobConfiger, set_load_model_name
+from fedlearner_webconsole.mmgr.service import ModelJobService, ModelJobGroupService
+from fedlearner_webconsole.mmgr.controller import ModelJobGroupController
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+
+
+class ModelJobSchedulerRunner(IRunnerV2):
+
+ @staticmethod
+ def _check_model_job(model_job_id: int):
+ """check workflow state and update model job status"""
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(model_job_id)
+ ModelJobService(session).update_model_job_status(model_job)
+ session.commit()
+ logging.info(f'[ModelJobScheduler] model_job {model_job.name} updates status to {model_job.status}')
+
+ @staticmethod
+ def _config_model_job(model_job_id: int):
+ """config model job by calling model job configer and model job service"""
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(model_job_id)
+ global_config = model_job.get_global_config()
+ if global_config is None:
+ ModelJobService(session).update_model_job_status(model_job)
+ session.commit()
+ return
+ domain_name = SettingService(session).get_system_info().pure_domain_name
+ model_job_config: ModelJobConfig = global_config.global_config.get(domain_name)
+ try:
+ configer = ModelJobConfiger(session=session,
+ model_job_type=model_job.model_job_type,
+ algorithm_type=model_job.algorithm_type,
+ project_id=model_job.project_id)
+ config = configer.get_config(dataset_id=model_job.dataset_id,
+ model_id=model_job.model_id,
+ model_job_config=model_job_config)
+ ModelJobService(session).config_model_job(model_job=model_job,
+ config=config,
+ create_workflow=False,
+ need_to_create_ready_workflow=True,
+ workflow_uuid=model_job.uuid)
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception(f'[ModelJobScheduler] config model job {model_job_id} failed')
+ model_job.error_message = str(e)
+ model_job.status = ModelJobStatus.ERROR
+ finally:
+ session.commit()
+ logging.info(f'[ModelJobScheduler] model_job {model_job_id} is CONFIGURED')
+
+ def schedule_model_job(self):
+ # 1. filter training model job with status PENDING and evaluation or prediction model job with status PENDING
+ # and ModelJobAuthFrontendStatus ALL_AUTHORIZED, and pull algorithm and create workflow, then status
+ # becomes CONFIGURED
+ with db.session_scope() as session:
+ training_model_job_ids: List[Tuple[int]] = session.query(ModelJob.id).filter_by(
+ status=ModelJobStatus.PENDING, model_job_type=ModelJobType.TRAINING).all()
+ non_training_model_job_ids: List[Tuple[int]] = session.query(
+ ModelJob.id).filter(ModelJob.model_job_type != ModelJobType.TRAINING).filter(
+ ModelJob.status == ModelJobStatus.PENDING).all()
+ for training_model_job_id, *_ in training_model_job_ids:
+ self._config_model_job(model_job_id=training_model_job_id)
+ for non_training_model_job_id, *_ in non_training_model_job_ids:
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(non_training_model_job_id)
+ if model_job.get_model_job_auth_frontend_status() in [ModelJobAuthFrontendStatus.ALL_AUTHORIZED]:
+ self._config_model_job(model_job_id=non_training_model_job_id)
+ # 2. filter model job with status CONFIGURED, RUNNING, and update the model job status
+ model_job_ids: List[Tuple[int]] = session.query(ModelJob.id).filter(
+ or_(ModelJob.status == ModelJobStatus.CONFIGURED, ModelJob.status == ModelJobStatus.RUNNING)).all()
+ for model_job_id, *_ in model_job_ids:
+ self._check_model_job(model_job_id=model_job_id)
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ try:
+ self.schedule_model_job()
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception('[ModelJobScheduler] schedule model job failed')
+ return RunnerStatus.FAILED, RunnerOutput(error_message=str(e))
+
+ return RunnerStatus.DONE, RunnerOutput()
+
+
+class ModelJobGroupSchedulerRunner(IRunnerV2):
+
+ @staticmethod
+ def _create_model_job_group_for_participants(model_job_group_id: int):
+ """create model job group for the participants"""
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(model_job_group_id)
+ ModelJobGroupController(session=session,
+ project_id=group.project_id).create_model_job_group_for_participants(
+ model_job_group_id=model_job_group_id)
+ session.commit()
+
+ def _schedule_model_job_group(self):
+ # filter model job group with ticket status APPROVED, create the model job group for the participants
+ with db.session_scope() as session:
+ model_job_group_ids: List[Tuple[int]] = session.query(
+ ModelJobGroup.id).filter_by(role=ModelJobRole.COORDINATOR,
+ status=GroupCreateStatus.PENDING,
+ ticket_status=TicketStatus.APPROVED).all()
+ for model_job_group_id, *_ in model_job_group_ids:
+ self._create_model_job_group_for_participants(model_job_group_id=model_job_group_id)
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ try:
+ self._schedule_model_job_group()
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception('[ModelJobGroupScheduler] schedule model job group failed')
+ return RunnerStatus.FAILED, RunnerOutput(error_message=str(e))
+
+ return RunnerStatus.DONE, RunnerOutput()
+
+
+class ModelJobGroupLongPeriodScheduler(IRunnerV2):
+
+ @staticmethod
+ def _create_auto_update_model_job(model_job_group_id: int):
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(auto_update=True, group_id=model_job_group_id).order_by(
+ ModelJob.created_at.desc()).limit(1).first()
+ if model_job is None:
+ group = session.query(ModelJobGroup).get(model_job_group_id)
+ logging.warning(f'There is no auto update model jobs in the model job group {group.name}')
+ return
+ if model_job.status not in [ModelJobStatus.SUCCEEDED]:
+ logging.warning(f'The status of the latest auto update model job {model_job.name} is not SUCCEEDED')
+ return
+ next_data_batch = BatchService(session).get_next_batch(model_job.data_batch)
+ if next_data_batch is None:
+ logging.warning(
+ f'There is no next data batch after the data batch with name: {model_job.data_batch.name}')
+ return
+ group: ModelJobGroup = ModelJobGroupService(session).lock_and_update_version(model_job_group_id)
+ version = group.latest_version
+ model_job_name = f'{group.name}-v{version}'
+ # Load the model of the previous model job for the new training job
+ model_name = model_job.model_name()
+ if model_name is None:
+ model = session.query(Model).filter_by(name=model_job.name).first()
+ if model is None:
+ raise Exception(f'model_job {model_job.name}\'s model is not found')
+ model_job.model_id = model.id
+ model_name = model.name
+ global_config = model_job.get_global_config()
+ if global_config is not None:
+ for config in global_config.global_config.values():
+ set_load_model_name(config, model_name)
+ ModelJobService(session).create_model_job(name=model_job_name,
+ uuid=resource_uuid(),
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=model_job.model_job_type,
+ algorithm_type=model_job.algorithm_type,
+ global_config=global_config,
+ group_id=model_job_group_id,
+ project_id=model_job.project_id,
+ data_batch_id=next_data_batch.id,
+ comment=model_job.comment,
+ version=version)
+ session.commit()
+
+ def _schedule_model_job_group(self):
+ # filter model job group with auto update status ACTIVE,
+ # create auto update model job for all participants by COORDINATOR
+ with db.session_scope() as session:
+ model_job_group_ids: List[Tuple[int]] = session.query(ModelJobGroup.id).filter_by(
+ role=ModelJobRole.COORDINATOR, auto_update_status=GroupAutoUpdateStatus.ACTIVE).all()
+ for model_job_group_id, *_ in model_job_group_ids:
+ try:
+ self._create_auto_update_model_job(model_job_group_id=model_job_group_id)
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception(f'[ModelJobGroupLongPeriodScheduler] fail to create auto update model job for '
+ f'group id: {model_job_group_id}')
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ self._schedule_model_job_group()
+ return RunnerStatus.DONE, RunnerOutput()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/scheduler_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/scheduler_test.py
new file mode 100644
index 000000000..deca67711
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/scheduler_test.py
@@ -0,0 +1,511 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=protected-access
+import grpc
+import unittest
+from datetime import datetime
+from unittest.mock import patch, MagicMock, call, ANY
+from google.protobuf.struct_pb2 import Value
+
+from testing.common import NoWebServerTestCase
+from testing.rpc.client import FakeRpcError
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, DataBatch
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobRole, ModelJobStatus, ModelJobType, ModelJobGroup, \
+ GroupCreateStatus, GroupAutoUpdateStatus, Model, AuthStatus
+from fedlearner_webconsole.mmgr.scheduler import ModelJobSchedulerRunner, ModelJobGroupSchedulerRunner, \
+ ModelJobGroupLongPeriodScheduler
+from fedlearner_webconsole.composer.interface import RunnerStatus, RunnerContext
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, ModelJobConfig, AlgorithmProjectList
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, RunnerOutput
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+
+
+class ModelJobSchedulerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='test')
+ participant = Participant(id=1, name='party', domain_name='fl-peer.com')
+ project_participant = ProjectParticipant(project_id=1, participant_id=1)
+ _insert_or_update_templates(session)
+ g1 = ModelJobGroup(id=1, name='g1', uuid='group-uuid')
+ m1 = ModelJob(id=1,
+ name='j1',
+ role=ModelJobRole.COORDINATOR,
+ status=ModelJobStatus.PENDING,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ dataset_id=1,
+ project_id=1)
+ m1.set_global_config(ModelJobGlobalConfig(global_config={'test': ModelJobConfig(algorithm_uuid='uuid')}))
+ m2 = ModelJob(id=2,
+ name='j2',
+ model_job_type=ModelJobType.TRAINING,
+ role=ModelJobRole.PARTICIPANT,
+ status=ModelJobStatus.PENDING)
+ m3 = ModelJob(id=3,
+ name='j3',
+ role=ModelJobRole.COORDINATOR,
+ status=ModelJobStatus.CONFIGURED,
+ project_id=1,
+ uuid='uuid',
+ group_id=1,
+ version=3,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL)
+ m3.set_global_config(ModelJobGlobalConfig(global_config={'test': ModelJobConfig(algorithm_uuid='uuid')}))
+ m4 = ModelJob(id=4, name='j4', role=ModelJobRole.PARTICIPANT, status=ModelJobStatus.CONFIGURED)
+ m5 = ModelJob(id=5, name='j5', role=ModelJobRole.PARTICIPANT, status=ModelJobStatus.RUNNING)
+ m6 = ModelJob(id=6,
+ name='j6',
+ role=ModelJobRole.COORDINATOR,
+ status=ModelJobStatus.PENDING,
+ model_job_type=ModelJobType.EVALUATION,
+ auth_status=AuthStatus.AUTHORIZED)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ m6.set_participants_info(participants_info)
+ m7 = ModelJob(id=7,
+ name='j7',
+ role=ModelJobRole.PARTICIPANT,
+ status=ModelJobStatus.RUNNING,
+ model_job_type=ModelJobType.PREDICTION,
+ auth_status=AuthStatus.AUTHORIZED)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ m7.set_participants_info(participants_info)
+ m8 = ModelJob(id=8,
+ name='j8',
+ role=ModelJobRole.PARTICIPANT,
+ status=ModelJobStatus.PENDING,
+ model_job_type=ModelJobType.EVALUATION,
+ auth_status=AuthStatus.PENDING)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ m8.set_participants_info(participants_info)
+ m9 = ModelJob(id=9,
+ name='j9',
+ role=ModelJobRole.PARTICIPANT,
+ status=ModelJobStatus.PENDING,
+ model_job_type=ModelJobType.PREDICTION,
+ auth_status=AuthStatus.AUTHORIZED)
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ m9.set_participants_info(participants_info)
+ session.add_all([project, participant, project_participant, g1, m1, m2, m3, m4, m5, m6, m7, m8, m9])
+ session.commit()
+
+ @patch('fedlearner_webconsole.mmgr.model_job_configer.ModelJobConfiger.set_dataset')
+ @patch('fedlearner_webconsole.mmgr.service.ModelJobService._get_job')
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobConfiger')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_config_model_job(self, mock_system_info: MagicMock, mock_configer: MagicMock, mock_get_job: MagicMock,
+ mock_set_dataset: MagicMock):
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test')
+ mock_get_job.return_value = Job(id=1, name='job')
+ instance = mock_configer.return_value
+ instance.get_config.return_value = WorkflowDefinition(job_definitions=[JobDefinition(name='nn-model')])
+ scheduler = ModelJobSchedulerRunner()
+ scheduler._config_model_job(model_job_id=1)
+ mock_configer.assert_called_with(session=ANY,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ project_id=1)
+ instance.get_config.assert_called_with(dataset_id=1,
+ model_id=None,
+ model_job_config=ModelJobConfig(algorithm_uuid='uuid'))
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ self.assertEqual(model_job.status, ModelJobStatus.CONFIGURED)
+
+ def test_config_model_job_with_no_global_config(self):
+ with db.session_scope() as session:
+ workflow = Workflow(id=1, name='workflow', uuid='uuid', state=WorkflowState.RUNNING)
+ model_job = session.query(ModelJob).get(2)
+ model_job.workflow_uuid = 'uuid'
+ session.add(workflow)
+ session.commit()
+ scheduler = ModelJobSchedulerRunner()
+ scheduler._config_model_job(model_job_id=2)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(2)
+ self.assertEqual(model_job.status, ModelJobStatus.RUNNING)
+
+ def test_check_model_job(self):
+ ModelJobSchedulerRunner._check_model_job(model_job_id=3)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(3)
+ self.assertEqual(model_job.status, ModelJobStatus.CONFIGURED)
+ workflow = Workflow(id=1, state=WorkflowState.READY)
+ model_job.workflow_id = 1
+ session.add(workflow)
+ session.commit()
+ ModelJobSchedulerRunner._check_model_job(model_job_id=3)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(3)
+ self.assertEqual(model_job.status, ModelJobStatus.CONFIGURED)
+ workflow = session.query(Workflow).get(1)
+ workflow.state = WorkflowState.RUNNING
+ session.commit()
+ ModelJobSchedulerRunner._check_model_job(model_job_id=3)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(3)
+ self.assertEqual(model_job.status, ModelJobStatus.RUNNING)
+ workflow = session.query(Workflow).get(1)
+ workflow.state = WorkflowState.FAILED
+ session.commit()
+ ModelJobSchedulerRunner._check_model_job(model_job_id=3)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(3)
+ self.assertEqual(model_job.status, ModelJobStatus.FAILED)
+
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobSchedulerRunner._check_model_job')
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobSchedulerRunner._config_model_job')
+ def test_schedule_model_job(self, mock_config: MagicMock, mock_check_job: MagicMock):
+ scheduler = ModelJobSchedulerRunner()
+ scheduler.schedule_model_job()
+ mock_config.assert_has_calls(
+ calls=[call(
+ model_job_id=1), call(model_job_id=2),
+ call(model_job_id=6),
+ call(model_job_id=9)])
+ mock_check_job.assert_has_calls(calls=[call(model_job_id=3), call(model_job_id=4), call(model_job_id=5)])
+
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobSchedulerRunner.schedule_model_job')
+ def test_run(self, mock_schedule_model_job: MagicMock):
+ scheduler = ModelJobSchedulerRunner()
+ runner_input = RunnerInput()
+ runner_context = RunnerContext(index=0, input=runner_input)
+ runner_status, runner_output = scheduler.run(runner_context)
+ mock_schedule_model_job.assert_called()
+ mock_schedule_model_job.reset_mock()
+ self.assertEqual(runner_output, RunnerOutput())
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+
+ def side_effect():
+ raise Exception('haha')
+
+ mock_schedule_model_job.side_effect = side_effect
+ scheduler = ModelJobSchedulerRunner()
+ runner_status, runner_output = scheduler.run(runner_context)
+ mock_schedule_model_job.assert_called()
+ self.assertEqual(runner_output, RunnerOutput(error_message='haha'))
+ self.assertEqual(runner_status, RunnerStatus.FAILED)
+
+
+class ModelJobGroupSchedulerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='peer', domain_name='fl-peer.com')
+ project_participant = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ dataset = Dataset(id=1, uuid='dataset_uuid', name='dataset')
+ algorithm_project_list = AlgorithmProjectList()
+ algorithm_project_list.algorithm_projects['test'] = 'algorithm-project-uuid1'
+ algorithm_project_list.algorithm_projects['peer'] = 'algorithm-project-uuid2'
+ group1 = ModelJobGroup(id=1,
+ project_id=1,
+ uuid='uuid1',
+ name='group1',
+ status=GroupCreateStatus.PENDING,
+ ticket_status=TicketStatus.PENDING)
+ group2 = ModelJobGroup(id=2,
+ project_id=1,
+ uuid='uuid2',
+ name='group2',
+ status=GroupCreateStatus.PENDING,
+ ticket_status=TicketStatus.APPROVED,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ dataset_id=1,
+ role=ModelJobRole.COORDINATOR)
+ group2.set_algorithm_project_uuid_list(algorithm_project_list)
+ group3 = ModelJobGroup(id=3,
+ project_id=1,
+ uuid='uuid3',
+ name='group3',
+ status=GroupCreateStatus.PENDING,
+ ticket_status=TicketStatus.APPROVED,
+ algorithm_type=AlgorithmType.NN_HORIZONTAL,
+ dataset_id=1,
+ role=ModelJobRole.COORDINATOR)
+ group4 = ModelJobGroup(id=4,
+ project_id=1,
+ uuid='uuid4',
+ name='group',
+ status=GroupCreateStatus.PENDING,
+ ticket_status=TicketStatus.APPROVED,
+ role=ModelJobRole.PARTICIPANT)
+ group5 = ModelJobGroup(id=5,
+ project_id=1,
+ uuid='uuid5',
+ name='group5',
+ status=GroupCreateStatus.SUCCEEDED,
+ ticket_status=TicketStatus.APPROVED)
+ group6 = ModelJobGroup(id=6,
+ project_id=1,
+ uuid='uuid6',
+ name='group6',
+ status=GroupCreateStatus.FAILED,
+ ticket_status=TicketStatus.APPROVED)
+ session.add_all(
+ [project, participant, project_participant, dataset, group1, group2, group3, group4, group5, group6])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_model_job_group')
+ def test_create_model_job_group_for_participants(self, mock_client: MagicMock):
+ scheduler = ModelJobGroupSchedulerRunner()
+ scheduler._create_model_job_group_for_participants(model_job_group_id=2)
+ algorithm_project_list = AlgorithmProjectList()
+ algorithm_project_list.algorithm_projects['test'] = 'algorithm-project-uuid1'
+ algorithm_project_list.algorithm_projects['peer'] = 'algorithm-project-uuid2'
+ mock_client.assert_called_with(name='group2',
+ uuid='uuid2',
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ dataset_uuid='dataset_uuid',
+ algorithm_project_list=algorithm_project_list)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(2)
+ self.assertEqual(group.status, GroupCreateStatus.SUCCEEDED)
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT, 'dataset with uuid is not found')
+ scheduler._create_model_job_group_for_participants(model_job_group_id=3)
+ mock_client.assert_called()
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(3)
+ self.assertEqual(group.status, GroupCreateStatus.FAILED)
+
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobGroupSchedulerRunner._create_model_job_group_for_participants')
+ def test_schedule_model_job_group(self, mock_create_model_job_group):
+ scheduler = ModelJobGroupSchedulerRunner()
+ scheduler._schedule_model_job_group()
+ mock_create_model_job_group.assert_has_calls(calls=[call(model_job_group_id=2), call(model_job_group_id=3)])
+
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobGroupSchedulerRunner._schedule_model_job_group')
+ def test_run(self, mock_schedule_model_job_group: MagicMock):
+ scheduler = ModelJobGroupSchedulerRunner()
+ runner_input = RunnerInput()
+ runner_context = RunnerContext(index=0, input=runner_input)
+ runner_status, runner_output = scheduler.run(runner_context)
+ mock_schedule_model_job_group.assert_called()
+ mock_schedule_model_job_group.reset_mock()
+ self.assertEqual(runner_output, RunnerOutput())
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+
+ def side_effect():
+ raise Exception('haha')
+
+ mock_schedule_model_job_group.side_effect = side_effect
+ scheduler = ModelJobGroupSchedulerRunner()
+ runner_status, runner_output = scheduler.run(runner_context)
+ mock_schedule_model_job_group.assert_called()
+ self.assertEqual(runner_output, RunnerOutput(error_message='haha'))
+ self.assertEqual(runner_status, RunnerStatus.FAILED)
+
+
+class ModelJobGroupLongPeriodSchedulerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1,
+ name='20220101-08',
+ dataset_id=1,
+ event_time=datetime(year=2000, month=1, day=1, hour=8),
+ latest_parent_dataset_job_stage_id=1)
+ group1 = ModelJobGroup(id=1,
+ name='group1',
+ role=ModelJobRole.COORDINATOR,
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE)
+ group2 = ModelJobGroup(id=2,
+ name='group2',
+ role=ModelJobRole.COORDINATOR,
+ auto_update_status=GroupAutoUpdateStatus.INITIAL)
+ group3 = ModelJobGroup(id=3,
+ name='group3',
+ role=ModelJobRole.PARTICIPANT,
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE)
+ group4 = ModelJobGroup(id=4,
+ name='group4',
+ role=ModelJobRole.COORDINATOR,
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE)
+ group5 = ModelJobGroup(id=5,
+ name='group5',
+ role=ModelJobRole.COORDINATOR,
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE,
+ latest_version=4)
+ model_job1 = ModelJob(id=1,
+ group_id=4,
+ auto_update=True,
+ created_at=datetime(2022, 12, 16, 1, 0, 0),
+ status=ModelJobStatus.SUCCEEDED,
+ data_batch_id=1)
+ model_job2 = ModelJob(id=2,
+ group_id=4,
+ auto_update=False,
+ created_at=datetime(2022, 12, 16, 2, 0, 0),
+ status=ModelJobStatus.SUCCEEDED)
+ model_job3 = ModelJob(id=3,
+ group_id=4,
+ auto_update=True,
+ created_at=datetime(2022, 12, 16, 3, 0, 0),
+ status=ModelJobStatus.RUNNING,
+ data_batch_id=1)
+ model_job4 = ModelJob(id=4,
+ group_id=5,
+ auto_update=True,
+ created_at=datetime(2022, 12, 16, 1, 0, 0),
+ status=ModelJobStatus.SUCCEEDED,
+ data_batch_id=1)
+ model_job5 = ModelJob(id=5,
+ group_id=5,
+ auto_update=False,
+ created_at=datetime(2022, 12, 16, 2, 0, 0),
+ status=ModelJobStatus.FAILED)
+ global_config = ModelJobGlobalConfig(
+ dataset_uuid='uuid',
+ global_config={
+ 'test1': ModelJobConfig(algorithm_uuid='uuid1', variables=[Variable(name='load_model_name')]),
+ 'test2': ModelJobConfig(algorithm_uuid='uuid2', variables=[Variable(name='load_model_name')])
+ })
+ model_job6 = ModelJob(id=6,
+ name='model-job6',
+ group_id=5,
+ auto_update=True,
+ created_at=datetime(2022, 12, 16, 3, 0, 0),
+ status=ModelJobStatus.SUCCEEDED,
+ data_batch_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ project_id=1,
+ comment='comment',
+ version=3)
+ model_job6.set_global_config(global_config)
+ session.add_all([
+ group1, group2, group3, group4, group5, model_job1, model_job2, model_job3, model_job4, model_job5,
+ model_job6, data_batch
+ ])
+ session.commit()
+
+ @patch('fedlearner_webconsole.mmgr.scheduler.resource_uuid')
+ @patch('fedlearner_webconsole.mmgr.service.ModelJobService.create_model_job')
+ @patch('fedlearner_webconsole.dataset.services.BatchService.get_next_batch')
+ def test_create_auto_update_model_job(self, mock_get_next_batch: MagicMock, mock_create_model_job: MagicMock,
+ mock_resource_uuid: MagicMock):
+ scheduler = ModelJobGroupLongPeriodScheduler()
+ # fail due to model job is None
+ scheduler._create_auto_update_model_job(model_job_group_id=1)
+ mock_create_model_job.assert_not_called()
+ # fail due to model job status is not SUCCEEDED
+ scheduler._create_auto_update_model_job(model_job_group_id=4)
+ mock_create_model_job.assert_not_called()
+ # fail due to next data batch is None
+ mock_get_next_batch.return_value = None
+ scheduler._create_auto_update_model_job(model_job_group_id=5)
+ mock_create_model_job.assert_not_called()
+ # create auto model job failed due to model_name is None
+ mock_get_next_batch.return_value = DataBatch(id=2)
+ mock_resource_uuid.return_value = 'uuid'
+ with self.assertRaises(Exception):
+ scheduler._create_auto_update_model_job(model_job_group_id=5)
+ with db.session_scope() as session:
+ model = Model(id=1, name='model-job6', uuid='uuid')
+ session.add(model)
+ session.commit()
+ scheduler._create_auto_update_model_job(model_job_group_id=5)
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='model-job6').first()
+ self.assertEqual(model_job.model_id, 1)
+ global_config = ModelJobGlobalConfig(
+ dataset_uuid='uuid',
+ global_config={
+ 'test1':
+ ModelJobConfig(algorithm_uuid='uuid1',
+ variables=[
+ Variable(name='load_model_name',
+ value='model-job6',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='model-job6'))
+ ]),
+ 'test2':
+ ModelJobConfig(algorithm_uuid='uuid2',
+ variables=[
+ Variable(name='load_model_name',
+ value='model-job6',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='model-job6'))
+ ])
+ })
+ mock_create_model_job.assert_called_with(name='group5-v5',
+ uuid='uuid',
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ group_id=5,
+ project_id=1,
+ data_batch_id=2,
+ comment='comment',
+ version=5)
+
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobGroupLongPeriodScheduler._create_auto_update_model_job')
+ def test_schedule_model_job_group(self, mock_create_auto_update_model_job: MagicMock):
+ scheduler = ModelJobGroupLongPeriodScheduler()
+ scheduler._schedule_model_job_group()
+ mock_create_auto_update_model_job.assert_has_calls(
+ calls=[call(model_job_group_id=1),
+ call(model_job_group_id=4),
+ call(model_job_group_id=5)])
+
+ @patch('fedlearner_webconsole.mmgr.scheduler.ModelJobGroupLongPeriodScheduler._schedule_model_job_group')
+ def test_run(self, mock_schedule_model_job_group: MagicMock):
+ scheduler = ModelJobGroupLongPeriodScheduler()
+ runner_input = RunnerInput()
+ runner_context = RunnerContext(index=0, input=runner_input)
+ runner_status, runner_output = scheduler.run(runner_context)
+ mock_schedule_model_job_group.assert_called()
+ mock_schedule_model_job_group.reset_mock()
+ self.assertEqual(runner_output, RunnerOutput())
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/service.py b/web_console_v2/api/fedlearner_webconsole/mmgr/service.py
index 2f811c169..55dcc175d 100644
--- a/web_console_v2/api/fedlearner_webconsole/mmgr/service.py
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/service.py
@@ -1,161 +1,474 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
-# coding: utf-8
-
-import os
-import json
import logging
-from fedlearner_webconsole.db import make_session_context
+from typing import Optional
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.exceptions import InvalidArgumentException, NotFoundException
+from fedlearner_webconsole.mmgr.metrics.metrics_inquirer import tree_metrics_inquirer, nn_metrics_inquirer
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, ModelTrainingCronJobInput
+from fedlearner_webconsole.proto.metrics_pb2 import ModelJobMetrics
from fedlearner_webconsole.job.metrics import JobMetricsBuilder
-from fedlearner_webconsole.job.models import Job, JobType, JobState, JobDefinition
-from fedlearner_webconsole.job.yaml_formatter import generate_job_run_yaml
-from fedlearner_webconsole.mmgr.models import Model, ModelType, ModelState
-from fedlearner_webconsole.utils.k8s_cache import Event, EventType, ObjectType
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.algorithm.models import AlgorithmType, Algorithm, AlgorithmProject, Source
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelType, ModelJobGroup, ModelJobType, ModelJobRole, \
+ ModelJobStatus, AuthStatus
+from fedlearner_webconsole.mmgr.utils import deleted_name
+from fedlearner_webconsole.mmgr.model_job_configer import get_sys_template_id, ModelJobConfiger
+from fedlearner_webconsole.mmgr.utils import get_job_path, build_workflow_name, \
+ is_model_job
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DataBatch
+from fedlearner_webconsole.composer.composer_service import CronJobService
+from fedlearner_webconsole.workflow.workflow_controller import create_ready_workflow
+from fedlearner_webconsole.workflow.service import CreateNewWorkflowParams, WorkflowService
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.utils.const import SYSTEM_WORKFLOW_CREATOR_USERNAME
+from fedlearner_webconsole.utils.base_model import auth_model
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, ModelJobConfig, AlgorithmProjectList
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.rpc.v2.job_service_client import JobServiceClient
-class ModelService:
+def get_project(project_id: int, session: Session) -> Project:
+ project = session.query(Project).get(project_id)
+ if project is None:
+ raise NotFoundException(f'project {project_id} is not found')
+ return project
+
+
+def get_dataset(dataset_id: int, session: Session) -> Dataset:
+ dataset = session.query(Dataset).get(dataset_id)
+ if dataset is None:
+ raise InvalidArgumentException(f'dataset {dataset_id} is not found')
+ return dataset
+
+
+def get_model_job(project_id: int, model_job_id: int, session: Session) -> ModelJob:
+ model_job = session.query(ModelJob).filter_by(id=model_job_id, project_id=project_id).first()
+ if model_job is None:
+ raise NotFoundException(f'[Model]model job {model_job_id} is not found')
+ return model_job
+
+
+def get_model_job_group(project_id: int, group_id: int, session: Session) -> ModelJobGroup:
+ query = session.query(ModelJobGroup).filter_by(id=group_id, project_id=project_id)
+ group = query.first()
+ if group is None:
+ raise NotFoundException(f'[Model]model group {group_id} is not found')
+ return group
+
+
+def check_model_job_group(project_id: int, group_id: int, sesseion: Session):
+ group = sesseion.query(ModelJobGroup).filter_by(id=group_id, project_id=project_id).first()
+ if group is None:
+ raise NotFoundException(f'[Model]model group {group_id} is not found')
+
+
+def get_participant(participant_id: int, project: Project) -> Participant:
+ for participant in project.participants:
+ if participant.id == participant_id:
+ return participant
+ raise NotFoundException(f'participant {participant_id} is not found')
+
+
+def get_model(project_id: int, model_id: int, session: Session) -> Model:
+ model = session.query(Model).filter_by(project_id=project_id, id=model_id).first()
+ if model is None:
+ raise NotFoundException(f'[Model]model {model_id} is not found')
+ return model
+
+
+def get_algorithm(project_id: int, algorithm_id: int, session: Session) -> Optional[Algorithm]:
+ query = session.query(Algorithm)
+ if project_id:
+ # query under project and preset algorithms with project_id as null
+ query = query.filter((Algorithm.project_id == project_id) | (Algorithm.source == Source.PRESET))
+ algo = query.filter_by(id=algorithm_id).first()
+ return algo
+
+
+class ModelJobService:
def __init__(self, session):
self._session = session
- job_type_map = {
- JobType.NN_MODEL_TRANINING: ModelType.NN_MODEL.value,
- JobType.NN_MODEL_EVALUATION: ModelType.NN_EVALUATION.value,
- JobType.TREE_MODEL_TRAINING: ModelType.TREE_MODEL.value,
- JobType.TREE_MODEL_EVALUATION: ModelType.TREE_EVALUATION.value
- }
-
- job_state_map = {
- JobState.STARTED: ModelState.RUNNING.value,
- JobState.COMPLETED: ModelState.SUCCEEDED.value,
- JobState.FAILED: ModelState.FAILED.value,
- JobState.STOPPED: ModelState.PAUSED.value,
- JobState.WAITING: ModelState.WAITING.value
- }
+ @staticmethod
+ def query_metrics(model_job: ModelJob, job: Optional[Job] = None) -> ModelJobMetrics:
+ job = job or model_job.job
+ builder = JobMetricsBuilder(job)
+ if model_job.algorithm_type == AlgorithmType.TREE_VERTICAL:
+ model_job_metrics = tree_metrics_inquirer.query(job, need_feature_importance=True)
+ if len(model_job_metrics.train) == 0 and len(model_job_metrics.eval) == 0:
+ # legacy metrics support
+ logging.info(f'use legacy tree model metrics, job name = {job.name}')
+ return builder.query_tree_metrics(need_feature_importance=True)
+ return model_job_metrics
+ if model_job.algorithm_type == AlgorithmType.NN_VERTICAL:
+ model_job_metrics = nn_metrics_inquirer.query(job)
+ if len(model_job_metrics.train) == 0 and len(model_job_metrics.eval) == 0:
+ # legacy metrics support
+ logging.info(f'use legacy nn model metrics, job name = {job.name}')
+ return builder.query_nn_metrics()
+ return model_job_metrics
+ if model_job.algorithm_type == AlgorithmType.NN_HORIZONTAL:
+ return builder.query_nn_metrics()
+ raise ValueError(f'invalid algorithm type {model_job.algorithm_type}')
@staticmethod
- def is_model_related_job(job):
- job_type = job.job_type
- if isinstance(job_type, int):
- job_type = JobType(job.job_type)
- return job_type in [
- JobType.NN_MODEL_TRANINING, JobType.NN_MODEL_EVALUATION,
- JobType.TREE_MODEL_TRAINING, JobType.TREE_MODEL_EVALUATION
- ]
-
- def k8s_watcher_hook(self, event: Event):
- logging.info('[ModelService][k8s_watcher_hook] %s %s: %s', event.obj_type, event.event_type, event.flapp_name)
- if event.obj_type == ObjectType.FLAPP and event.event_type in [
- EventType.MODIFIED, EventType.DELETED
- ]:
- job = self._session.query(Job).filter_by(
- name=event.flapp_name).one_or_none()
- if not job:
- return logging.warning('[ModelService][k8s_watcher_hook] job not found: %s', event.flapp_name)
- if self.is_model_related_job(job):
- self.on_job_update(job)
-
- def workflow_hook(self, job: Job):
- if self.is_model_related_job(job):
- self.create(job)
-
- def plot_metrics(self, model, job=None):
- try:
- return JobMetricsBuilder(job or model.job).plot_metrics()
- except Exception as e:
- return repr(e)
-
- def is_model_quiescence(self, state):
- return state in [
- ModelState.SUCCEEDED.value, ModelState.FAILED.value,
- ModelState.PAUSED.value
- ]
-
- def on_job_update(self, job: Job):
- logging.info('[ModelService][on_job_update] job name: %s', job.name)
- model = self._session.query(Model).filter_by(job_name=job.name).one()
- # see also `fedlearner_webconsole.job.models.Job.stop`
- if job.state in self.job_state_map:
- state = self.job_state_map[job.state]
+ def _get_job(workflow: Workflow) -> Optional[Job]:
+ for job in workflow.owned_jobs:
+ if is_model_job(job.job_type):
+ return job
+ return None
+
+ def _create_model_job_for_participants(self, model_job: ModelJob):
+ project = self._session.query(Project).get(model_job.project_id)
+ group = self._session.query(ModelJobGroup).get(model_job.group_id)
+ global_config = model_job.get_global_config()
+ for participant in project.participants:
+ client = JobServiceClient.from_project_and_participant(participant.domain_name, project.name)
+ try:
+ client.create_model_job(name=model_job.name,
+ uuid=model_job.uuid,
+ group_uuid=group.uuid,
+ model_job_type=model_job.model_job_type,
+ algorithm_type=model_job.algorithm_type,
+ global_config=global_config,
+ version=model_job.version)
+ logging.info(f'[ModelJob] model job {model_job.id} is ready')
+ except Exception as e: # pylint: disable=broad-except
+ logging.exception('[ModelJob] creating model job for participants failed')
+ raise Exception(f'[ModelJob] creating model job for participants failed with detail {str(e)}') from e
+
+ # TODO(hangweiqiang): ensure version is unique for training job under model job group
+ def create_model_job(self,
+ name: str,
+ uuid: str,
+ project_id: int,
+ role: ModelJobRole,
+ model_job_type: ModelJobType,
+ algorithm_type: AlgorithmType,
+ global_config: ModelJobGlobalConfig,
+ group_id: Optional[int] = None,
+ coordinator_id: Optional[int] = 0,
+ data_batch_id: Optional[int] = None,
+ version: Optional[int] = None,
+ comment: Optional[str] = None) -> ModelJob:
+ model_job = ModelJob(name=name,
+ uuid=uuid,
+ group_id=group_id,
+ project_id=project_id,
+ role=role,
+ model_job_type=model_job_type,
+ algorithm_type=algorithm_type,
+ coordinator_id=coordinator_id,
+ version=version,
+ comment=comment)
+ assert global_config.dataset_uuid != '', 'dataset uuid must not be empty'
+ dataset = self._session.query(Dataset).filter_by(uuid=global_config.dataset_uuid).first()
+ assert dataset is not None, f'dataset with uuid {global_config.dataset_uuid} is not found'
+ model_job.dataset_id = dataset.id
+ if data_batch_id is not None: # for auto update jobs
+ assert algorithm_type in [AlgorithmType.NN_VERTICAL],\
+ 'auto update is only supported for nn vertical train'
+ dataset_job: DatasetJob = self._session.query(DatasetJob).filter_by(output_dataset_id=dataset.id).first()
+ assert dataset_job.kind != DatasetJobKind.RSA_PSI_DATA_JOIN,\
+ 'auto update is not supported for RSA-PSI dataset'
+ data_batch: DataBatch = self._session.query(DataBatch).get(data_batch_id)
+ assert data_batch is not None, f'data batch {data_batch_id} is not found'
+ assert data_batch.is_available(), f'data batch {data_batch_id} is not available'
+ assert data_batch.latest_parent_dataset_job_stage is not None, 'dataset job stage with id is not found'
+ model_job.data_batch_id = data_batch_id
+ model_job.auto_update = True
+ if role in [ModelJobRole.COORDINATOR]:
+ global_config.dataset_job_stage_uuid = data_batch.latest_parent_dataset_job_stage.uuid
+ model_job.set_global_config(global_config)
+ self.initialize_auth_status(model_job)
+ # when model job type is eval or predict
+ if global_config.model_uuid != '':
+ model = self._session.query(Model).filter_by(uuid=global_config.model_uuid).first()
+ assert model is not None, f'model with uuid {global_config.model_uuid} is not found'
+ model_job.model_id = model.id
+ # add model's group id to model_job when eval and predict
+ model_job.group_id = model.group_id
+ pure_domain_name = SettingService(session=self._session).get_system_info().pure_domain_name
+ model_job_config: ModelJobConfig = global_config.global_config.get(pure_domain_name)
+ assert model_job_config is not None, f'model_job_config of self domain name {pure_domain_name} must not be None'
+ if model_job_config.algorithm_uuid != '':
+ algorithm = self._session.query(Algorithm).filter_by(uuid=model_job_config.algorithm_uuid).first()
+ # algorithm is none if algorithm_uuid points to a published algorithm at the peer platform
+ if algorithm is not None:
+ model_job.algorithm_id = algorithm.id
+ # no need create model job at participants when eval or predict horizontal model
+ if model_job_type in [ModelJobType.TRAINING] and role in [ModelJobRole.COORDINATOR]:
+ self._create_model_job_for_participants(model_job)
+ if model_job_type in [ModelJobType.EVALUATION, ModelJobType.PREDICTION] and algorithm_type not in [
+ AlgorithmType.NN_HORIZONTAL
+ ] and role in [ModelJobRole.COORDINATOR]:
+ self._create_model_job_for_participants(model_job)
+ self._session.add(model_job)
+ return model_job
+
+ def config_model_job(self,
+ model_job: ModelJob,
+ config: WorkflowDefinition,
+ create_workflow: bool,
+ need_to_create_ready_workflow: Optional[bool] = False,
+ workflow_uuid: Optional[str] = None):
+ workflow_name = build_workflow_name(model_job_type=model_job.model_job_type.name,
+ algorithm_type=model_job.algorithm_type.name,
+ model_job_name=model_job.name)
+ template_id = get_sys_template_id(self._session, model_job.algorithm_type, model_job.model_job_type)
+ if template_id is None:
+ raise ValueError(f'workflow template for {model_job.algorithm_type.name} not found')
+ workflow_comment = f'created by model_job {model_job.name}'
+ configer = ModelJobConfiger(session=self._session,
+ model_job_type=model_job.model_job_type,
+ algorithm_type=model_job.algorithm_type,
+ project_id=model_job.project_id)
+ configer.set_dataset(config=config, dataset_id=model_job.dataset_id, data_batch_id=model_job.data_batch_id)
+ if need_to_create_ready_workflow:
+ workflow = create_ready_workflow(
+ session=self._session,
+ name=workflow_name,
+ config=config,
+ project_id=model_job.project_id,
+ template_id=template_id,
+ uuid=workflow_uuid,
+ comment=workflow_comment,
+ )
+ elif create_workflow:
+ params = CreateNewWorkflowParams(project_id=model_job.project_id, template_id=template_id)
+ workflow = WorkflowService(self._session).create_workflow(name=workflow_name,
+ config=config,
+ params=params,
+ comment=workflow_comment,
+ uuid=workflow_uuid,
+ creator_username=SYSTEM_WORKFLOW_CREATOR_USERNAME)
else:
- return logging.warning(
- '[ModelService][on_job_update] job state is %s', job.state)
- if model.state != ModelState.RUNNING.value and state == ModelState.RUNNING.value:
- logging.info(
- '[ModelService][on_job_update] updating model(%d).version from %s to %s',
- model.id, model.version, model.version + 1)
- model.version += 1
- logging.info(
- '[ModelService][on_job_update] updating model(%d).state from %s to %s',
- model.id, model.state, state)
- if self.is_model_quiescence(state):
- model.metrics = json.dumps(self.plot_metrics(model, job))
- model.state = state
- self._session.add(model)
+ workflow = self._session.query(Workflow).filter_by(uuid=model_job.workflow_uuid).first()
+ if workflow is None:
+ raise ValueError(f'workflow with uuid {model_job.workflow_uuid} not found')
+ workflow = WorkflowService(self._session).config_workflow(workflow=workflow,
+ template_id=template_id,
+ config=config,
+ comment=workflow_comment,
+ creator_username=SYSTEM_WORKFLOW_CREATOR_USERNAME)
+ self._session.flush()
+ model_job.workflow_id = workflow.id
+ model_job.workflow_uuid = workflow.uuid
+ job = self._get_job(workflow)
+ assert job is not None, 'model job not found in workflow'
+ model_job.job_name = job.name
+ model_job.job_id = job.id
+ model_job.status = ModelJobStatus.CONFIGURED
+ self._session.flush()
+
+ def update_model_job_status(self, model_job: ModelJob):
+ workflow = self._session.query(Workflow).filter_by(uuid=model_job.workflow_uuid).first()
+ if workflow:
+ if workflow.state in [WorkflowState.RUNNING]:
+ model_job.status = ModelJobStatus.RUNNING
+ if workflow.state in [WorkflowState.STOPPED]:
+ model_job.status = ModelJobStatus.STOPPED
+ if workflow.state in [WorkflowState.COMPLETED]:
+ model_job.status = ModelJobStatus.SUCCEEDED
+ if workflow.state in [WorkflowState.FAILED]:
+ model_job.status = ModelJobStatus.FAILED
+
+ def initialize_auth_status(self, model_job: ModelJob):
+ pure_domain_name = SettingService(self._session).get_system_info().pure_domain_name
+ participants = ParticipantService(self._session).get_participants_by_project(model_job.project_id)
+ # 1. default all authorized when model job type is training
+ # 2. default all authorized when algorithm type is nn_horizontal and model job type is evaluation or prediction
+ # 3. set coordinator authorized when algorithm type is not nn_horizontal and model job type is evaluation or
+ # prediction
+ participants_info = ParticipantsInfo(participants_map={
+ p.pure_domain_name(): ParticipantInfo(auth_status=AuthStatus.PENDING.name) for p in participants
+ })
+ participants_info.participants_map[pure_domain_name].auth_status = AuthStatus.PENDING.name
+ if model_job.model_job_type in [ModelJobType.TRAINING
+ ] or model_job.algorithm_type in [AlgorithmType.NN_HORIZONTAL]:
+ participants_info = ParticipantsInfo(participants_map={
+ p.pure_domain_name(): ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name) for p in participants
+ })
+ participants_info.participants_map[pure_domain_name].auth_status = AuthStatus.AUTHORIZED.name
+ model_job.auth_status = AuthStatus.AUTHORIZED
+ elif model_job.role in [ModelJobRole.COORDINATOR]:
+ participants_info.participants_map[pure_domain_name].auth_status = AuthStatus.AUTHORIZED.name
+ model_job.auth_status = AuthStatus.AUTHORIZED
+ model_job.set_participants_info(participants_info)
+
+ @staticmethod
+ def update_model_job_auth_status(model_job: ModelJob, auth_status: AuthStatus):
+ model_job.auth_status = auth_status
+ participants_info = model_job.get_participants_info()
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ participants_info.participants_map[pure_domain_name].auth_status = auth_status.name
+ model_job.set_participants_info(participants_info)
+
+ def delete(self, job_id: int):
+ model_job: ModelJob = self._session.query(ModelJob).get(job_id)
+ model_job.deleted_at = now()
+ model_job.name = deleted_name(model_job.name)
+ if model_job.output_model is not None:
+ ModelService(self._session).delete(model_job.output_model.id)
- def create(self, job: Job, parent_job_name=None, group_id=0):
- logging.info('[ModelService][create] create model %s', job.name)
- model = Model()
- model.name = job.name # TODO allow rename by end-user
- model.type = self.job_type_map[job.job_type]
- model.state = ModelState.COMMITTING.value
- model.job_name = job.name
- if parent_job_name:
- parent = self._session.query(Model).filter_by(
- job_name=parent_job_name).one_or_none()
- if not parent:
- return parent
- model.version = parent.version
- model.parent_id = parent.id
- model.params = json.dumps({})
- model.group_id = group_id
- model.state = ModelState.COMMITTED.value
- self._session.add(model)
- self._session.commit()
- return model
- # `detail_level` is a comma separated string list
- # contains `metrics` if `plot_metrics` result is
- def query(self, model_id, detail_level=''):
- model = self._session.query(Model).filter_by(id=model_id).one_or_none()
- if not model:
- return model
- detail_level = detail_level.split(',')
- model_json = model.to_dict()
- model_json['detail_level'] = detail_level
- if 'metrics' in detail_level:
- if self.is_model_quiescence(model) and model.metrics:
- model_json['metrics'] = json.loads(model.metrics)
- else: model_json['metrics'] = self.plot_metrics(model)
- return model_json
-
- def drop(self, model_id):
- model = self._session.query(Model).filter_by(id=model_id).one_or_none()
- if not model:
- return model
- if model.state not in [
- ModelState.SUCCEEDED.value, ModelState.FAILED.value
- ]: # FIXME atomicity
- raise Exception(
- f'cannot delete model when model.state is {model.state}')
- # model.state = ModelState.DROPPING.value
- # TODO remove model files from NFS et al.
- model.state = ModelState.DROPPED.value
+class ModelService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def create_model_from_model_job(self, model_job: ModelJob):
+ name = f'{model_job.group.name}-v{model_job.version}'
+ model_type = ModelType.NN_MODEL
+ if model_job.algorithm_type == AlgorithmType.TREE_VERTICAL:
+ model_type = ModelType.TREE_MODEL
+ model = Model(name=name,
+ uuid=model_job.uuid,
+ version=model_job.version,
+ model_type=model_type,
+ algorithm_type=model_job.algorithm_type,
+ project_id=model_job.project_id,
+ job_id=model_job.job_id,
+ group_id=model_job.group_id,
+ model_job_id=model_job.id)
+ storage_root_dir = model_job.project.get_storage_root_path(None)
+ if storage_root_dir is None:
+ logging.warning(f'[ModelService] storage root of project {model_job.project.name} is None')
+ raise RuntimeError(f'storage root of project {model_job.project.name} is None')
+ model.model_path = get_job_path(storage_root_dir, model_job.job.name)
self._session.add(model)
- self._session.commit()
+
+ def delete(self, model_id: int):
+ model: Model = self._session.query(Model).get(model_id)
+ model.deleted_at = now()
+ model.name = deleted_name(model.name)
+
+
+class ModelJobGroupService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def launch_model_job(self, group: ModelJobGroup, name: str, uuid: str, version: int) -> ModelJob:
+ model_job = ModelJob(
+ name=name,
+ uuid=uuid,
+ group_id=group.id,
+ project_id=group.project_id,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=group.algorithm_type,
+ algorithm_id=group.algorithm_id,
+ dataset_id=group.dataset_id,
+ version=version,
+ )
+ self._session.add(model_job)
+ self._session.flush()
+ ModelJobService(self._session).config_model_job(model_job,
+ group.get_config(),
+ create_workflow=False,
+ need_to_create_ready_workflow=True,
+ workflow_uuid=model_job.uuid)
+ group.latest_version = version
+ self._session.flush()
+ return model_job
+
+ def delete(self, group_id: int):
+ group: ModelJobGroup = self._session.query(ModelJobGroup).get(group_id)
+ group.name = deleted_name(group.name)
+ group.deleted_at = now()
+ job_service = ModelJobService(self._session)
+ for job in group.model_jobs:
+ job_service.delete(job.id)
+
+ def lock_and_update_version(self, group_id: int) -> ModelJobGroup:
+ group: ModelJobGroup = self._session.query(ModelJobGroup).populate_existing().with_for_update().get(group_id)
+ # use exclusive lock to ensure version is unique and increasing.
+ # since 2PC has its own db transaction, and the latest_version of group should be updated in service,
+ # to avoid lock conflict, the latest_version is updated and lock is released,
+ # and the version is passed to 2PC transaction.
+ group.latest_version = group.latest_version + 1
+ return group
+
+ def update_cronjob_config(self, group: ModelJobGroup, cron_config: str):
+ """Update model training cron job config
+
+ Args:
+ group: group for updating cron config
+ cron_config: cancel cron job if cron config is empty string; create
+ or update cron job if cron config is valid
+ """
+ item_name = f'model_training_cron_job_{group.id}'
+ group.cron_config = cron_config
+ if cron_config:
+ runner_input = RunnerInput(model_training_cron_job_input=ModelTrainingCronJobInput(group_id=group.id))
+ items = [(ItemType.MODEL_TRAINING_CRON_JOB, runner_input)]
+ CronJobService(self._session).start_cronjob(item_name=item_name, items=items, cron_config=cron_config)
+ else:
+ CronJobService(self._session).stop_cronjob(item_name=item_name)
+
+ def create_group(self, name: str, uuid: str, project_id: int, role: ModelJobRole, dataset_id: int,
+ algorithm_type: AlgorithmType, algorithm_project_list: AlgorithmProjectList,
+ coordinator_id: int) -> ModelJobGroup:
+ dataset = self._session.query(Dataset).get(dataset_id)
+ assert dataset is not None, f'dataset with id {dataset_id} is not found'
+ group = ModelJobGroup(name=name,
+ uuid=uuid,
+ role=role,
+ project_id=project_id,
+ dataset_id=dataset_id,
+ algorithm_type=algorithm_type,
+ coordinator_id=coordinator_id)
+ group.set_algorithm_project_uuid_list(algorithm_project_list)
+ pure_domain_name = SettingService(session=self._session).get_system_info().pure_domain_name
+ algorithm_project_uuid = algorithm_project_list.algorithm_projects.get(pure_domain_name)
+ if algorithm_project_uuid is None and algorithm_type != AlgorithmType.TREE_VERTICAL:
+ raise Exception(f'algorithm project uuid must be given if algorithm type is {algorithm_type.name}')
+ if algorithm_project_uuid is not None:
+ algorithm_project = self._session.query(AlgorithmProject).filter_by(uuid=algorithm_project_uuid).first()
+ # algorithm project is none if uuid points to a published algorithm at the peer platform
+ if algorithm_project is not None:
+ group.algorithm_project_id = algorithm_project.id
+ self._session.add(group)
+ return group
+
+ def get_latest_model_from_model_group(self, model_group_id: int) -> Model:
+ model = self._session.query(Model).filter_by(group_id=model_group_id).order_by(Model.version.desc()).first()
+ if model is None:
+ raise InvalidArgumentException(f'model in group {model_group_id} is not found')
return model
- def get_checkpoint_path(self, job):
- return None
+ def initialize_auth_status(self, group: ModelJobGroup):
+ # set auth status map
+ pure_domain_name = SettingService(self._session).get_system_info().pure_domain_name
+ participants = ParticipantService(self._session).get_participants_by_project(group.project_id)
+ participants_info = ParticipantsInfo(participants_map={
+ p.pure_domain_name(): ParticipantInfo(auth_status=AuthStatus.PENDING.name) for p in participants
+ })
+ participants_info.participants_map[pure_domain_name].auth_status = AuthStatus.AUTHORIZED.name
+ group.set_participants_info(participants_info)
+ # compatible with older versions of auth status
+ group.authorized = True
+ group.auth_status = auth_model.AuthStatus.AUTHORIZED
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/service_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/service_test.py
new file mode 100644
index 000000000..0bcd6fe32
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/service_test.py
@@ -0,0 +1,805 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import unittest
+from unittest.mock import patch, Mock
+from google.protobuf.struct_pb2 import Value
+from google.protobuf.empty_pb2 import Empty
+
+from testing.common import NoWebServerTestCase
+from testing.rpc.client import FakeRpcError
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobGroup, ModelJobType, Model, ModelJobRole, \
+ ModelJobStatus, AuthStatus
+from fedlearner_webconsole.mmgr.service import ModelJobService, ModelJobGroupService, ModelService
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.algorithm.models import AlgorithmType, Algorithm, AlgorithmProject
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobState, DatasetJobKind, DatasetType, \
+ DataBatch, DatasetJobStage
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, ModelJobConfig, AlgorithmProjectList
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, ModelTrainingCronJobInput
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+
+
+class ModelJobServiceTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ dataset_job = DatasetJob(id=1,
+ name='datasetjob',
+ uuid='uuid',
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.OT_PSI_DATA_JOIN)
+ dataset = Dataset(id=2,
+ uuid='uuid',
+ name='datasetjob',
+ dataset_type=DatasetType.PSI,
+ path='/data/dataset/haha')
+ dataset_rsa = Dataset(id=3,
+ uuid='uuid_rsa',
+ name='dataset_rsa',
+ dataset_type=DatasetType.PSI,
+ path='/data/dataset/haha')
+ dataset_job_rsa = DatasetJob(id=2,
+ name='dataset_job_rsa',
+ uuid='uuid_rsa',
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=3,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN)
+ dataset_job_stage = DatasetJobStage(id=1,
+ uuid='uuid',
+ name='dataset_job_stage_1',
+ project_id=1,
+ dataset_job_id=1,
+ data_batch_id=1,
+ state=DatasetJobState.SUCCEEDED)
+ data_batch = DataBatch(id=1,
+ name='data_batch_1',
+ dataset_id=dataset.id,
+ latest_parent_dataset_job_stage_id=1)
+ model_job = ModelJob(name='test-model-job',
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.TREE_VERTICAL,
+ dataset_id=2,
+ model_id=1,
+ project_id=1,
+ workflow_uuid='test-uuid')
+ algorithm = Algorithm(id=1, name='algo', uuid='uuid', project_id=1)
+ project = Project(id=1, name='project')
+ model_job_group = ModelJobGroup(id=1, name='model_job_group', uuid='uuid')
+ participant1 = Participant(id=1, name='demo1', domain_name='fl-demo1.com')
+ participant2 = Participant(id=2, name='demo2', domain_name='fl-demo2.com')
+ project_part1 = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ project_part2 = ProjectParticipant(id=2, project_id=1, participant_id=2)
+ session.add_all([
+ model_job, dataset, dataset_job, algorithm, project, participant1, participant2, project_part1,
+ project_part2, model_job_group, dataset_rsa, dataset_job_rsa, dataset_job_stage, data_batch
+ ])
+ session.commit()
+
+ @staticmethod
+ def _get_workflow_config():
+ return WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[
+ Variable(name='mode', value='train'),
+ Variable(name='data_source'),
+ Variable(name='data_path'),
+ Variable(name='file_wildcard'),
+ ],
+ yaml_template='{}')
+ ])
+
+ def test_config_model_job_create_workflow(self):
+ config = self._get_workflow_config()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ ModelJobService(session).config_model_job(model_job,
+ config=config,
+ create_workflow=True,
+ workflow_uuid='test-uuid')
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ workflow = session.query(Workflow).filter_by(uuid='test-uuid').first()
+ self.assertEqual(workflow.creator, 's_y_s_t_e_m')
+ self.assertEqual(model_job.job_name, 'test-uuid-train-job')
+ self.assertEqual(model_job.job_id, workflow.owned_jobs[0].id)
+ self.assertEqual(model_job.workflow.template.name, 'sys-preset-tree-model')
+ self.assertEqual(
+ model_job.workflow.get_config(),
+ WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[
+ Variable(name='mode', value='train'),
+ Variable(name='data_source',
+ value='',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='')),
+ Variable(name='data_path',
+ value='/data/dataset/haha/batch',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='/data/dataset/haha/batch')),
+ Variable(name='file_wildcard',
+ value='**/part*',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='**/part*')),
+ ],
+ yaml_template='{}')
+ ]))
+
+ def test_config_model_job_not_create_workflow(self):
+ config = self._get_workflow_config()
+ with db.session_scope() as session:
+ workflow = Workflow(name='test-workflow', uuid='test-uuid', state=WorkflowState.NEW, project_id=1)
+ session.add(workflow)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ ModelJobService(session).config_model_job(model_job,
+ config=config,
+ create_workflow=False,
+ workflow_uuid='test-uuid')
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ workflow = session.query(Workflow).filter_by(name='test-workflow').first()
+ self.assertEqual(workflow.creator, 's_y_s_t_e_m')
+ self.assertEqual(model_job.job_name, 'test-uuid-train-job')
+ self.assertEqual(model_job.job_id, workflow.owned_jobs[0].id)
+ self.assertEqual(model_job.workflow.template.name, 'sys-preset-tree-model')
+ self.assertEqual(
+ model_job.workflow.get_config(),
+ WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.TREE_MODEL_TRAINING,
+ variables=[
+ Variable(name='mode', value='train'),
+ Variable(name='data_source',
+ value='',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='')),
+ Variable(name='data_path',
+ value='/data/dataset/haha/batch',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='/data/dataset/haha/batch')),
+ Variable(name='file_wildcard',
+ value='**/part*',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='**/part*')),
+ ],
+ yaml_template='{}')
+ ]))
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_model_job')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_create_model_job(self, mock_get_system_info, mock_create_model_job):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='test')
+ # fail due to dataset uuid is None
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig()
+ with self.assertRaises(AssertionError, msg='dataset uuid must not be None'):
+ service.create_model_job(name='name',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config)
+ # fail due to dataset is not found
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid1')
+ with self.assertRaises(AssertionError, msg='dataset with uuid uuid1 is not found'):
+ service.create_model_job(name='name',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config)
+ # fail due to domain name in model_job_config is None
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid')
+ with self.assertRaises(AssertionError, msg='model_job_config of self domain name test must not be None'):
+ service.create_model_job(name='name',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config)
+ # create model job when role is participant and model_job_type is TRAINING
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={'test': ModelJobConfig(algorithm_uuid='uuid')})
+ service.create_model_job(name='model_job_1',
+ uuid='uuid-1',
+ group_id=2,
+ project_id=3,
+ coordinator_id=1,
+ role=ModelJobRole.PARTICIPANT,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ version=3)
+ session.commit()
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='model_job_1').first()
+ self.assertEqual(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(model_job.dataset_id, 2)
+ self.assertEqual(model_job.get_global_config(), global_config)
+ self.assertEqual(model_job.algorithm_id, 1)
+ self.assertEqual(model_job.version, 3)
+ self.assertEqual(model_job.group_id, 2)
+ self.assertEqual(model_job.project_id, 3)
+ self.assertEqual(model_job.coordinator_id, 1)
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(participants_map={'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)}))
+ self.assertEqual(model_job.status, ModelJobStatus.PENDING)
+ self.assertEqual(model_job.auto_update, False)
+ # create model job when role is coordinator and model_job_type is EVALUATION
+ with db.session_scope() as session:
+ service = ModelJobService(session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={
+ 'test': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo1': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo2': ModelJobConfig(algorithm_uuid='uuid')
+ })
+ mock_create_model_job.side_effect = [Empty(), Empty()]
+ service.create_model_job(name='model_job_2',
+ uuid='uuid-2',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.EVALUATION,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ version=3)
+ session.commit()
+ mock_create_model_job.assert_called()
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='model_job_2').first()
+ self.assertEqual(model_job.model_job_type, ModelJobType.EVALUATION)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ }))
+ self.assertEqual(model_job.status, ModelJobStatus.PENDING)
+ self.assertEqual(model_job.coordinator_id, 0)
+ self.assertEqual(model_job.auto_update, False)
+ # create model job when role is participant and model_job_type is EVALUATION
+ mock_create_model_job.reset_mock()
+ with db.session_scope() as session:
+ service = ModelJobService(session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={
+ 'test': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo1': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo2': ModelJobConfig(algorithm_uuid='uuid')
+ })
+ mock_create_model_job.side_effect = [Empty(), Empty()]
+ service.create_model_job(name='model_job_5',
+ uuid='uuid-5',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.PARTICIPANT,
+ model_job_type=ModelJobType.EVALUATION,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ version=3)
+ session.commit()
+ mock_create_model_job.assert_not_called()
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='model_job_5').first()
+ self.assertEqual(model_job.auth_status, AuthStatus.PENDING)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ }))
+ # create eval horizontal model job when role is coordinator
+ mock_create_model_job.reset_mock()
+ with db.session_scope() as session:
+ service = ModelJobService(session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={
+ 'test': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo1': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo2': ModelJobConfig(algorithm_uuid='uuid')
+ })
+ service.create_model_job(name='model_job_3',
+ uuid='uuid-3',
+ project_id=1,
+ group_id=None,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.EVALUATION,
+ algorithm_type=AlgorithmType.NN_HORIZONTAL,
+ global_config=global_config,
+ version=4)
+ session.commit()
+ mock_create_model_job.assert_not_called()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='model_job_3').first()
+ self.assertEqual(model_job.model_job_type, ModelJobType.EVALUATION)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_HORIZONTAL)
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ # create predict horizontal model job when role is coordinator
+ mock_create_model_job.reset_mock()
+ with db.session_scope() as session:
+ service = ModelJobService(session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={
+ 'test': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo1': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo2': ModelJobConfig(algorithm_uuid='uuid')
+ })
+ service.create_model_job(name='model_job_4',
+ uuid='uuid-4',
+ project_id=1,
+ group_id=None,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.PREDICTION,
+ algorithm_type=AlgorithmType.NN_HORIZONTAL,
+ global_config=global_config,
+ version=4)
+ session.commit()
+ mock_create_model_job.assert_not_called()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='model_job_4').first()
+ self.assertEqual(model_job.model_job_type, ModelJobType.PREDICTION)
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_HORIZONTAL)
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ # fail due to grpc error
+ with db.session_scope() as session:
+ service = ModelJobService(session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={
+ 'test': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo1': ModelJobConfig(algorithm_uuid='uuid'),
+ 'demo2': ModelJobConfig(algorithm_uuid='uuid')
+ })
+ mock_create_model_job.side_effect = [
+ Empty(), FakeRpcError(grpc.StatusCode.UNIMPLEMENTED, 'rpc not implemented')
+ ]
+ with self.assertRaises(Exception):
+ service.create_model_job(name='model_job_2',
+ uuid='uuid-2',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ version=3)
+
+ def test_update_model_job_status(self):
+ with db.session_scope() as session:
+ workflow = Workflow(id=1, uuid='test-uuid', state=WorkflowState.NEW)
+ session.add(workflow)
+ session.commit()
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ ModelJobService(session).update_model_job_status(model_job)
+ self.assertEqual(model_job.status, ModelJobStatus.PENDING)
+ workflow = session.query(Workflow).filter_by(uuid='test-uuid').first()
+ workflow.state = WorkflowState.RUNNING
+ ModelJobService(session).update_model_job_status(model_job)
+ self.assertEqual(model_job.status, ModelJobStatus.RUNNING)
+ workflow.state = WorkflowState.STOPPED
+ ModelJobService(session).update_model_job_status(model_job)
+ self.assertEqual(model_job.status, ModelJobStatus.STOPPED)
+ workflow.state = WorkflowState.COMPLETED
+ ModelJobService(session).update_model_job_status(model_job)
+ self.assertEqual(model_job.status, ModelJobStatus.SUCCEEDED)
+ workflow.state = WorkflowState.FAILED
+ ModelJobService(session).update_model_job_status(model_job)
+ self.assertEqual(model_job.status, ModelJobStatus.FAILED)
+
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ def test_initialize_auth_status(self, mock_system_info):
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test', name='name')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ ModelJobService(session).initialize_auth_status(model_job)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ }))
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test', name='name')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ model_job.algorithm_type = AlgorithmType.NN_HORIZONTAL
+ ModelJobService(session).initialize_auth_status(model_job)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ }))
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test', name='name')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ model_job.algorithm_type = AlgorithmType.TREE_VERTICAL
+ model_job.model_job_type = ModelJobType.EVALUATION
+ ModelJobService(session).initialize_auth_status(model_job)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ }))
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test', name='name')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ model_job.role = ModelJobRole.COORDINATOR
+ ModelJobService(session).initialize_auth_status(model_job)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'demo1': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo2': ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ }))
+
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ def test_update_model_job_auth_status(self, mock_get_system_info):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='test')
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ ModelJobService.update_model_job_auth_status(model_job, AuthStatus.AUTHORIZED)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ self.assertEqual(model_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(participants_map={'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)}))
+ ModelJobService.update_model_job_auth_status(model_job, AuthStatus.PENDING)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).filter_by(name='test-model-job').first()
+ self.assertEqual(model_job.auth_status, AuthStatus.PENDING)
+ self.assertEqual(
+ model_job.get_participants_info(),
+ ParticipantsInfo(participants_map={'test': ParticipantInfo(auth_status=AuthStatus.PENDING.name)}))
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_model_job')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_create_auto_update_model_job(self, mock_get_system_info, mock_create_model_job):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='test')
+ # fail due to algorithm type not supported
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid')
+ with self.assertRaises(AssertionError, msg='auto update is only supported for nn vertical train'):
+ service.create_model_job(name='name',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_HORIZONTAL,
+ global_config=global_config,
+ data_batch_id=1)
+ # fail due to dataset job type not supported
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid_rsa')
+ with self.assertRaises(AssertionError, msg='auto update is not supported for RSA-PSI dataset'):
+ service.create_model_job(name='name',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ data_batch_id=1)
+ # fail due to data batch is not found
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid')
+ with self.assertRaises(AssertionError, msg='data batch 2 is not found'):
+ service.create_model_job(name='name',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ data_batch_id=2)
+ # create success
+ with db.session_scope() as session:
+ service = ModelJobService(session=session)
+ global_config = ModelJobGlobalConfig(dataset_uuid='uuid',
+ global_config={'test': ModelJobConfig(algorithm_uuid='uuid')})
+ service.create_model_job(name='model_job_1',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=global_config,
+ data_batch_id=1)
+ session.commit()
+ mock_create_model_job.assert_called()
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='model_job_1').first()
+ self.assertEqual(model_job.data_batch_id, 1)
+ self.assertEqual(model_job.auto_update, True)
+ global_config.dataset_job_stage_uuid = 'uuid'
+ self.assertEqual(model_job.get_global_config(), global_config)
+
+
+class ModelServiceTest(NoWebServerTestCase):
+
+ _MODEL_NAME = 'test-model'
+ _PROJECT_ID = 123
+ _GROUP_ID = 123
+ _MODEL_JOB_ID = 123
+ _JOB_ID = 123
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=self._PROJECT_ID, name='test-project')
+ session.add(project)
+ session.flush()
+ workflow = Workflow(name='test-workflow', project_id=project.id)
+ session.add(workflow)
+ session.flush()
+ workflow = Workflow(id=1, name='workflow', uuid='uuid', project_id=project.id)
+ job = Job(id=self._JOB_ID,
+ name='uuid-nn-model',
+ project_id=project.id,
+ job_type=JobType.NN_MODEL_TRANINING,
+ state=JobState.COMPLETED,
+ workflow_id=workflow.id)
+ job.set_config(JobDefinition(name='nn-model'))
+ session.add(job)
+ group = ModelJobGroup(id=self._GROUP_ID, name='test-group', project_id=project.id)
+ session.add(group)
+ session.flush()
+ model_job = ModelJob(id=self._MODEL_JOB_ID,
+ name='test-model-job',
+ uuid='test-uuid',
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ model_job_type=ModelJobType.NN_TRAINING,
+ group_id=group.id,
+ project_id=project.id,
+ job_name=job.name,
+ job_id=job.id,
+ version=2)
+ session.add(model_job)
+ session.commit()
+
+ @patch('fedlearner_webconsole.project.models.Project.get_storage_root_path')
+ def test_create_model_from_model_job(self, mock_get_storage_root_path):
+ mock_get_storage_root_path.return_value = '/data'
+ with db.session_scope() as session:
+ service = ModelService(session)
+ job = session.query(Job).get(self._JOB_ID)
+ model_job = session.query(ModelJob).get(self._MODEL_JOB_ID)
+ service.create_model_from_model_job(model_job=model_job)
+ session.commit()
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(self._MODEL_JOB_ID)
+ model: Model = session.query(Model).filter_by(uuid=model_job.uuid).first()
+ self.assertEqual(model.name, 'test-group-v2')
+ self.assertEqual(model.job_id, job.id)
+ self.assertEqual(model.project_id, self._PROJECT_ID)
+ self.assertEqual(model.model_path, '/data/job_output/uuid-nn-model')
+ self.assertEqual(model.model_job_id, model_job.id)
+ self.assertEqual(model.group_id, model_job.group_id)
+
+ mock_get_storage_root_path.return_value = None
+ with self.assertRaises(RuntimeError, msg='storage root of project test-project is None') as cm:
+ with db.session_scope() as session:
+ service = ModelService(session)
+ model_job = session.query(ModelJob).get(self._MODEL_JOB_ID)
+ service.create_model_from_model_job(model_job=model_job)
+
+
+class ModelJobGroupServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ dataset = Dataset(id=1, name='name', uuid='uuid')
+ project = Project(id=1, name='project')
+ group = ModelJobGroup(id=1, name='group', project_id=1)
+ algorithm_project = AlgorithmProject(id=1, name='name', uuid='algo-uuid')
+ session.add_all([dataset, project, group, algorithm_project])
+ session.commit()
+
+ @patch('fedlearner_webconsole.composer.composer_service.CronJobService.start_cronjob')
+ @patch('fedlearner_webconsole.composer.composer_service.CronJobService.stop_cronjob')
+ def test_update_cronjob_config(self, mock_stop_cronjob: Mock, mock_start_cronjob: Mock):
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).get(1)
+ ModelJobGroupService(session).update_cronjob_config(group, '')
+ self.assertEqual(group.cron_config, '')
+ mock_start_cronjob.assert_not_called()
+ mock_stop_cronjob.assert_called_once_with(item_name='model_training_cron_job_1')
+ mock_stop_cronjob.reset_mock()
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ ModelJobGroupService(session).update_cronjob_config(group, '*/10 * * * *')
+ self.assertEqual(group.cron_config, '*/10 * * * *')
+ mock_start_cronjob.assert_called_once_with(
+ item_name='model_training_cron_job_1',
+ items=[(ItemType.MODEL_TRAINING_CRON_JOB,
+ RunnerInput(model_training_cron_job_input=ModelTrainingCronJobInput(group_id=1)))],
+ cron_config='*/10 * * * *')
+ mock_stop_cronjob.assert_not_called()
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda _: SystemInfo(pure_domain_name='test'))
+ def test_create_group(self):
+ with db.session_scope() as session:
+ service = ModelJobGroupService(session)
+ with self.assertRaises(AssertionError, msg='dataset with id 2 is not found'):
+ service.create_group(name='name',
+ uuid='uuid',
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ dataset_id=2,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_project_list=AlgorithmProjectList(),
+ coordinator_id=1)
+ with db.session_scope() as session:
+ service = ModelJobGroupService(session)
+ with self.assertRaises(Exception, msg='algorithm project must be given if algorithm type is NN_VERTICAL'):
+ service.create_group(name='name',
+ uuid='uuid',
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ dataset_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_project_list=AlgorithmProjectList(),
+ coordinator_id=1)
+ with db.session_scope() as session:
+ service = ModelJobGroupService(session)
+ service.create_group(name='name',
+ uuid='uuid',
+ project_id=1,
+ role=ModelJobRole.COORDINATOR,
+ dataset_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_project_list=AlgorithmProjectList(algorithm_projects={'test': 'algo-uuid'}),
+ coordinator_id=1)
+ session.commit()
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(name='name').first()
+ self.assertEqual(group.name, 'name')
+ self.assertEqual(group.project_id, 1)
+ self.assertEqual(group.role, ModelJobRole.COORDINATOR)
+ self.assertEqual(group.dataset_id, 1)
+ self.assertEqual(group.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(group.algorithm_project_id, 1)
+ self.assertEqual(group.get_algorithm_project_uuid_list(),
+ AlgorithmProjectList(algorithm_projects={'test': 'algo-uuid'}))
+ self.assertEqual(group.coordinator_id, 1)
+
+ def test_get_latest_model_from_model_group(self):
+ with db.session_scope() as session:
+ model_1 = Model()
+ model_1.name = 'test_model_name_1'
+ model_1.project_id = 1
+ model_1.version = 1
+ model_1.group_id = 1
+ model_2 = Model()
+ model_2.name = 'test_model_name_2'
+ model_2.project_id = 1
+ model_2.version = 2
+ model_2.group_id = 1
+ session.add_all([model_1, model_2])
+ session.commit()
+ with db.session_scope() as session:
+ service = ModelJobGroupService(session)
+ model = service.get_latest_model_from_model_group(1)
+ self.assertEqual('test_model_name_2', model.name)
+
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ def test_initialize_auth_status(self, mock_system_info):
+ mock_system_info.return_value = SystemInfo(pure_domain_name='test', name='name')
+ with db.session_scope() as session:
+ participant = Participant(id=1, name='party', domain_name='fl-peer.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ session.add_all([participant, relationship])
+ session.commit()
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ ModelJobGroupService(session).initialize_auth_status(group)
+ session.commit()
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ self.assertEqual(
+ group.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'peer': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ }))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/utils.py b/web_console_v2/api/fedlearner_webconsole/mmgr/utils.py
new file mode 100644
index 000000000..7bc41fca3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/utils.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from uuid import uuid4
+
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.job.models import JobType
+
+
+def get_job_path(storage_root_path: str, job_name: str) -> str:
+ return os.path.join(storage_root_path, 'job_output', job_name)
+
+
+def get_exported_model_path(job_path: str) -> str:
+ return os.path.join(job_path, 'exported_models')
+
+
+def get_checkpoint_path(job_path: str) -> str:
+ return os.path.join(job_path, 'checkpoints')
+
+
+def get_output_path(job_path: str) -> str:
+ return os.path.join(job_path, 'outputs')
+
+
+def exported_model_version_path(exported_models_path, version: int):
+ return os.path.join(exported_models_path, str(version))
+
+
+def get_model_path(storage_root_path: str, uuid: str) -> str:
+ return os.path.join(storage_root_path, 'model_output', uuid)
+
+
+def build_workflow_name(model_job_type: str, algorithm_type: str, model_job_name: str) -> str:
+ prefix = f'{model_job_type.lower()}-{algorithm_type.lower()}-{model_job_name}'
+ # since the length of workflow name is limited to 255, the length of prefix should be less than 249
+ return f'{prefix[:249]}-{uuid4().hex[:5]}'
+
+
+def is_model_job(job_type: JobType):
+ return job_type in [
+ JobType.NN_MODEL_TRANINING, JobType.NN_MODEL_EVALUATION, JobType.TREE_MODEL_TRAINING,
+ JobType.TREE_MODEL_EVALUATION
+ ]
+
+
+def deleted_name(name: str):
+ """Rename the deleted model job, model, group due to unique constraint on name"""
+ timestamp = now().strftime('%Y%m%d_%H%M%S')
+ return f'deleted_at_{timestamp}_{name}'
diff --git a/web_console_v2/api/fedlearner_webconsole/mmgr/utils_test.py b/web_console_v2/api/fedlearner_webconsole/mmgr/utils_test.py
new file mode 100644
index 000000000..e17b9d861
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/mmgr/utils_test.py
@@ -0,0 +1,61 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from fedlearner_webconsole.job.models import JobType
+from fedlearner_webconsole.mmgr.utils import get_exported_model_path, get_job_path, get_checkpoint_path, \
+ exported_model_version_path, is_model_job
+
+
+class UtilsTest(unittest.TestCase):
+
+ def test_is_model_job(self):
+ self.assertFalse(is_model_job(job_type=JobType.TRANSFORMER))
+ self.assertFalse(is_model_job(job_type=JobType.RAW_DATA))
+ self.assertFalse(is_model_job(job_type=JobType.DATA_JOIN))
+ self.assertFalse(is_model_job(job_type=JobType.PSI_DATA_JOIN))
+ self.assertTrue(is_model_job(job_type=JobType.NN_MODEL_TRANINING))
+ self.assertTrue(is_model_job(job_type=JobType.NN_MODEL_EVALUATION))
+ self.assertTrue(is_model_job(job_type=JobType.TREE_MODEL_TRAINING))
+ self.assertTrue(is_model_job(job_type=JobType.TREE_MODEL_EVALUATION))
+
+ def test_get_job_path(self):
+ storage_root_path = '/data'
+ job_name = 'train_job'
+ job_path = get_job_path(storage_root_path, job_name)
+ exported_path = f'{storage_root_path}/job_output/{job_name}'
+ self.assertEqual(job_path, exported_path)
+
+ def test_get_exported_model_path(self):
+ job_path = '/data/job_output/train_job'
+ exported_model_path = get_exported_model_path(job_path)
+ expected_path = f'{job_path}/exported_models'
+ self.assertEqual(exported_model_path, expected_path)
+
+ def test_get_checkpoint_path(self):
+ job_path = '/data/job_output/train_job'
+ checkpoint_path = get_checkpoint_path(job_path)
+ expected_path = f'{job_path}/checkpoints'
+ self.assertEqual(checkpoint_path, expected_path)
+
+ def test_exported_model_version_path(self):
+ exported_model_path = '/data/model_output/uuid'
+ exported_model_path_v1 = exported_model_version_path(exported_model_path, 1)
+ expected_path = f'{exported_model_path}/1'
+ self.assertEqual(exported_model_path_v1, expected_path)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/notification/BUILD.bazel
new file mode 100644
index 000000000..e4ae7249a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/notification/BUILD.bazel
@@ -0,0 +1,58 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "notification_lib",
+ srcs = [
+ "email.py",
+ "sender.py",
+ "template.py",
+ ],
+ imports = ["../.."],
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "email_test",
+ srcs = [
+ "email_test.py",
+ ],
+ imports = ["../.."],
+ main = "email_test.py",
+ deps = [
+ ":notification_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "sender_test",
+ srcs = [
+ "sender_test.py",
+ ],
+ imports = ["../.."],
+ main = "sender_test.py",
+ deps = [
+ ":notification_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "template_test",
+ srcs = [
+ "template_test.py",
+ ],
+ imports = ["../.."],
+ main = "template_test.py",
+ deps = [
+ ":notification_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/__init__.py b/web_console_v2/api/fedlearner_webconsole/notification/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/email.py b/web_console_v2/api/fedlearner_webconsole/notification/email.py
new file mode 100644
index 000000000..09d042121
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/notification/email.py
@@ -0,0 +1,25 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from fedlearner_webconsole.notification.sender import send
+from fedlearner_webconsole.notification.template import NotificationTemplateName, render
+
+
+def send_email(address: str, template_name: NotificationTemplateName, **kwargs):
+ notification = render(template_name, **kwargs)
+ # TODO(linfan.fine): validate the email address
+ if address:
+ notification.receivers.append(address)
+ send(notification)
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/email_test.py b/web_console_v2/api/fedlearner_webconsole/notification/email_test.py
new file mode 100644
index 000000000..b35f6f334
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/notification/email_test.py
@@ -0,0 +1,39 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, Mock
+
+from fedlearner_webconsole.notification.email import send_email
+from fedlearner_webconsole.notification.template import NotificationTemplateName
+from fedlearner_webconsole.proto.notification_pb2 import Notification
+
+
+class EmailTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.notification.email.render')
+ @patch('fedlearner_webconsole.notification.email.send')
+ def test_send_email(self, mock_send: Mock, mock_render: Mock):
+ subject = 'test_subject'
+ content = 'test_content'
+ address = 'a@b.com'
+ mock_render.return_value = Notification(subject=subject, content=content)
+ send_email(address, NotificationTemplateName.WORKFLOW_COMPLETE, var1='aaa', var2='bbb')
+ mock_send.assert_called_once_with(Notification(subject=subject, content=content, receivers=[address]))
+ mock_render.assert_called_once_with(NotificationTemplateName.WORKFLOW_COMPLETE, var1='aaa', var2='bbb')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/sender.py b/web_console_v2/api/fedlearner_webconsole/notification/sender.py
new file mode 100644
index 000000000..ceb1440ce
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/notification/sender.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import logging
+
+from fedlearner_webconsole.proto.notification_pb2 import Notification
+
+
+class Sender(metaclass=abc.ABCMeta):
+
+ @abc.abstractmethod
+ def send(self, notification: Notification):
+ """Sends notification by third-party services."""
+
+
+senders = {}
+
+
+def register_sender(name: str, sender: Sender):
+ senders[name] = sender
+
+
+def send(notification: Notification):
+ """Sends a notification.
+
+ Invoking senders directly while there is no performance concerns as of now.
+ In the future, it should be sent to a queue, and we can use consumer-producers pattern
+ to send those notifications asynchronously."""
+ if not senders:
+ logging.info('[Notification] no sender for %s', notification.subject)
+ return
+ for name, sender in senders.items():
+ try:
+ sender.send(notification)
+ logging.info('[Notification] %s sent by %s', notification.subject, name)
+ except Exception: # pylint: disable=broad-except
+ logging.exception('[Notification] sender %s failed to send %s', name, notification.subject)
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/sender_test.py b/web_console_v2/api/fedlearner_webconsole/notification/sender_test.py
new file mode 100644
index 000000000..d58b2b167
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/notification/sender_test.py
@@ -0,0 +1,41 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock
+
+from fedlearner_webconsole.notification.sender import register_sender, send
+from fedlearner_webconsole.proto.notification_pb2 import Notification
+
+
+class SenderTest(unittest.TestCase):
+
+ def test_send(self):
+ mock_sender = MagicMock()
+ mock_sender.send = MagicMock()
+ register_sender('mock_sender', mock_sender)
+
+ notification = Notification(subject='test subject', content='test content', receivers=[])
+ send(notification)
+ mock_sender.send.assert_called_once_with(notification)
+
+ def test_send_with_no_sender(self):
+ notification = Notification(subject='test subject', content='test content', receivers=[])
+ # No exception is expected
+ send(notification)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/template.py b/web_console_v2/api/fedlearner_webconsole/notification/template.py
new file mode 100644
index 000000000..3f6068d07
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/notification/template.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+from string import Template
+from typing import NamedTuple
+
+from fedlearner_webconsole.proto.notification_pb2 import Notification
+
+
+class NotificationTemplate(NamedTuple):
+ subject: Template
+ content: Template
+
+
+class NotificationTemplateName(enum.Enum):
+ WORKFLOW_COMPLETE = 'WORKFLOW_COMPLETE'
+
+
+_UNKNOWN_TEMPLATE = NotificationTemplate(
+ subject=Template('Unknown email'),
+ content=Template(''),
+)
+
+_WORKFLOW_COMPLETE_TEMPLATE = NotificationTemplate(
+ subject=Template('【隐私计算平台】工作流「${name}」- 运行结束 - ${state}'),
+ content=Template('「工作流中心」:工作流「${name}」- 运行结束 - ${state},详情请见:${link}'),
+)
+
+TEMPLATES = {NotificationTemplateName.WORKFLOW_COMPLETE: _WORKFLOW_COMPLETE_TEMPLATE}
+
+
+def render(template_name: NotificationTemplateName, **kwargs) -> Notification:
+ template = TEMPLATES.get(template_name, _UNKNOWN_TEMPLATE)
+ return Notification(
+ subject=template.subject.safe_substitute(kwargs),
+ content=template.content.safe_substitute(kwargs),
+ )
diff --git a/web_console_v2/api/fedlearner_webconsole/notification/template_test.py b/web_console_v2/api/fedlearner_webconsole/notification/template_test.py
new file mode 100644
index 000000000..7c79788d8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/notification/template_test.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.notification.template import render, NotificationTemplateName
+from fedlearner_webconsole.proto.notification_pb2 import Notification
+
+
+class TemplateTest(unittest.TestCase):
+
+ def test_render(self):
+ email = render(NotificationTemplateName.WORKFLOW_COMPLETE,
+ name='test workflow',
+ state='FAILED',
+ link='www.a.com')
+ self.assertEqual(
+ email,
+ Notification(subject='【隐私计算平台】工作流「test workflow」- 运行结束 - FAILED',
+ content='「工作流中心」:工作流「test workflow」- 运行结束 - FAILED,详情请见:www.a.com'))
+ # some variables are not passed
+ email = render(NotificationTemplateName.WORKFLOW_COMPLETE, name='test workflow', unknown_var='123')
+ self.assertEqual(
+ email,
+ Notification(subject='【隐私计算平台】工作流「test workflow」- 运行结束 - ${state}',
+ content='「工作流中心」:工作流「test workflow」- 运行结束 - ${state},详情请见:${link}'))
+
+ def test_render_unknown(self):
+ self.assertEqual(render('unknown template', hello=123), Notification(subject='Unknown email',))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/participant/BUILD.bazel
new file mode 100644
index 000000000..1daa06058
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/BUILD.bazel
@@ -0,0 +1,128 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "k8s_utils_lib",
+ srcs = ["k8s_utils.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ ],
+)
+
+py_test(
+ name = "k8s_utils_lib_test",
+ srcs = [
+ "k8s_utils_test.py",
+ ],
+ imports = ["../.."],
+ main = "k8s_utils_test.py",
+ deps = [
+ ":k8s_utils_lib",
+ "//web_console_v2/api/testing:helpers_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = ["services.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ ],
+)
+
+py_test(
+ name = "services_lib_test",
+ srcs = [
+ "services_test.py",
+ ],
+ imports = ["../.."],
+ main = "services_test.py",
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_time_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":k8s_utils_lib",
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_time_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/__init__.py b/web_console_v2/api/fedlearner_webconsole/participant/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/apis.py b/web_console_v2/api/fedlearner_webconsole/participant/apis.py
new file mode 100644
index 000000000..c7ff12619
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/apis.py
@@ -0,0 +1,426 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from http import HTTPStatus
+from typing import Optional
+from flask_restful import Api, Resource
+from webargs import fields, validate
+from google.protobuf.json_format import MessageToDict
+
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import InvalidArgumentException, ResourceConflictException, \
+ NotFoundException, MethodNotAllowedException
+from fedlearner_webconsole.participant.k8s_utils import get_host_and_port, get_valid_candidates, \
+ create_or_update_participant_in_k8s
+from fedlearner_webconsole.participant.models import Participant, ParticipantType
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.proto.common_pb2 import StatusCode
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.utils.decorators.pp_flask import use_kwargs, input_validator, admin_required
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+
+
+def _get_empty_message_hint(field: str) -> str:
+ return f'{field} should not be empty'
+
+
+def _create_participant_in_k8s(domain_name: str, host: str, port: int, namespace: str):
+ # crete manually must have all the arguments
+ if host is None or port is None:
+ raise InvalidArgumentException('Do not have host or port.')
+ # create ingress and service
+ # TODO(taoyanting):validate url
+ create_or_update_participant_in_k8s(domain_name=domain_name, host=host, port=port, namespace=namespace)
+
+
+class ParticipantsApi(Resource):
+
+ @credentials_required
+ def get(self):
+ """Get all participant information ordered by `created_by`.
+ ---
+ tags:
+ - participant
+ description: Get all participant information ordered by `created_by`.
+ responses:
+ 200:
+ description: list of paritcipant
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Participant'
+ """
+ with db.session_scope() as session:
+ participants = session.query(Participant). \
+ order_by(Participant.created_at.desc()).all()
+ participant_service = ParticipantService(session)
+ protos = []
+ for participant in participants:
+ # A trade-off to join project counts with participants
+ proto = participant.to_proto()
+ proto.num_project = participant_service.get_number_of_projects(participant.id)
+ protos.append(proto)
+ return make_flask_response(data=protos)
+
+ # TODO(taoyanting): refactor this api
+ @input_validator
+ @credentials_required
+ @emits_event(audit_fields=['is_manual_configured', 'type'])
+ @use_kwargs({
+ 'name':
+ fields.Str(required=True),
+ 'domain_name':
+ fields.Str(required=True),
+ 'is_manual_configured':
+ fields.Bool(required=False, load_default=False),
+ 'type':
+ fields.Str(required=False,
+ load_default=ParticipantType.PLATFORM.name,
+ validate=validate.OneOf([t.name for t in ParticipantType])),
+ 'host':
+ fields.Str(required=False, load_default=None),
+ 'port':
+ fields.Integer(required=False, load_default=None),
+ 'comment':
+ fields.Str(required=False, load_default=None),
+ })
+ def post(
+ self,
+ name: str,
+ domain_name: str,
+ is_manual_configured: bool,
+ type: Optional[str], # pylint: disable=redefined-builtin
+ host: Optional[str],
+ port: Optional[int],
+ comment: Optional[str]):
+ """Create new participant
+ ---
+ tags:
+ - participant
+ description: Create new participant
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ name:
+ type: string
+ domain_name:
+ type: string
+ is_manual_configured:
+ type: boolean
+ type:
+ type: string
+ host:
+ type: string
+ port:
+ type: integer
+ comment:
+ type: string
+ responses:
+ 200:
+ description: Participant that you created.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Participant'
+ """
+
+ extra = {}
+ extra['is_manual_configured'] = is_manual_configured
+ participant_type = ParticipantType[type]
+
+ with db.session_scope() as session:
+ if session.query(Participant). \
+ filter_by(domain_name=domain_name).first() is not None:
+ raise ResourceConflictException(message='Participant domain name has been used')
+ service = ParticipantService(session)
+ if participant_type == ParticipantType.LIGHT_CLIENT:
+ participant = service.create_light_client_participant(name, domain_name, comment)
+ else:
+ if is_manual_configured:
+ namespace = SettingService(session).get_namespace()
+ _create_participant_in_k8s(domain_name, host, port, namespace)
+ else:
+ host, port = get_host_and_port(domain_name)
+ participant = service.create_platform_participant(name, domain_name, host, port, extra, comment)
+ try:
+ session.commit()
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ return make_flask_response(data=participant.to_proto(), status=HTTPStatus.CREATED)
+
+
+class ParticipantApi(Resource):
+
+ @credentials_required
+ def get(self, participant_id: int):
+ """Get details of particiapnt
+ ---
+ tags:
+ - participant
+ description: Get details of particiapnt
+ parameters:
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: the specified participant
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Participant'
+ """
+ with db.session_scope() as session:
+ participant = session.query(Participant).filter_by(id=participant_id).first()
+ if participant is None:
+ raise NotFoundException(f'Failed to find participant: {participant_id}')
+ return make_flask_response(data=participant.to_proto())
+
+ @credentials_required
+ @input_validator
+ @emits_event()
+ @use_kwargs({
+ 'name': fields.Str(required=False, load_default=None),
+ 'domain_name': fields.Str(required=False, load_default=None),
+ 'host': fields.Str(required=False, load_default=None),
+ 'port': fields.Integer(required=False, load_default=None),
+ 'comment': fields.Str(required=False, load_default=None),
+ })
+ def patch(self, participant_id: int, name: Optional[str], domain_name: Optional[str], host: Optional[str],
+ port: Optional[int], comment: Optional[str]):
+ """Partial update the given participant
+ ---
+ tags:
+ - participant
+ description: Partial update the given participant
+ parameters:
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ name:
+ type: string
+ domain_name:
+ type: string
+ host:
+ type: string
+ port:
+ type: integer
+ comment:
+ type: string
+ responses:
+ 200:
+ description: the updated participant
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Participant'
+ """
+ with db.session_scope() as session:
+ participant: Participant = session.query(Participant).filter_by(id=participant_id).first()
+ if participant is None:
+ raise NotFoundException(f'Failed to find participant: {participant_id}')
+
+ participant.name = name or participant.name
+ participant.comment = comment or participant.comment
+
+ if domain_name is not None and domain_name != participant.domain_name:
+ if session.query(Participant).filter_by(domain_name=domain_name).first() is not None:
+ raise ResourceConflictException(message='Participant domain name has been used')
+ participant.domain_name = domain_name
+ if participant.type == ParticipantType.PLATFORM:
+ extra = participant.get_extra_info()
+ if extra['is_manual_configured']:
+ if domain_name or host or port:
+ participant.host = host or participant.host
+ participant.port = port or participant.port
+
+ # TODO(taoyanting):validate url
+ try:
+ namespace = SettingService(session).get_namespace()
+ create_or_update_participant_in_k8s(domain_name=participant.domain_name,
+ host=participant.host,
+ port=participant.port,
+ namespace=namespace)
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ elif domain_name is not None:
+ host, port = get_host_and_port(participant.domain_name)
+ participant.host = host
+ participant.port = port
+ participant.set_extra_info(extra)
+ try:
+ session.commit()
+ return make_flask_response(data=participant.to_proto())
+ except Exception as e:
+ raise InvalidArgumentException(details=e) from e
+
+ @credentials_required
+ @admin_required
+ @emits_event()
+ def delete(self, participant_id):
+ """Delete a participant
+ ---
+ tags:
+ - participant
+ description: Delete a participant
+ parameters:
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: Deleted successfully
+ """
+ with db.session_scope() as session:
+ participant = session.query(Participant).filter_by(id=participant_id).first()
+ if participant is None:
+ raise NotFoundException(f'Failed to find participant: {participant_id}')
+
+ service = ParticipantService(session)
+ num_project = service.get_number_of_projects(participant_id)
+ if num_project != 0:
+ raise MethodNotAllowedException(f'Failed to delete participant: {participant_id}, '
+ f'because it has related projects')
+ session.delete(participant)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class ParticipantConnectionChecksApi(Resource):
+
+ @credentials_required
+ def get(self, participant_id: int):
+ """Check participant connection status
+ ---
+ tags:
+ - participant
+ description: Check participant connection status
+ parameters:
+ - in: path
+ name: participant_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: connection status
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ success:
+ type: boolean
+ message:
+ type: string
+ application_version:
+ $ref: '#/definitions/proto.ApplicationVersion'
+ """
+ with db.session_scope() as session:
+ participant = session.query(Participant).filter_by(id=participant_id).first()
+ if participant is None:
+ raise NotFoundException(f'Failed to find participant: {participant_id}')
+
+ client = RpcClient.from_participant(participant.domain_name)
+ result = client.check_peer_connection()
+ version = {}
+ if result.status.code == StatusCode.STATUS_SUCCESS:
+ version = MessageToDict(result.application_version, preserving_proto_field_name=True)
+ return make_flask_response({
+ 'success': result.status.code == StatusCode.STATUS_SUCCESS,
+ 'message': result.status.msg,
+ 'application_version': version
+ })
+
+
+class ParticipantCandidatesApi(Resource):
+
+ @credentials_required
+ @admin_required
+ def get(self):
+ """Get candidate participant according to kueburnetes resource.
+ ---
+ tags:
+ - participant
+ description: Get candidate participant according to kueburnetes resource.
+ responses:
+ 200:
+ description:
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: object
+ properties:
+ domain_name:
+ type: string
+ """
+ return make_flask_response(get_valid_candidates())
+
+
+class ParticipantFlagsApi(Resource):
+
+ def get(self, participant_id: int):
+ """Get flags from participant
+ ---
+ tags:
+ - flag
+ responses:
+ 200:
+ description: Participant's flags are returned
+ content:
+ application/json:
+ schema:
+ type: object
+ additionalProperties: true
+ example:
+ FLAG_1: string_value
+ FLAG_2: true
+ FLAG_3: 1
+ """
+ with db.session_scope() as session:
+ participant: Participant = session.query(Participant).get(participant_id)
+ if participant is None:
+ raise NotFoundException(f'Failed to find participant: {participant_id}')
+ client = SystemServiceClient.from_participant(domain_name=participant.domain_name)
+ return make_flask_response(data=client.list_flags())
+
+
+def initialize_participant_apis(api: Api):
+ api.add_resource(ParticipantsApi, '/participants')
+ api.add_resource(ParticipantApi, '/participants/')
+ api.add_resource(ParticipantConnectionChecksApi, '/participants//connection_checks')
+ api.add_resource(ParticipantCandidatesApi, '/participant_candidates')
+ api.add_resource(ParticipantFlagsApi, '/participants//flags')
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/apis_test.py b/web_console_v2/api/fedlearner_webconsole/participant/apis_test.py
new file mode 100644
index 000000000..d06a13dcf
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/apis_test.py
@@ -0,0 +1,372 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+
+from http import HTTPStatus
+
+from unittest.mock import patch, MagicMock
+
+from fedlearner_webconsole.utils.pp_time import sleep
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant, ParticipantType
+from testing.common import BaseTestCase
+
+
+class ParticipantsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.fake_certs = {'test/test.certs': 'key'}
+ self.default_participant = Participant(name='test-particitant-name',
+ domain_name='fl-test.com',
+ host='1.1.1.1',
+ port=32443,
+ comment='test comment')
+ self.default_participant.set_extra_info({'is_manual_configured': False})
+
+ self.participant_manually = Participant(name='test-manual-participant',
+ domain_name='fl-test-manual.com',
+ host='1.1.1.2',
+ port=443)
+ self.participant_manually.set_extra_info({
+ 'is_manual_configured': True,
+ })
+
+ with db.session_scope() as session:
+ session.add(self.default_participant)
+ session.flush()
+ sleep(1)
+ session.add(self.participant_manually)
+ session.commit()
+
+ @patch('fedlearner_webconsole.participant.apis.get_host_and_port')
+ def test_post_participant_without_certificate(self, mock_get_host_and_port):
+ name = 'test-post-participant'
+ domain_name = 'fl-post-test.com'
+ comment = 'test post participant'
+ host = '120.0.0.20'
+ port = 20
+ mock_get_host_and_port.return_value = (host, port)
+
+ create_response = self.post_helper('/api/v2/participants',
+ data={
+ 'name': name,
+ 'domain_name': domain_name,
+ 'is_manual_configured': False,
+ 'comment': comment
+ })
+ self.assertEqual(HTTPStatus.CREATED, create_response.status_code)
+ participant = self.get_response_data(create_response)
+ # yapf: disable
+ self.assertPartiallyEqual(participant, {
+ 'id': 3,
+ 'comment': comment,
+ 'domain_name': 'fl-post-test.com',
+ 'pure_domain_name': 'post-test',
+ 'host': host,
+ 'name': name,
+ 'port': port,
+ 'extra': {
+ 'is_manual_configured': False,
+ },
+ 'type': 'PLATFORM',
+ 'last_connected_at': 0,
+ 'num_project': 0,
+ }, ignore_fields=['created_at', 'updated_at'])
+ # yapf: enable
+
+ @patch('fedlearner_webconsole.participant.apis.create_or_update_participant_in_k8s')
+ def test_post_participant_manually(self, mock_create_or_update_participant_in_k8s):
+ name = 'test-post-participant'
+ domain_name = 'fl-post-test.com'
+ comment = 'test post participant'
+ host = '120.0.0.20'
+ port = 20
+
+ create_response = self.post_helper('/api/v2/participants',
+ data={
+ 'name': name,
+ 'domain_name': domain_name,
+ 'comment': comment,
+ 'is_manual_configured': True,
+ 'host': host,
+ 'port': port,
+ })
+
+ self.assertEqual(HTTPStatus.CREATED, create_response.status_code)
+ participant = self.get_response_data(create_response)
+ # yapf: disable
+ self.assertPartiallyEqual(participant, {
+ 'id': 3,
+ 'comment': comment,
+ 'domain_name': 'fl-post-test.com',
+ 'pure_domain_name': 'post-test',
+ 'host': host,
+ 'name': name,
+ 'port': port,
+ 'extra': {
+ 'is_manual_configured': True,
+ },
+ 'type': 'PLATFORM',
+ 'last_connected_at': 0,
+ 'num_project': 0,
+ }, ignore_fields=['created_at', 'updated_at'])
+ # yapf: enable
+ mock_create_or_update_participant_in_k8s.assert_called_once_with(domain_name='fl-post-test.com',
+ host='120.0.0.20',
+ namespace='default',
+ port=20)
+
+ def test_post_light_client_participant(self):
+ resp = self.post_helper('/api/v2/participants',
+ data={
+ 'name': 'light-client',
+ 'domain_name': 'fl-light-client.com',
+ 'type': 'LIGHT_CLIENT',
+ 'is_manual_configured': False,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ participant = session.query(Participant).filter_by(name='light-client').first()
+ self.assertEqual(participant.domain_name, 'fl-light-client.com')
+ self.assertEqual(participant.type, ParticipantType.LIGHT_CLIENT)
+ self.assertResponseDataEqual(resp, {
+ 'name': 'light-client',
+ 'domain_name': 'fl-light-client.com',
+ 'pure_domain_name': 'light-client',
+ 'host': '',
+ 'port': 0,
+ 'type': 'LIGHT_CLIENT',
+ 'comment': '',
+ 'extra': {
+ 'is_manual_configured': False
+ },
+ 'last_connected_at': 0,
+ 'num_project': 0,
+ },
+ ignore_fields=['id', 'created_at', 'updated_at'])
+
+ @patch('fedlearner_webconsole.participant.apis.get_host_and_port')
+ def test_post_conflict_domain_name_participant(self, mock_get_host_and_port):
+ mock_get_host_and_port.return_value = ('1.1.1.1', 1)
+ create_response = self.post_helper('/api/v2/participants',
+ data={
+ 'name': 'test-post-conflict-participant',
+ 'domain_name': 'fl-test.com',
+ 'is_manual_configured': False,
+ })
+
+ self.assertEqual(HTTPStatus.CONFLICT, create_response.status_code)
+
+ def test_list_participant(self):
+ list_response = self.get_helper('/api/v2/participants')
+ participants = self.get_response_data(list_response)
+ self.assertPartiallyEqual(
+ participants,
+ [{
+ 'comment': '',
+ 'domain_name': 'fl-test-manual.com',
+ 'pure_domain_name': 'test-manual',
+ 'host': '1.1.1.2',
+ 'id': 2,
+ 'name': 'test-manual-participant',
+ 'port': 443,
+ 'extra': {
+ 'is_manual_configured': True,
+ },
+ 'last_connected_at': 0,
+ 'num_project': 0,
+ 'type': 'PLATFORM',
+ }, {
+ 'comment': 'test comment',
+ 'domain_name': 'fl-test.com',
+ 'pure_domain_name': 'test',
+ 'host': '1.1.1.1',
+ 'id': 1,
+ 'name': 'test-particitant-name',
+ 'port': 32443,
+ 'extra': {
+ 'is_manual_configured': False,
+ },
+ 'last_connected_at': 0,
+ 'num_project': 0,
+ 'type': 'PLATFORM',
+ }],
+ ignore_fields=['created_at', 'updated_at'],
+ )
+
+ @patch('fedlearner_webconsole.participant.apis.get_host_and_port')
+ def test_update_participant(self, mock_get_host_and_port):
+ name = 'test-update-participant'
+ domain_name = 'fl-update-test.com'
+ comment = 'test update participant'
+ ip = '120.0.0.30'
+ port = 30
+ mock_get_host_and_port.return_value = (ip, port)
+
+ update_response = self.patch_helper('/api/v2/participants/1',
+ data={
+ 'name': name,
+ 'domain_name': domain_name,
+ 'comment': comment,
+ })
+ participant = self.get_response_data(update_response)
+
+ self.assertEqual(update_response.status_code, HTTPStatus.OK)
+ # yapf: disable
+ self.assertPartiallyEqual(participant, {
+ 'comment': comment,
+ 'domain_name': 'fl-update-test.com',
+ 'pure_domain_name': 'update-test',
+ 'host': ip,
+ 'id': 1,
+ 'name': name,
+ 'port': port,
+ 'extra': {
+ 'is_manual_configured': False,
+ },
+ 'last_connected_at': 0,
+ 'num_project': 0,
+ 'type': 'PLATFORM'
+ }, ignore_fields=['created_at', 'updated_at'])
+ # yapf: enable
+
+ def test_update_participant_conflict_domain_name(self):
+ update_response = self.patch_helper('/api/v2/participants/1', data={
+ 'domain_name': 'fl-test-manual.com',
+ })
+ self.assertEqual(update_response.status_code, HTTPStatus.CONFLICT)
+
+ @patch('fedlearner_webconsole.participant.apis.create_or_update_participant_in_k8s')
+ def test_update_host_and_port(self, mock_create_or_update_participant_in_k8s):
+ update_response = self.patch_helper('/api/v2/participants/2', data={
+ 'host': '1.112.212.20',
+ 'port': 9999,
+ })
+
+ self.assertEqual(update_response.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(update_response)['port'], 9999)
+
+ mock_create_or_update_participant_in_k8s.assert_called_once()
+
+ @patch('fedlearner_webconsole.participant.apis.create_or_update_participant_in_k8s')
+ def test_update_only_name(self, mock_create_or_update_participant_in_k8s):
+ update_response = self.patch_helper('/api/v2/participants/2', data={
+ 'name': 'fl-test-only-name',
+ })
+
+ self.assertEqual(update_response.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(update_response)['name'], 'fl-test-only-name')
+
+ mock_create_or_update_participant_in_k8s.assert_not_called()
+
+ def test_update_light_client(self):
+ with db.session_scope() as session:
+ party = Participant(name='test-party', domain_name='fl-light-client.com', type=ParticipantType.LIGHT_CLIENT)
+ session.add(party)
+ session.commit()
+ resp = self.patch_helper(f'/api/v2/participants/{party.id}',
+ data={
+ 'name': 'test-name',
+ 'domain_name': 'fl-1.com',
+ 'comment': 'comment'
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ party = session.query(Participant).get(party.id)
+ self.assertEqual(party.name, 'test-name')
+ self.assertEqual(party.domain_name, 'fl-1.com')
+ self.assertEqual(party.comment, 'comment')
+
+ def test_get_participant(self):
+ with db.session_scope() as session:
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ session.add(relationship)
+ session.commit()
+ get_response = self.get_helper('/api/v2/participants/1')
+ participant = self.get_response_data(get_response)
+ # yapf: disable
+ self.assertPartiallyEqual(participant, {
+ 'comment': 'test comment',
+ 'domain_name': 'fl-test.com',
+ 'pure_domain_name': 'test',
+ 'host': '1.1.1.1',
+ 'id': 1,
+ 'name': 'test-particitant-name',
+ 'port': 32443,
+ 'extra': {
+ 'is_manual_configured': False,
+ },
+ 'type': 'PLATFORM',
+ 'last_connected_at': 0,
+ 'num_project': 0,
+ }, ignore_fields=['created_at', 'updated_at'])
+ # yapf: enable
+
+ def test_delete_participant(self):
+ self.signin_as_admin()
+ with db.session_scope() as session:
+ relationship = ProjectParticipant(project_id=1, participant_id=2)
+ session.add(relationship)
+ session.commit()
+ # test delete participant which does not exist
+ delete_response = self.delete_helper('/api/v2/participants/3')
+ self.assertEqual(delete_response.status_code, HTTPStatus.NOT_FOUND)
+
+ # test delete participant which has related projects
+ delete_response = self.delete_helper('/api/v2/participants/2')
+ self.assertEqual(delete_response.status_code, HTTPStatus.METHOD_NOT_ALLOWED)
+
+ # test delete participant successfully
+ delete_response = self.delete_helper('/api/v2/participants/1')
+ self.assertEqual(delete_response.status_code, HTTPStatus.NO_CONTENT)
+
+
+class ParticipantCandidatesApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.signin_as_admin()
+
+ def test_get_valid_candidates(self):
+ get_response = self.get_helper('/api/v2/participant_candidates')
+ data = self.get_response_data(get_response)
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ self.assertEqual(data, [{'domain_name': 'fl-aaa.com'}, {'domain_name': 'fl-ccc.com'}])
+
+
+class ParticipantFlagsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ session.add(Participant(id=1, name='party', domain_name='fl-test.com'))
+ session.commit()
+
+ @patch('fedlearner_webconsole.participant.apis.SystemServiceClient')
+ def test_get_peer_flags(self, mock_client: MagicMock):
+ instance = mock_client.from_participant.return_value
+ instance.list_flags.return_value = {'key': 'value'}
+ # fail due to participant not found
+ resp = self.get_helper('/api/v2/participants/2/flags')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ resp = self.get_helper('/api/v2/participants/1/flags')
+ mock_client.from_participant.assert_called_with(domain_name='fl-test.com')
+ self.assertResponseDataEqual(resp, {'key': 'value'})
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/k8s_utils.py b/web_console_v2/api/fedlearner_webconsole/participant/k8s_utils.py
new file mode 100644
index 000000000..90dcbf4c9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/k8s_utils.py
@@ -0,0 +1,134 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import re
+from typing import Tuple, List, Dict
+
+from envs import Envs
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.k8s.k8s_client import k8s_client
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+
+_RE_INGRESS_NAME = re.compile(r'^(fl-).+(-client-auth)$')
+_RE_SERVICE_NAME = re.compile(r'^(fl-).+$')
+
+
+def get_valid_candidates() -> Dict[str, List[str]]:
+ with db.session_scope() as session:
+ namespace = SettingService(session).get_namespace()
+ ingresses = k8s_client.list_ingress(namespace)
+ ingress_names = []
+ for ingress in ingresses.items:
+ if hasattr(ingress, 'metadata') and hasattr(ingress.metadata, 'name'):
+ name = ingress.metadata.name
+ if _RE_INGRESS_NAME.fullmatch(name):
+ ingress_names.append(name)
+
+ services = k8s_client.list_service(namespace)
+ service_names = []
+ for service in services.items:
+ if hasattr(service, 'metadata') and hasattr(service.metadata, 'name'):
+ name = service.metadata.name
+ if _RE_SERVICE_NAME.fullmatch(name):
+ service_names.append(name)
+
+ candidates = [{'domain_name': f'{name[:-12]}.com'} for name in ingress_names if name[:-12] in service_names]
+ return candidates
+
+
+def get_host_and_port(domain_name: str) -> Tuple[str, int]:
+ with db.session_scope() as session:
+ namespace = SettingService(session).get_namespace()
+ service_name = domain_name.rpartition('.')[0]
+ ingress_name = f'{service_name}-client-auth'
+
+ try:
+ service = k8s_client.get_service(name=service_name, namespace=namespace)
+ ingress = k8s_client.get_ingress(name=ingress_name, namespace=namespace)
+ host = service.spec.external_name
+ port = ingress.spec.rules[0].http.paths[0].backend.service_port
+ except Exception as e:
+ raise InvalidArgumentException(details=f'can not find post or port in ingress, {e}') from e
+
+ return host, port
+
+
+def _create_or_update_participant_ingress(name: str, service_port: int, namespace: str):
+ client_auth_ingress_name = f'{name}-client-auth'
+ pure_domain_name = get_pure_domain_name(name)
+ host = f'{pure_domain_name}.fedlearner.net'
+ configuration_snippet = f"""
+ grpc_next_upstream_tries 5;
+ grpc_set_header Host {host};
+ grpc_set_header Authority {host};"""
+ # TODO(wangsen.0914): removes this hack after we align the controller
+ is_tce = False # TODO(lixiaoguang.01) hardcode
+ secret_path = 'ingress-nginx/client' if not is_tce else 'tce_static/bdcert'
+ grpc_ssl_trusted_certificate = 'all.pem' if not is_tce else 'intermediate.pem'
+ server_snippet = f"""
+ grpc_ssl_verify on;
+ grpc_ssl_server_name on;
+ grpc_ssl_name {host};
+ grpc_ssl_trusted_certificate /etc/{secret_path}/{grpc_ssl_trusted_certificate};
+ grpc_ssl_certificate /etc/{secret_path}/client.pem;
+ grpc_ssl_certificate_key /etc/{secret_path}/client.key;"""
+ # yapf: disable
+ k8s_client.create_or_update_ingress(metadata={
+ 'name': client_auth_ingress_name,
+ 'namespace': namespace,
+ 'annotations': {
+ 'nginx.ingress.kubernetes.io/backend-protocol': 'GRPCS',
+ 'nginx.ingress.kubernetes.io/http2-insecure-port': 'true',
+ 'nginx.ingress.kubernetes.io/configuration-snippet': configuration_snippet,
+ 'nginx.ingress.kubernetes.io/server-snippet': server_snippet
+ }
+ },
+ spec={
+ 'rules': [{
+ 'host': f'{client_auth_ingress_name}.com',
+ 'http': {
+ 'paths': [{
+ 'pathType': 'ImplementationSpecific',
+ 'backend': {
+ 'serviceName': name,
+ 'servicePort': service_port
+ }
+ }]
+ }
+ }],
+ 'ingressClassName': None
+ },
+ name=client_auth_ingress_name,
+ namespace=namespace)
+ # yapf: enable
+
+
+def create_or_update_participant_in_k8s(domain_name: str, host: str, port: int, namespace: str):
+ name = domain_name.rpartition('.')[0]
+ k8s_client.create_or_update_service(
+ metadata={
+ 'name': name,
+ 'namespace': namespace,
+ },
+ spec={
+ 'externalName': host,
+ 'type': 'ExternalName',
+ },
+ name=name,
+ namespace=namespace,
+ )
+ _create_or_update_participant_ingress(name=name, service_port=port, namespace=namespace)
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/k8s_utils_test.py b/web_console_v2/api/fedlearner_webconsole/participant/k8s_utils_test.py
new file mode 100644
index 000000000..89b03b47c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/k8s_utils_test.py
@@ -0,0 +1,130 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+
+from fedlearner_webconsole.participant.k8s_utils import get_host_and_port, _create_or_update_participant_ingress, \
+ create_or_update_participant_in_k8s
+from testing.helpers import to_simple_namespace
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ParticipantsK8sUtilsTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.participant.k8s_utils.k8s_client')
+ @patch('fedlearner_webconsole.participant.k8s_utils.SettingService.get_namespace')
+ def test_get_host_and_port(self, mock_get_namespace, mock_k8s_client):
+ fake_service = to_simple_namespace({'spec': {'external_name': '127.0.0.10'}})
+ fake_ingress = to_simple_namespace({
+ 'spec': {
+ 'rules': [{
+ 'http': {
+ 'paths': [{
+ 'backend': {
+ 'service_name': 'fakeservice',
+ 'service_port': 32443
+ }
+ }]
+ }
+ }]
+ }
+ })
+ domain_name = 'test_domain_name.com'
+ mock_get_namespace.return_value = 'default'
+ mock_k8s_client.get_service = MagicMock(return_value=fake_service)
+ mock_k8s_client.get_ingress = MagicMock(return_value=fake_ingress)
+
+ host, port = get_host_and_port(domain_name)
+ self.assertEqual(host, '127.0.0.10')
+ self.assertEqual(port, 32443)
+ mock_k8s_client.get_service.assert_called_once_with(name='test_domain_name', namespace='default')
+ mock_k8s_client.get_ingress.assert_called_once_with(name='test_domain_name-client-auth', namespace='default')
+
+ @patch('fedlearner_webconsole.participant.k8s_utils.k8s_client')
+ def test_create_or_update_participant_ingress(self, mock_k8s_client: MagicMock):
+ mock_k8s_client.create_or_update_ingress = MagicMock()
+ _create_or_update_participant_ingress('fl-test', service_port=32443, namespace='fedlearner')
+ mock_k8s_client.create_or_update_ingress.assert_called_once_with(
+ name='fl-test-client-auth',
+ namespace='fedlearner',
+ metadata={
+ 'name': 'fl-test-client-auth',
+ 'namespace': 'fedlearner',
+ 'annotations': {
+ 'nginx.ingress.kubernetes.io/backend-protocol': 'GRPCS',
+ 'nginx.ingress.kubernetes.io/http2-insecure-port': 'true',
+ 'nginx.ingress.kubernetes.io/configuration-snippet':
+ '\n'
+ ' grpc_next_upstream_tries 5;\n'
+ ' grpc_set_header Host test.fedlearner.net;\n'
+ ' grpc_set_header Authority test.fedlearner.net;',
+ 'nginx.ingress.kubernetes.io/server-snippet':
+ '\n'
+ ' grpc_ssl_verify on;\n'
+ ' grpc_ssl_server_name on;\n'
+ ' grpc_ssl_name test.fedlearner.net;\n'
+ ' grpc_ssl_trusted_certificate /etc/ingress-nginx/client/all.pem;\n'
+ ' grpc_ssl_certificate /etc/ingress-nginx/client/client.pem;\n'
+ ' grpc_ssl_certificate_key /etc/ingress-nginx/client/client.key;'
+ }
+ },
+ spec={
+ 'rules': [{
+ 'host': 'fl-test-client-auth.com',
+ 'http': {
+ 'paths': [{
+ 'pathType': 'ImplementationSpecific',
+ 'backend': {
+ 'serviceName': 'fl-test',
+ 'servicePort': 32443
+ }
+ }]
+ }
+ }],
+ 'ingressClassName': None
+ },
+ )
+
+ @patch('fedlearner_webconsole.participant.k8s_utils._create_or_update_participant_ingress')
+ @patch('fedlearner_webconsole.participant.k8s_utils.k8s_client')
+ def test_create_or_update_participant_in_k8s(self, mock_k8s_client: MagicMock,
+ mock_create_or_update_participant_ingress: MagicMock):
+ mock_k8s_client.create_or_update_service = MagicMock()
+ create_or_update_participant_in_k8s(domain_name='fl-a-test.com',
+ host='1.2.3.4',
+ port=32443,
+ namespace='fedlearner')
+ mock_k8s_client.create_or_update_service.assert_called_once_with(
+ name='fl-a-test',
+ namespace='fedlearner',
+ metadata={
+ 'name': 'fl-a-test',
+ 'namespace': 'fedlearner',
+ },
+ spec={
+ 'externalName': '1.2.3.4',
+ 'type': 'ExternalName',
+ },
+ )
+ mock_create_or_update_participant_ingress.assert_called_once_with(
+ name='fl-a-test',
+ service_port=32443,
+ namespace='fedlearner',
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/models.py b/web_console_v2/api/fedlearner_webconsole/participant/models.py
new file mode 100644
index 000000000..1f719bd31
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/models.py
@@ -0,0 +1,96 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+import json
+from enum import Enum
+from typing import Dict
+from sqlalchemy import UniqueConstraint, Index
+from sqlalchemy.sql import func
+
+from fedlearner_webconsole.proto import participant_pb2
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+from fedlearner_webconsole.utils.base_model.review_ticket_model import ReviewTicketModel
+
+
+class ParticipantType(Enum):
+ PLATFORM = 0
+ LIGHT_CLIENT = 1
+
+
+class Participant(db.Model, ReviewTicketModel):
+ __tablename__ = 'participants_v2'
+ __table_args__ = (UniqueConstraint('domain_name', name='uniq_domain_name'),
+ default_table_args('This is webconsole participant table.'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='participant id')
+ name = db.Column(db.String(255), nullable=False, comment='participant name')
+ domain_name = db.Column(db.String(255), unique=True, nullable=False, comment='participant domain_name')
+ host = db.Column(db.String(255), comment='participant host')
+ port = db.Column(db.Integer, comment='host port')
+ type = db.Column('participant_type',
+ db.Enum(ParticipantType, native_enum=False, length=64, create_constraint=False),
+ default=ParticipantType.PLATFORM,
+ key='type',
+ comment='participant type')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
+ extra = db.Column(db.Text(), comment='extra_info')
+ last_connected_at = db.Column(db.DateTime(timezone=True), comment='last connected at')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created at')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ onupdate=func.now(),
+ server_default=func.now(),
+ comment='updated at')
+
+ def set_extra_info(self, extra_info: Dict):
+ self.extra = json.dumps(extra_info)
+
+ def get_extra_info(self) -> Dict:
+ if self.extra is not None:
+ return json.loads(self.extra)
+ return {}
+
+ def get_type(self) -> ParticipantType:
+ return self.type if self.type else ParticipantType.PLATFORM
+
+ def pure_domain_name(self):
+ return get_pure_domain_name(self.domain_name)
+
+ def to_proto(self) -> participant_pb2.Participant:
+ extra_info = self.get_extra_info()
+ proto = participant_pb2.Participant(
+ id=self.id,
+ name=self.name,
+ domain_name=self.domain_name,
+ pure_domain_name=self.pure_domain_name(),
+ host=self.host,
+ port=self.port,
+ type=self.get_type().name,
+ comment=self.comment,
+ last_connected_at=to_timestamp(self.last_connected_at) if self.last_connected_at else 0,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at) if self.updated_at else 0,
+ extra=participant_pb2.ParticipantExtra(is_manual_configured=extra_info.get('is_manual_configured', False)))
+ return proto
+
+
+class ProjectParticipant(db.Model):
+ __tablename__ = 'projects_participants_v2'
+ __table_args__ = (Index('idx_project_id', 'project_id'), Index('idx_participant_id', 'participant_id'),
+ default_table_args('This is webcocsole projects and participants relationship table.'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='relationship id')
+ project_id = db.Column(db.Integer, nullable=False, comment='project_id id')
+ participant_id = db.Column(db.Integer, nullable=False, comment='participants_id id')
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/models_test.py b/web_console_v2/api/fedlearner_webconsole/participant/models_test.py
new file mode 100644
index 000000000..ac648c270
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/models_test.py
@@ -0,0 +1,54 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+
+from fedlearner_webconsole.participant.models import Participant, ParticipantType
+from fedlearner_webconsole.proto import participant_pb2
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ParticipantTest(NoWebServerTestCase):
+
+ def test_to_proto(self):
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ participant = Participant(id=123,
+ name='testp',
+ domain_name='fl-test.com',
+ host='test.fl.com',
+ port=32443,
+ type=ParticipantType.PLATFORM,
+ comment='c',
+ created_at=created_at,
+ updated_at=created_at,
+ extra='{"is_manual_configured":true}')
+ self.assertEqual(
+ participant.to_proto(),
+ participant_pb2.Participant(id=123,
+ name='testp',
+ domain_name='fl-test.com',
+ pure_domain_name='test',
+ host='test.fl.com',
+ port=32443,
+ type='PLATFORM',
+ comment='c',
+ created_at=int(created_at.timestamp()),
+ updated_at=int(created_at.timestamp()),
+ extra=participant_pb2.ParticipantExtra(is_manual_configured=True)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/services.py b/web_console_v2/api/fedlearner_webconsole/participant/services.py
new file mode 100644
index 000000000..c98738dbc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/services.py
@@ -0,0 +1,85 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import List, Optional
+from fedlearner_webconsole.auth.models import Session
+from fedlearner_webconsole.participant.models import ProjectParticipant, Participant, ParticipantType
+
+
+class ParticipantService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def get_participant_by_pure_domain_name(self, pure_domain_name: str) -> Optional[Participant]:
+ """Finds a specific participant by pure domain name, e.g. aliyun-test.
+
+ For compatible reason, we have two kinds of domain name, fl-xxx.com and xxx.fedlearner.net,
+ so we need to identify a participant by pure domain name globally."""
+ participants = self._session.query(Participant).filter(
+ Participant.domain_name.like(f'%{pure_domain_name}%')).all()
+ for p in participants:
+ if p.pure_domain_name() == pure_domain_name:
+ return p
+ return None
+
+ def get_participants_by_project(self, project_id: int) -> List:
+ # the precision of datetime cannot suffice, use id instead
+ participants = self._session.query(Participant).join(
+ ProjectParticipant, ProjectParticipant.participant_id == Participant.id).filter(
+ ProjectParticipant.project_id == project_id). \
+ order_by(Participant.created_at.desc()).all()
+ return participants
+
+ # get only platform participant, ignore light-client type participant
+ def get_platform_participants_by_project(self, project_id: int) -> List:
+ participants = self.get_participants_by_project(project_id)
+ platform_participants = []
+ for participant in participants:
+ # a hack that previous participant_type is null
+ if participant.get_type() == ParticipantType.PLATFORM:
+ platform_participants.append(participant)
+ return platform_participants
+
+ def get_number_of_projects(self, participant_id: int) -> int:
+ return self._session.query(ProjectParticipant).filter_by(participant_id=participant_id).count()
+
+ def create_light_client_participant(self,
+ name: str,
+ domain_name: str,
+ comment: Optional[str] = None) -> Participant:
+ participant = Participant(name=name,
+ domain_name=domain_name,
+ type=ParticipantType.LIGHT_CLIENT,
+ comment=comment)
+ self._session.add(participant)
+ return participant
+
+ def create_platform_participant(self,
+ name: str,
+ domain_name: str,
+ host: str,
+ port: int,
+ extra: dict,
+ comment: Optional[str] = None) -> Participant:
+ participant = Participant(name=name,
+ domain_name=domain_name,
+ host=host,
+ port=port,
+ type=ParticipantType.PLATFORM,
+ comment=comment)
+ participant.set_extra_info(extra)
+ self._session.add(participant)
+ return participant
diff --git a/web_console_v2/api/fedlearner_webconsole/participant/services_test.py b/web_console_v2/api/fedlearner_webconsole/participant/services_test.py
new file mode 100644
index 000000000..4870857e7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/participant/services_test.py
@@ -0,0 +1,109 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.models import Participant, ParticipantType, ProjectParticipant
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.pp_time import sleep
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ParticipantServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.participant_1 = Participant(name='participant 1', domain_name='fl-participant-1.com')
+ self.participant_2 = Participant(name='participant 2', domain_name='participant-2.fedlearner.net')
+ self.participant_3 = Participant(name='participant 3',
+ domain_name='participant-3.fedlearner.net',
+ type=ParticipantType.LIGHT_CLIENT)
+ self.project_1 = Project(name='project 1')
+ self.project_2 = Project(name='project 2')
+ self.project_3 = Project(name='project 3')
+ self.relationship_11 = ProjectParticipant(project_id=1, participant_id=1)
+ self.relationship_12 = ProjectParticipant(project_id=1, participant_id=2)
+ self.relationship_22 = ProjectParticipant(project_id=2, participant_id=2)
+ self.relationship_33 = ProjectParticipant(project_id=3, participant_id=3)
+
+ with db.session_scope() as session:
+ session.add(self.participant_1)
+ session.flush()
+ sleep(1)
+ session.add(self.participant_2)
+ session.add(self.project_1)
+ session.flush()
+ sleep(1)
+ session.add(self.project_2)
+ session.add(self.project_3)
+ session.add(self.participant_3)
+ session.add(self.relationship_11)
+ session.add(self.relationship_12)
+ session.add(self.relationship_22)
+ session.add(self.relationship_33)
+ session.commit()
+
+ def test_get_participant_by_pure_domain_name(self):
+ with db.session_scope() as session:
+ service = ParticipantService(session)
+
+ p = service.get_participant_by_pure_domain_name('participant-1')
+ self.assertEqual(p.id, self.participant_1.id)
+ p = service.get_participant_by_pure_domain_name('participant-2')
+ self.assertEqual(p.id, self.participant_2.id)
+ self.assertIsNone(service.get_participant_by_pure_domain_name('participant'))
+ self.assertIsNone(service.get_participant_by_pure_domain_name('none'))
+
+ def test_get_participants_by_project_id(self):
+ with db.session_scope() as session:
+ service = ParticipantService(session)
+ participants = service.get_participants_by_project(1)
+ self.assertEqual(len(participants), 2)
+ self.assertEqual(participants[0].name, 'participant 2')
+ self.assertEqual(participants[1].name, 'participant 1')
+
+ participants = service.get_participants_by_project(2)
+ self.assertEqual(len(participants), 1)
+ self.assertEqual(participants[0].name, 'participant 2')
+
+ participants = service.get_participants_by_project(3)
+ self.assertEqual(len(participants), 1)
+ self.assertEqual(participants[0].name, 'participant 3')
+
+ def test_get_platform_participants_by_project(self):
+
+ with db.session_scope() as session:
+ service = ParticipantService(session)
+ participants = service.get_platform_participants_by_project(2)
+ self.assertEqual(len(participants), 1)
+ self.assertEqual(participants[0].name, 'participant 2')
+
+ participants = service.get_platform_participants_by_project(3)
+ self.assertEqual(len(participants), 0)
+
+ def test_get_number_of_projects(self):
+ with db.session_scope() as session:
+ service = ParticipantService(session)
+ num1 = service.get_number_of_projects(1)
+ self.assertEqual(num1, 1)
+
+ num2 = service.get_number_of_projects(2)
+ self.assertEqual(num2, 2)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/project/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/project/BUILD.bazel
new file mode 100644
index 000000000..f16d5a8ce
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/BUILD.bazel
@@ -0,0 +1,188 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:iam_required_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:permission_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "controllers_lib",
+ srcs = [
+ "controllers.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:project_service_client_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "controllers_lib_test",
+ size = "small",
+ srcs = [
+ "controllers_test.py",
+ ],
+ imports = ["../.."],
+ main = "controllers_test.py",
+ deps = [
+ ":controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = ["services.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "services_lib_test",
+ size = "small",
+ srcs = [
+ "services_test.py",
+ ],
+ imports = ["../.."],
+ main = "services_test.py",
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "project_scheduler_lib",
+ srcs = [
+ "project_scheduler.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:project_service_client_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "project_scheduler_lib_test",
+ size = "small",
+ srcs = [
+ "project_scheduler_test.py",
+ ],
+ imports = ["../.."],
+ main = "project_scheduler_test.py",
+ deps = [
+ ":project_scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/project/add_on.py b/web_console_v2/api/fedlearner_webconsole/project/add_on.py
deleted file mode 100644
index c20453528..000000000
--- a/web_console_v2/api/fedlearner_webconsole/project/add_on.py
+++ /dev/null
@@ -1,342 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-
-import tarfile
-import io
-import os
-from base64 import b64encode, b64decode
-from typing import Type, Dict
-from OpenSSL import crypto, SSL
-from fedlearner_webconsole.utils.k8s_client import K8sClient
-
-CA_SECRET_NAME = 'ca-secret'
-OPERATOR_NAME = 'fedlearner-operator'
-SERVER_SECRET_NAME = 'fedlearner-proxy-server'
-INGRESS_NGINX_CONTROLLER_NAME = 'fedlearner-stack-ingress-nginx-controller'
-
-
-def parse_certificates(encoded_gz):
- """
- Parse certificates from base64-encoded string to a dict
- Args:
- encoded_gz: A base64-encoded string from a `.gz` file.
- Returns:
- dict: key is the file name, value is the content
- """
- binary_gz = io.BytesIO(b64decode(encoded_gz))
- with tarfile.open(fileobj=binary_gz) as gz:
- certificates = {}
- for file in gz.getmembers():
- if file.isfile():
- # raw file name is like `fl-test.com/client/client.pem`
- certificates[file.name.split('/', 1)[-1]] = \
- str(b64encode(gz.extractfile(file).read()),
- encoding='utf-8')
- return certificates
-
-
-def verify_certificates(certificates: Dict[str, str]) -> (bool, str):
- """
- Verify certificates from 4 aspects:
- 1. The CN of all public keys are equal.
- 2. All the CN are generic domain names.
- 3. Public key match private key.
- 4. Private key is signed by CA.
- Args:
- certificates:
- Returns:
- """
- try:
- client_public_key = crypto.load_certificate(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('client/client.pem')))
- server_public_key = crypto.load_certificate(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('server/server.pem')))
- client_private_key = crypto.load_privatekey(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('client/client.key')))
- server_private_key = crypto.load_privatekey(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('server/server.key')))
- client_intermediate_ca = crypto.load_certificate(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('client/intermediate.pem')))
- server_intermediate_ca = crypto.load_certificate(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('server/intermediate.pem')))
- client_root_ca = crypto.load_certificate(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('client/root.pem')))
- server_root_ca = crypto.load_certificate(
- crypto.FILETYPE_PEM,
- b64decode(certificates.get('server/root.pem')))
- except crypto.Error as err:
- return False, 'Format of key or CA is invalid: {}'.format(err)
-
- if client_public_key.get_subject().CN != server_public_key.get_subject().CN:
- return False, 'Client and server public key CN mismatch'
- if not client_public_key.get_subject().CN.startswith('*.'):
- return False, 'CN of public key should be a generic domain name'
-
- try:
- client_context = SSL.Context(SSL.TLSv1_METHOD)
- client_context.use_certificate(client_public_key)
- client_context.use_privatekey(client_private_key)
- client_context.check_privatekey()
-
- server_context = SSL.Context(SSL.TLSv1_METHOD)
- server_context.use_certificate(server_public_key)
- server_context.use_privatekey(server_private_key)
- server_context.check_privatekey()
- except SSL.Error as err:
- return False, 'Key pair mismatch: {}'.format(err)
-
- try:
- client_store = crypto.X509Store()
- client_store.add_cert(client_root_ca)
- client_store.add_cert(client_intermediate_ca)
- crypto.X509StoreContext(client_store, client_public_key)\
- .verify_certificate()
- except crypto.X509StoreContextError as err:
- return False, 'Client key and CA mismatch: {}'.format(err)
- try:
- server_store = crypto.X509Store()
- server_store.add_cert(server_root_ca)
- server_store.add_cert(server_intermediate_ca)
- crypto.X509StoreContext(server_store, server_public_key)\
- .verify_certificate()
- except crypto.X509StoreContextError as err:
- return False, 'Server key and CA mismatch: {}'.format(err)
-
- return True, ''
-
-
-def create_add_on(client: Type[K8sClient], domain_name: str, url: str,
- certificates: Dict[str, str], custom_host: str = None):
- """
- Idempotent
- Create add on and upgrade nginx-ingress and operator.
- If add on of domain_name exists, replace it.
-
- Args:
- client: K8s client instance
- domain_name: participant's domain name, used to create Ingress
- url: participant's external ip, used to create ExternalName
- Service
- certificates: used for two-way tls authentication and to create one
- server Secret, one client Secret and one CA
- custom_host: used for case where participant is using an external
- authentication gateway
- """
- # url: xxx.xxx.xxx.xxx:xxxxx
- ip = url.split(':')[0]
- port = int(url.split(':')[1])
- client_all_pem = str(b64encode('{}\n{}'.format(
- str(b64decode(certificates.get('client/intermediate.pem')),
- encoding='utf-8').strip(),
- str(b64decode(certificates.get('client/root.pem')),
- encoding='utf-8').strip()).encode()), encoding='utf-8')
- server_all_pem = str(b64encode('{}\n{}'.format(
- str(b64decode(certificates.get('server/intermediate.pem')),
- encoding='utf-8').strip(),
- str(b64decode(certificates.get('server/root.pem')),
- encoding='utf-8').strip()).encode()), encoding='utf-8')
- name = domain_name.split('.')[0]
- client_secret_name = '{}-client'.format(name)
- client_auth_ingress_name = '-client-auth.'.join(domain_name.split('.'))
-
- # Create server certificate secret
- # If users verify gRpc in external gateway,
- # `AUTHORIZATION_MODE` should be set to `EXTERNAL`.
- if os.environ.get('AUTHORIZATION_MODE') != 'EXTERNAL':
- client.create_or_update_secret(
- data={
- 'ca.crt': certificates.get('server/intermediate.pem'),
- 'tls.crt': certificates.get('server/server.pem'),
- 'tls.key': certificates.get('server/server.key')
- },
- metadata={
- 'name': SERVER_SECRET_NAME,
- 'namespace': 'default'
- },
- secret_type='Opaque',
- name=SERVER_SECRET_NAME
- )
- client.create_or_update_secret(
- data={
- 'ca.crt': server_all_pem
- },
- metadata={
- 'name': CA_SECRET_NAME,
- 'namespace': 'default'
- },
- secret_type='Opaque',
- name=CA_SECRET_NAME
- )
- # TODO: Support multiple participants
- operator = client.get_deployment(OPERATOR_NAME)
- new_args = list(filter(lambda arg: not arg.startswith('--ingress'),
- operator.spec.template.spec.containers[0].args))
- new_args.extend([
- '--ingress-extra-host-suffix=".{}"'.format(domain_name),
- '--ingress-client-auth-secret-name="default/ca-secret"',
- '--ingress-enabled-client-auth=true',
- '--ingress-secret-name={}'.format(SERVER_SECRET_NAME)])
- operator.spec.template.spec.containers[0].args = new_args
- client.create_or_update_deployment(metadata=operator.metadata,
- spec=operator.spec,
- name=OPERATOR_NAME)
-
- # Create client certificate secret
- client.create_or_update_secret(
- data={
- 'client.pem': certificates.get('client/intermediate.pem'),
- 'client.key': certificates.get('client/client.key'),
- 'all.pem': client_all_pem
- },
- metadata={
- 'name': client_secret_name
- },
- secret_type='Opaque',
- name=client_secret_name
- )
-
- # Update ingress-nginx-controller to load client secret
- ingress_nginx_controller = client.get_deployment(
- INGRESS_NGINX_CONTROLLER_NAME
- )
- volumes = ingress_nginx_controller.spec.template.spec.volumes or []
- volumes = list(filter(lambda volume: volume.name != client_secret_name,
- volumes))
- volumes.append({
- 'name': client_secret_name,
- 'secret': {
- 'secretName': client_secret_name
- }
- })
- volume_mounts = ingress_nginx_controller.spec.template\
- .spec.containers[0].volume_mounts or []
- volume_mounts = list(filter(
- lambda mount: mount.name != client_secret_name, volume_mounts))
- volume_mounts.append(
- {
- 'mountPath': '/etc/{}/client/'.format(name),
- 'name': client_secret_name
- })
- ingress_nginx_controller.spec.template.spec.volumes = volumes
- ingress_nginx_controller.spec.template\
- .spec.containers[0].volume_mounts = volume_mounts
- client.create_or_update_deployment(
- metadata=ingress_nginx_controller.metadata,
- spec=ingress_nginx_controller.spec,
- name=INGRESS_NGINX_CONTROLLER_NAME
- )
- # TODO: check ingress-nginx-controller's health
-
- # Create ingress to forward request to peer
- client.create_or_update_service(
- metadata={
- 'name': name,
- 'namespace': 'default'
- },
- spec={
- 'externalName': ip,
- 'type': 'ExternalName'
- },
- name=name
- )
- configuration_snippet_template = 'grpc_next_upstream_tries 5;\n'\
- 'grpc_set_header Host {0};\n'\
- 'grpc_set_header Authority {0};'
- configuration_snippet = \
- configuration_snippet_template.format(custom_host or '$http_x_host')
- client.create_or_update_ingress(
- metadata={
- 'name': domain_name,
- 'namespace': 'default',
- 'annotations': {
- 'kubernetes.io/ingress.class': 'nginx',
- 'nginx.ingress.kubernetes.io/backend-protocol': 'GRPCS',
- 'nginx.ingress.kubernetes.io/http2-insecure-port': 't',
- 'nginx.ingress.kubernetes.io/configuration-snippet':
- configuration_snippet
- }
- },
- spec={
- 'rules': [{
- 'host': domain_name,
- 'http': {
- 'paths': [
- {
- 'path': '/',
- 'backend': {
- 'serviceName': name,
- 'servicePort': port
- }
- }
- ]
- }
- }]
- },
- name=domain_name
- )
- # In most case with external authorization mode,
- # secrets are created by helm charts (deploy/charts/fedlearner-add-on).
- # So use `ingress-nginx` as default.
- # FIXME: change when supporting multi-peer
- secret_path = name if os.environ.get('AUTHORIZATION_MODE') != 'EXTERNAL' \
- else 'ingress-nginx'
- server_snippet_template = \
- 'grpc_ssl_verify on;\n'\
- 'grpc_ssl_server_name on;\n'\
- 'grpc_ssl_name {0};\n'\
- 'grpc_ssl_trusted_certificate /etc/{1}/client/all.pem;\n'\
- 'grpc_ssl_certificate /etc/{1}/client/client.pem;\n'\
- 'grpc_ssl_certificate_key /etc/{1}/client/client.key;'
- server_snippet = server_snippet_template.format(
- custom_host or '$http_x_host', secret_path)
- client.create_or_update_ingress(
- metadata={
- 'name': client_auth_ingress_name,
- 'namespace': 'default',
- 'annotations': {
- 'kubernetes.io/ingress.class': 'nginx',
- 'nginx.ingress.kubernetes.io/backend-protocol': 'GRPCS',
- 'nginx.ingress.kubernetes.io/http2-insecure-port': 't',
- 'nginx.ingress.kubernetes.io/configuration-snippet':
- configuration_snippet,
- 'nginx.ingress.kubernetes.io/server-snippet': server_snippet
- }
- },
- spec={
- 'rules': [{
- 'host': client_auth_ingress_name,
- 'http': {
- 'paths': [
- {
- 'path': '/',
- 'backend': {
- 'serviceName': name,
- 'servicePort': port
- }
- }
- ]
- }
- }]
- },
- name=client_auth_ingress_name
- )
diff --git a/web_console_v2/api/fedlearner_webconsole/project/apis.py b/web_console_v2/api/fedlearner_webconsole/project/apis.py
index c13a6a3ea..97e00c83b 100644
--- a/web_console_v2/api/fedlearner_webconsole/project/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/project/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
@@ -13,282 +13,498 @@
# limitations under the License.
# coding: utf-8
-# pylint: disable=raise-missing-from
-
-import re
from enum import Enum
-from uuid import uuid4
+from functools import partial
+from http import HTTPStatus
+from typing import Optional, Dict, Any, List
-from sqlalchemy.sql import func
-from flask import request
-from flask_restful import Resource, Api, reqparse
from google.protobuf.json_format import ParseDict
+from flask_restful import Resource, Api
+from marshmallow import Schema, fields, validate, post_load
+from marshmallow.validate import Length
+from envs import Envs
+from fedlearner_webconsole.audit.decorators import emits_event
from fedlearner_webconsole.db import db
-from fedlearner_webconsole.project.models import Project
-from fedlearner_webconsole.proto.common_pb2 import Variable, StatusCode
-from fedlearner_webconsole.proto.project_pb2 \
- import Project as ProjectProto, CertificateStorage, \
- Participant as ParticipantProto
-from fedlearner_webconsole.project.add_on \
- import parse_certificates, verify_certificates, create_add_on
+from fedlearner_webconsole.iam.client import create_iams_for_resource
+from fedlearner_webconsole.iam.iam_required import iam_required
+from fedlearner_webconsole.iam.permission import Permission
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.participant.models import ProjectParticipant
+from fedlearner_webconsole.project.controllers import PendingProjectRpcController
+from fedlearner_webconsole.project.models import Project, PendingProjectState, ProjectRole, PendingProject
+from fedlearner_webconsole.project.services import ProjectService, PendingProjectService
+from fedlearner_webconsole.proto.common_pb2 import StatusCode, Variable
from fedlearner_webconsole.exceptions \
- import InvalidArgumentException, NotFoundException
+ import InvalidArgumentException, NotFoundException, ResourceConflictException, InternalException
+from fedlearner_webconsole.proto.project_pb2 import ProjectConfig
+from fedlearner_webconsole.proto.review_pb2 import TicketType, TicketDetails
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
from fedlearner_webconsole.rpc.client import RpcClient
-from fedlearner_webconsole.utils.decorators import jwt_required
-from fedlearner_webconsole.utils.k8s_client import k8s_client
-from fedlearner_webconsole.workflow.models import Workflow
-
-_CERTIFICATE_FILE_NAMES = [
- 'client/client.pem', 'client/client.key', 'client/intermediate.pem',
- 'client/root.pem', 'server/server.pem', 'server/server.key',
- 'server/intermediate.pem', 'server/root.pem'
-]
-
-_URL_REGEX = r'(?:^((?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])(?:\.' \
- r'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])){3})(?::+' \
- r'(\d+))?$)|(?:^\[((?:(?:[0-9a-fA-F:]){1,4}(?:(?::(?:[0-9a-fA-F]' \
- r'){1,4}|:)){2,7})+)\](?::+(\d+))?|((?:(?:[0-9a-fA-F:]){1,4}(?:(' \
- r'?::(?:[0-9a-fA-F]){1,4}|:)){2,7})+)$)'
+from fedlearner_webconsole.rpc.v2.project_service_client import ProjectServiceClient
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.utils.decorators.pp_flask import input_validator, use_args, use_kwargs
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.flask_utils import get_current_user, make_flask_response, FilterExpField
class ErrorMessage(Enum):
PARAM_FORMAT_ERROR = 'Format of parameter {} is wrong: {}'
- NAME_CONFLICT = 'Project name {} has been used.'
+
+
+def _add_variable(config: Optional[Dict], field: str, value: Any) -> Dict:
+ config = config or {}
+ config['variables'] = config.get('variables', [])
+ for item in config['variables']:
+ if item['name'] == field:
+ return config
+ config['variables'].append({'name': field, 'value': value})
+ return config
+
+
+class CreateProjectParameter(Schema):
+ name = fields.String(required=True)
+ config = fields.Dict(load_default={})
+ # System does not support multiple participants now
+ participant_ids = fields.List(fields.Integer(), validate=Length(equal=1))
+ comment = fields.String(load_default='')
+
+
+class CreatePendingProjectParameter(Schema):
+ name = fields.String(required=True)
+ config = fields.Dict(load_default={})
+ participant_ids = fields.List(fields.Integer(), validate=Length(min=1))
+ comment = fields.String(load_default='')
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'] = ParseDict(data['config'], ProjectConfig(), ignore_unknown_fields=True)
+ return data
class ProjectsApi(Resource):
- @jwt_required()
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('name',
- required=True,
- type=str,
- help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
- 'name', 'Empty'))
- parser.add_argument('config',
- required=True,
- type=dict,
- help=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
- 'config', 'Empty'))
- parser.add_argument('comment')
- data = parser.parse_args()
+
+ @input_validator
+ @credentials_required
+ @iam_required(Permission.PROJECTS_POST)
+ @emits_event(audit_fields=['participant_ids'])
+ @use_args(CreateProjectParameter())
+ def post(self, data: Dict):
+ """Creates a new project.
+ ---
+ tags:
+ - project
+ description: Creates a new project
+ parameters:
+ - in: body
+ name: body
+ schema:
+ $ref: '#/definitions/CreateProjectParameter'
+ responses:
+ 201:
+ description: Created a project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Project'
+ """
name = data['name']
config = data['config']
comment = data['comment']
-
- if Project.query.filter_by(name=name).first() is not None:
- raise InvalidArgumentException(
- details=ErrorMessage.NAME_CONFLICT.value.format(name))
-
- if config.get('participants') is None:
- raise InvalidArgumentException(
- details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
- 'participants', 'Empty'))
- if len(config.get('participants')) != 1:
- # TODO: remove limit after operator supports multiple participants
- raise InvalidArgumentException(
- details='Currently not support multiple participants.')
-
- # exact configuration from variables
- # TODO: one custom host for one participant
- grpc_ssl_server_host = None
- egress_host = None
- for variable in config.get('variables', []):
- if variable.get('name') == 'GRPC_SSL_SERVER_HOST':
- grpc_ssl_server_host = variable.get('value')
- if variable.get('name') == 'EGRESS_HOST':
- egress_host = variable.get('value')
-
- # parse participant
- certificates = {}
- for participant in config.get('participants'):
- if 'name' not in participant.keys() or \
- 'domain_name' not in participant.keys():
- raise InvalidArgumentException(
- details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
- 'participants', 'Participant must have name and '
- 'domain_name.'))
- domain_name = participant.get('domain_name')
- # Grpc spec
- participant['grpc_spec'] = {
- 'authority':
- egress_host or '{}-client-auth.com'.format(domain_name[:-4])
- }
-
- if participant.get('certificates'):
- # If users use web console to create add-on,
- # peer url must be given
- if 'url' not in participant.keys():
+ participant_ids = data['participant_ids']
+ with db.session_scope() as session:
+ if session.query(Project).filter_by(name=name).first() is not None:
+ raise ResourceConflictException(message=f'Project name {name} has been used.')
+
+ with db.session_scope() as session:
+ try:
+ user = get_current_user()
+ # defensive programming, if user is none, wont query user.username
+ new_project = Project(name=name, comment=comment, creator=user and user.username)
+ config = _add_variable(config, 'storage_root_path', Envs.STORAGE_ROOT)
+ try:
+ new_project.set_config(ParseDict(config, ProjectConfig()))
+ except Exception as e:
raise InvalidArgumentException(
- details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
- 'participants', 'Participant must have url.'))
- if re.match(_URL_REGEX, participant.get('url')) is None:
- raise InvalidArgumentException('URL pattern is wrong')
-
- current_cert = parse_certificates(
- participant.get('certificates'))
- success, err = verify_certificates(current_cert)
- if not success:
- raise InvalidArgumentException(err)
- certificates[domain_name] = {'certs': current_cert}
- if 'certificates' in participant.keys():
- participant.pop('certificates')
-
- new_project = Project()
- # generate token
- # If users send a token, then use it instead.
- # If `token` is None, generate a new one by uuid.
- config['name'] = name
- token = config.get('token', uuid4().hex)
- config['token'] = token
-
- # check format of config
- try:
- new_project.set_config(ParseDict(config, ProjectProto()))
- except Exception as e:
- raise InvalidArgumentException(
- details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
- 'config', e))
- new_project.set_certificate(
- ParseDict({'domain_name_to_cert': certificates},
- CertificateStorage()))
- new_project.name = name
- new_project.token = token
- new_project.comment = comment
-
- # create add on
- for participant in new_project.get_config().participants:
- if participant.domain_name in\
- new_project.get_certificate().domain_name_to_cert.keys():
- _create_add_on(
- participant,
- new_project.get_certificate().domain_name_to_cert[
- participant.domain_name], grpc_ssl_server_host)
- try:
- new_project = db.session.merge(new_project)
- db.session.commit()
- except Exception as e:
- raise InvalidArgumentException(details=str(e))
-
- return {'data': new_project.to_dict()}
-
- @jwt_required()
+ details=ErrorMessage.PARAM_FORMAT_ERROR.value.format('config', e)) from e
+ session.add(new_project)
+ session.flush()
+
+ for participant_id in participant_ids:
+ # insert a relationship into the table
+ new_relationship = ProjectParticipant(project_id=new_project.id, participant_id=participant_id)
+ session.add(new_relationship)
+
+ create_iams_for_resource(new_project, user)
+ session.commit()
+ except Exception as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ return make_flask_response(data=new_project.to_proto(), status=HTTPStatus.CREATED)
+
+ @credentials_required
def get(self):
- # TODO: Not count soft-deleted workflow
- projects = db.session.query(
- Project, func.count(Workflow.id).label('num_workflow'))\
- .join(Workflow, Workflow.project_id == Project.id, isouter=True)\
- .group_by(Project.id)\
- .all()
- result = []
- for project in projects:
- project_dict = project.Project.to_dict()
- project_dict['num_workflow'] = project.num_workflow
- result.append(project_dict)
- return {'data': result}
+ """Gets all projects.
+ ---
+ tags:
+ - project
+ description: gets all projects.
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ProjectRef'
+ """
+ with db.session_scope() as session:
+ service = ProjectService(session)
+ return make_flask_response(data=service.get_projects())
class ProjectApi(Resource):
- @jwt_required()
- def get(self, project_id):
- project = Project.query.filter_by(id=project_id).first()
- if project is None:
- raise NotFoundException(
- f'Failed to find project: {project_id}')
- return {'data': project.to_dict()}
-
- @jwt_required()
- def patch(self, project_id):
- project = Project.query.filter_by(id=project_id).first()
- if project is None:
- raise NotFoundException(
- f'Failed to find project: {project_id}')
- config = project.get_config()
- if request.json.get('token') is not None:
- new_token = request.json.get('token')
- config.token = new_token
- project.token = new_token
- if request.json.get('variables') is not None:
- del config.variables[:]
- config.variables.extend([
- ParseDict(variable, Variable())
- for variable in request.json.get('variables')
- ])
-
- # exact configuration from variables
- grpc_ssl_server_host = None
- egress_host = None
- for variable in config.variables:
- if variable.name == 'GRPC_SSL_SERVER_HOST':
- grpc_ssl_server_host = variable.value
- if variable.name == 'EGRESS_HOST':
- egress_host = variable.value
-
- if request.json.get('participant_name'):
- config.participants[0].name = request.json.get('participant_name')
-
- if request.json.get('comment'):
- project.comment = request.json.get('comment')
-
- for participant in config.participants:
- if participant.domain_name in\
- project.get_certificate().domain_name_to_cert.keys():
- _create_add_on(
- participant,
- project.get_certificate().domain_name_to_cert[
- participant.domain_name], grpc_ssl_server_host)
- if egress_host:
- participant.grpc_spec.authority = egress_host
- project.set_config(config)
- try:
- db.session.commit()
- except Exception as e:
- raise InvalidArgumentException(details=e)
- return {'data': project.to_dict()}
+
+ @credentials_required
+ @iam_required(Permission.PROJECT_GET)
+ def get(self, project_id: int):
+ """Gets a project.
+ ---
+ tags:
+ - project
+ description: Gets a project
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Project'
+ """
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(id=project_id).first()
+ if project is None:
+ raise NotFoundException(f'Failed to find project: {project_id}')
+ return make_flask_response(data=project.to_proto(), status=HTTPStatus.OK)
+
+ @input_validator
+ @credentials_required
+ @iam_required(Permission.PROJECT_PATCH)
+ @emits_event(audit_fields=['variables'])
+ @use_kwargs({
+ 'comment': fields.String(load_default=None),
+ 'variables': fields.List(fields.Dict(), load_default=None),
+ 'config': fields.Dict(load_default=None)
+ })
+ def patch(self, project_id: int, comment: Optional[str], variables: Optional[List[Dict]], config: Optional[Dict]):
+ """Patch a project.
+ ---
+ tags:
+ - project
+ description: Update a project.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: body
+ name: body
+ schema:
+ type: object
+ properties:
+ comment:
+ type: string
+ variables:
+ description: A list of variables to override existing ones.
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Variable'
+ config:
+ description: Config of project, include variables.
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ProjectConfig'
+ responses:
+ 200:
+ description: Updated project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Project'
+ """
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(id=project_id).first()
+ if project is None:
+ raise NotFoundException(f'Failed to find project: {project_id}')
+
+ if comment:
+ project.comment = comment
+
+ if config is not None:
+ config_proto = ParseDict(config, ProjectConfig(), ignore_unknown_fields=True)
+ project.set_config(config_proto)
+ session.flush()
+ # TODO(xiangyuxuan.prs): remove variables parameter when pending project launch
+ if variables is not None:
+ # Overrides all variables
+ variables = [ParseDict(variable, Variable()) for variable in variables]
+ project.set_variables(variables)
+ try:
+ session.commit()
+ except Exception as e:
+ raise InvalidArgumentException(details=e) from e
+
+ return make_flask_response(data=project.to_proto(), status=HTTPStatus.OK)
class CheckConnectionApi(Resource):
- @jwt_required()
- def post(self, project_id):
- project = Project.query.filter_by(id=project_id).first()
- if project is None:
- raise NotFoundException(
- f'Failed to find project: {project_id}')
- success = True
- details = []
- # TODO: Concurrently check
- for participant in project.get_config().participants:
- result = self.check_connection(project.get_config(), participant)
- success = success & (result.code == StatusCode.STATUS_SUCCESS)
+
+ @credentials_required
+ def get(self, project_id: int):
+ """Checks the connection for a project.
+ ---
+ tags:
+ - project
+ description: Checks the connection for a project.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ success:
+ description: If the connection is established or not.
+ type: boolean
+ message:
+ type: string
+ """
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(id=project_id).first()
+ if project is None:
+ raise NotFoundException(f'Failed to find project: {project_id}')
+ service = ParticipantService(session)
+ participants = service.get_platform_participants_by_project(project.id)
+
+ error_messages = []
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(project.name, project.token, participant.domain_name)
+ result = client.check_connection().status
if result.code != StatusCode.STATUS_SUCCESS:
- details.append(result.msg)
- return {'data': {'success': success, 'details': details}}
+ error_messages.append(
+ f'failed to validate {participant.domain_name}\'s workspace, result: {result.msg}')
- def check_connection(self, project_config: ProjectProto,
- participant_proto: ParticipantProto):
- client = RpcClient(project_config, participant_proto)
- return client.check_connection().status
+ return {
+ 'data': {
+ 'success': len(error_messages) == 0,
+ 'message': '\n'.join(error_messages) if len(error_messages) > 0 else 'validate project successfully!'
+ }
+ }, HTTPStatus.OK
+
+
+class ProjectParticipantsApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int):
+ """Gets participants of a project.
+ ---
+ tags:
+ - project
+ description: Gets participants of a project.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.Participant'
+ """
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(id=project_id).first()
+ if project is None:
+ raise NotFoundException(f'Failed to find project: {project_id}')
+ service = ParticipantService(session)
+ participants = service.get_participants_by_project(project_id)
+ return make_flask_response(data=[participant.to_proto() for participant in participants])
+
+
+class PendingProjectsApi(Resource):
+
+ @input_validator
+ @credentials_required
+ @iam_required(Permission.PROJECTS_POST)
+ @use_args(CreatePendingProjectParameter())
+ def post(self, data: Dict):
+ """Creates a new pending project.
+ ---
+ tags:
+ - project
+ description: Creates a new pending project
+ parameters:
+ - in: body
+ name: body
+ schema:
+ $ref: '#/definitions/CreatePendingProjectParameter'
+ responses:
+ 201:
+ description: Created a pending project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.PendingProjectPb'
+ """
+ with db.session_scope() as session:
+ # TODO(xiangyuxuan.prs): remove after using token instead of name to ensure consistency of project
+ if PendingProjectService(session).duplicated_name_exists(data['name']):
+ raise ResourceConflictException(f'{data["name"]} has already existed')
+ participants_info = PendingProjectService(session).build_participants_info(data['participant_ids'])
+ pending_project = PendingProjectService(session).create_pending_project(data['name'],
+ data['config'],
+ participants_info,
+ data['comment'],
+ get_current_user().username,
+ state=PendingProjectState.ACCEPTED,
+ role=ProjectRole.COORDINATOR)
+ session.flush()
+ ticket_helper = get_ticket_helper(session)
+ ticket_helper.create_ticket(TicketType.CREATE_PROJECT, TicketDetails(uuid=pending_project.uuid))
+ session.commit()
+ return make_flask_response(data=pending_project.to_proto(), status=HTTPStatus.CREATED)
+
+ @credentials_required
+ @use_args(
+ {
+ 'filter': FilterExpField(
+ required=False,
+ load_default=None,
+ ),
+ 'page': fields.Integer(required=False, load_default=1),
+ 'page_size': fields.Integer(required=False, load_default=10)
+ },
+ location='query')
+ def get(self, params: dict):
+ """Gets all pending projects.
+ ---
+ tags:
+ - project
+ description: gets all pending projects.
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.PendingProjectPb'
+ """
+ with db.session_scope() as session:
+ try:
+ pagination = PendingProjectService(session).list_pending_projects(
+ filter_exp=params['filter'],
+ page=params['page'],
+ page_size=params['page_size'],
+ )
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ data = [t.to_proto() for t in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+
+class PendingProjectApi(Resource):
+
+ @credentials_required
+ @use_kwargs({
+ 'state':
+ fields.String(required=True,
+ validate=validate.OneOf([PendingProjectState.ACCEPTED.name, PendingProjectState.CLOSED.name]))
+ })
+ def patch(self, pending_project_id: int, state: str):
+ """Accept or refuse a pending project.
+ ---
+ tags:
+ - project
+ description: Accept or refuse a pending project.
+ parameters:
+ - in: path
+ name: pending_project_id
+ schema:
+ type: integer
+ - in: body
+ name: body
+ schema:
+ type: object
+ properties:
+ state:
+ type: string
+ responses:
+ 200:
+ description: a pending project
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.PendingProjectPb'
+ """
+ with db.session_scope() as session:
+ pending_project = PendingProjectService(session).update_state_as_participant(pending_project_id, state)
+ resp = PendingProjectRpcController(pending_project).sync_pending_project_state_to_coordinator(
+ uuid=pending_project.uuid, state=PendingProjectState(state))
+ if not resp.succeeded:
+ raise InternalException(f'connect to coordinator failed: {resp.msg}')
+ session.commit()
+ return make_flask_response(data=pending_project.to_proto())
+
+ def delete(self, pending_project_id: int):
+ """Delete pending project by id.
+ ---
+ tags:
+ - project
+ description: Delete pending project.
+ parameters:
+ - in: path
+ name: pending_project_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the pending project
+ responses:
+ 204:
+ description: No content.
+ """
+ with db.session_scope() as session:
+ pending_project = session.query(PendingProject).get(pending_project_id)
+ if pending_project is None:
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+ result = PendingProjectRpcController(pending_project).send_to_participants(
+ partial(ProjectServiceClient.delete_pending_project, uuid=pending_project.uuid))
+ if not all(resp.succeeded for resp in result.values()):
+ raise InternalException(f'delete participants failed: {result}')
+ with db.session_scope() as session:
+ session.delete(pending_project)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
def initialize_project_apis(api: Api):
api.add_resource(ProjectsApi, '/projects')
api.add_resource(ProjectApi, '/projects/')
- api.add_resource(CheckConnectionApi,
- '/projects//connection_checks')
-
-
-def _create_add_on(participant, certificate, grpc_ssl_server_host=None):
- if certificate is None:
- return
- # check validation
- for file_name in _CERTIFICATE_FILE_NAMES:
- if certificate.certs.get(file_name) is None:
- raise InvalidArgumentException(
- details=ErrorMessage.PARAM_FORMAT_ERROR.value.format(
- 'certificates', '{} not existed'.format(file_name)))
- try:
- create_add_on(k8s_client, participant.domain_name, participant.url,
- certificate.certs, grpc_ssl_server_host)
- except RuntimeError as e:
- raise InvalidArgumentException(details=str(e))
+ api.add_resource(ProjectParticipantsApi, '/projects//participants')
+ api.add_resource(CheckConnectionApi, '/projects//connection_checks')
+
+ api.add_resource(PendingProjectsApi, '/pending_projects')
+ api.add_resource(PendingProjectApi, '/pending_project/')
+
+ schema_manager.append(CreateProjectParameter)
+ schema_manager.append(CreatePendingProjectParameter)
diff --git a/web_console_v2/api/fedlearner_webconsole/project/apis_test.py b/web_console_v2/api/fedlearner_webconsole/project/apis_test.py
new file mode 100644
index 000000000..4e5b7d487
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/apis_test.py
@@ -0,0 +1,392 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import json
+import unittest
+from datetime import datetime, timezone
+
+from http import HTTPStatus
+from unittest.mock import patch, MagicMock
+
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant, ParticipantType
+from fedlearner_webconsole.project.apis import _add_variable
+from fedlearner_webconsole.project.controllers import ParticipantResp
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.utils.pp_time import sleep
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project, PendingProject, PendingProjectState, ProjectRole
+from fedlearner_webconsole.proto.project_pb2 import ProjectConfig, ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.workflow.models import Workflow
+from testing.common import BaseTestCase
+
+
+class ProjectApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.signin_as_admin()
+ self.default_project = Project()
+ self.default_project.name = 'test-default_project'
+ self.default_project.set_config(ParseDict({'variables': [{'name': 'test', 'value': 'test'}]}, ProjectConfig()))
+ self.default_project.comment = 'test comment'
+
+ workflow = Workflow(name='workflow_key_get1', project_id=1)
+ participant = Participant(name='test-participant', domain_name='fl-test.com', host='127.0.0.1', port=32443)
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ with db.session_scope() as session:
+ session.add(self.default_project)
+ session.add(workflow)
+ session.add(participant)
+ session.add(relationship)
+ session.commit()
+
+ def test_add_variable(self):
+ # test none
+ self.assertEqual(_add_variable(None, 'storage_root_path', '/data'),
+ {'variables': [{
+ 'name': 'storage_root_path',
+ 'value': '/data'
+ }]})
+ # test variables is []
+ self.assertEqual(_add_variable({'variables': []}, 'storage_root_path', '/data'),
+ {'variables': [{
+ 'name': 'storage_root_path',
+ 'value': '/data'
+ }]})
+ # test has other variables
+ self.assertEqual(
+ _add_variable({'variables': [{
+ 'name': 'test-post',
+ 'value': 'test'
+ }]}, 'storage_root_path', '/data'),
+ {'variables': [{
+ 'name': 'test-post',
+ 'value': 'test'
+ }, {
+ 'name': 'storage_root_path',
+ 'value': '/data'
+ }]})
+ # test already set storage_root_path
+ self.assertEqual(
+ _add_variable({'variables': [{
+ 'name': 'storage_root_path',
+ 'value': '/fake_data'
+ }]}, 'storage_root_path', '/data'), {'variables': [{
+ 'name': 'storage_root_path',
+ 'value': '/fake_data'
+ }]})
+
+ def test_get_project(self):
+ get_response = self.get_helper('/api/v2/projects/1')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+ queried_project = json.loads(get_response.data).get('data')
+ with db.session_scope() as session:
+ project_in_db = session.query(Project).get(1)
+ self.assertEqual(queried_project, to_dict(project_in_db.to_proto()))
+
+ def test_get_not_found_project(self):
+ get_response = self.get_helper(f'/api/v2/projects/{1000}')
+ self.assertEqual(get_response.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_post_project_by_new_participant(self):
+ name = 'test-post-project'
+ comment = 'test post project(by new participant)'
+ config = {'variables': [{'name': 'test-post', 'value': 'test'}]}
+ create_response = self.post_helper('/api/v2/projects',
+ data={
+ 'name': name,
+ 'comment': comment,
+ 'config': config,
+ 'participant_ids': [2]
+ })
+
+ self.assertEqual(create_response.status_code, HTTPStatus.CREATED)
+ created_project = self.get_response_data(create_response)
+
+ with db.session_scope() as session:
+ relationship = session.query(ProjectParticipant).all()
+ queried_project = session.query(Project).filter_by(name=name).first()
+
+ self.assertEqual((relationship[1].project_id, relationship[1].participant_id), (2, 2))
+ self.assertEqual(created_project, to_dict(queried_project.to_proto()))
+
+ def test_post_conflict_name_project(self):
+ create_response = self.post_helper('/api/v2/projects',
+ data={
+ 'name': self.default_project.name,
+ 'participant_ids': [1],
+ })
+ self.assertEqual(create_response.status_code, HTTPStatus.CONFLICT)
+
+ def test_list_project(self):
+ list_response = self.get_helper('/api/v2/projects')
+ project_list = self.get_response_data(list_response)
+ self.assertEqual(len(project_list), 1)
+ project_id = project_list[0]['id']
+ with db.session_scope() as session:
+ queried_project = session.query(Project).get(project_id)
+ ref = queried_project.to_ref()
+ ref.num_workflow = 1
+ self.assertEqual(project_list[0], to_dict(ref))
+
+ def test_update_project(self):
+ updated_comment = 'updated comment'
+ variables = [{'name': 'test-variables', 'value': 'variables'}]
+ update_response = self.patch_helper('/api/v2/projects/1',
+ data={
+ 'comment': updated_comment,
+ 'variables': variables,
+ })
+ self.assertEqual(update_response.status_code, HTTPStatus.OK)
+ # test response
+ project = self.get_response_data(update_response)
+ self.assertEqual(project['comment'], updated_comment)
+ self.assertEqual(project['variables'], [{
+ 'access_mode': 'UNSPECIFIED',
+ 'name': 'test-variables',
+ 'value': 'variables',
+ 'value_type': 'STRING',
+ 'tag': '',
+ 'widget_schema': ''
+ }])
+ # test database
+ get_response = self.get_helper('/api/v2/projects/1')
+ project = self.get_response_data(get_response)
+ self.assertEqual(project['comment'], updated_comment)
+
+ def test_update_project_config(self):
+ config = {'variables': [{'name': 'test-variables', 'value': 'variables'}]}
+ update_response = self.patch_helper('/api/v2/projects/1', data={
+ 'config': config,
+ })
+ self.assertEqual(update_response.status_code, HTTPStatus.OK)
+ # test database
+ get_response = self.get_helper('/api/v2/projects/1')
+ project = self.get_response_data(get_response)
+ config = {
+ 'abilities': [],
+ 'action_rules': {},
+ 'support_blockchain':
+ False,
+ 'variables': [{
+ 'access_mode': 'UNSPECIFIED',
+ 'name': 'test-variables',
+ 'tag': '',
+ 'value': 'variables',
+ 'value_type': 'STRING',
+ 'widget_schema': ''
+ }]
+ }
+ self.assertEqual(project['config'], config)
+
+ def test_update_not_found_project(self):
+ updated_comment = 'updated comment'
+ update_response = self.patch_helper(f'/api/v2/projects/{1000}', data={'comment': updated_comment})
+ self.assertEqual(update_response.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_post_project_with_multiple_participants(self):
+ create_response = self.post_helper('/api/v2/projects',
+ data={
+ 'name': 'test name',
+ 'comment': 'test comment',
+ 'participant_ids': [1, 2, 3]
+ })
+ self.assertEqual(create_response.status_code, HTTPStatus.BAD_REQUEST)
+
+ def test_post_project_with_light_client(self):
+ with db.session_scope() as session:
+ light_participant = Participant(name='light-client',
+ type=ParticipantType.LIGHT_CLIENT,
+ domain_name='fl-light-client.com',
+ host='127.0.0.1',
+ port=32443)
+ session.add(light_participant)
+ session.commit()
+ resp = self.post_helper('/api/v2/projects',
+ data={
+ 'name': 'test-project',
+ 'comment': 'test comment',
+ 'participant_ids': [light_participant.id]
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(name='test-project').first()
+ self.assertEqual(project.participants[0].name, 'light-client')
+ self.assertEqual(project.get_participant_type(), ParticipantType.LIGHT_CLIENT)
+
+
+class ProjectParticipantsApi(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+
+ self.participant_1 = Participant(name='participant pro1', domain_name='fl-participant-1.com')
+ self.participant_2 = Participant(name='participant pro2', domain_name='fl-participant-2.com')
+ self.project_1 = Project(name='project 1')
+ self.relationship_11 = ProjectParticipant(project_id=1, participant_id=1)
+ self.relationship_12 = ProjectParticipant(project_id=1, participant_id=2)
+ with db.session_scope() as session:
+ session.add(self.participant_1)
+ session.flush()
+ sleep(1)
+ session.add(self.participant_2)
+ session.add(self.project_1)
+ session.add(self.relationship_11)
+ session.add(self.relationship_12)
+ session.commit()
+
+ def test_get_project_participants(self):
+ get_response = self.get_helper('/api/v2/projects/1/participants')
+ participants = self.get_response_data(get_response)
+ self.assertEqual(len(participants), 2)
+ self.assertEqual(participants[0]['name'], 'participant pro2')
+ self.assertEqual(participants[0]['pure_domain_name'], 'participant-2')
+ self.assertEqual(participants[1]['name'], 'participant pro1')
+ self.assertEqual(participants[1]['pure_domain_name'], 'participant-1')
+
+
+class PendingProjectsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.participant_1 = Participant(name='participant pro1', domain_name='fl-participant-1.com')
+ self.participant_2 = Participant(name='participant pro2', domain_name='fl-participant-2.com')
+ with db.session_scope() as session:
+ session.add(self.participant_1)
+ session.add(self.participant_2)
+ session.commit()
+
+ def test_post_pending_projects(self):
+ resp = self.post_helper('/api/v2/pending_projects',
+ data={
+ 'name': 'test-project',
+ 'comment': 'test comment',
+ 'config': {
+ 'variables': []
+ },
+ 'participant_ids': [self.participant_1.id, self.participant_2.id]
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ resp = self.post_helper('/api/v2/pending_projects',
+ data={
+ 'name': 'test-project',
+ 'comment': 'test comment',
+ 'participant_ids': []
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ with db.session_scope() as session:
+ pending_project: PendingProject = session.query(PendingProject).filter_by(name='test-project').first()
+ self.assertEqual(pending_project.role, ProjectRole.COORDINATOR)
+ self.assertEqual(pending_project.state, PendingProjectState.ACCEPTED)
+ self.assertEqual(pending_project.creator_username, 'ada')
+ expected_info = ParticipantsInfo(
+ participants_map={
+ 'participant-1':
+ ParticipantInfo(name='participant pro1',
+ role=ProjectRole.PARTICIPANT.name,
+ state=PendingProjectState.PENDING.name,
+ type=ParticipantType.PLATFORM.name),
+ 'participant-2':
+ ParticipantInfo(name='participant pro2',
+ role=ProjectRole.PARTICIPANT.name,
+ state=PendingProjectState.PENDING.name,
+ type=ParticipantType.PLATFORM.name),
+ '':
+ ParticipantInfo(role=ProjectRole.COORDINATOR.name,
+ state=PendingProjectState.ACCEPTED.name,
+ type=ParticipantType.PLATFORM.name)
+ })
+ self.assertEqual(pending_project.get_participants_info(), expected_info)
+
+ @patch('fedlearner_webconsole.project.apis.PendingProjectService.list_pending_projects')
+ def test_get_pending_projects(self, mock_list):
+ created_at = datetime(2022, 5, 10, 0, 0, 0)
+ updated_at = datetime(2022, 5, 10, 0, 0, 0)
+ pending_proj = PendingProject(id=123,
+ name='test',
+ uuid='uuid',
+ state=PendingProjectState.ACCEPTED,
+ role=ProjectRole.PARTICIPANT,
+ comment='test',
+ created_at=created_at,
+ updated_at=updated_at,
+ ticket_status=TicketStatus.PENDING)
+ mock_list.return_value.get_items.return_value = [pending_proj]
+ mock_list.return_value.get_metadata.return_value = {
+ 'current_page': 1,
+ 'page_size': 1,
+ 'total_pages': 1,
+ 'total_items': 1
+ }
+ resp = self.get_helper('/api/v2/pending_projects')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(len(self.get_response_data(resp)), 1)
+
+
+class PendingProjectApiTest(BaseTestCase):
+
+ @patch('fedlearner_webconsole.project.apis.PendingProjectService.update_state_as_participant')
+ @patch('fedlearner_webconsole.project.apis.PendingProjectRpcController.sync_pending_project_state_to_coordinator')
+ def test_patch_pending_project(self, mock_sync: MagicMock, mock_update: MagicMock):
+ pending_project = PendingProject(name='test', state=PendingProjectState.PENDING, role=ProjectRole.PARTICIPANT)
+ with db.session_scope() as session:
+ session.add(pending_project)
+ session.commit()
+ mock_sync.return_value = ParticipantResp(succeeded=True, resp=None, msg='')
+ mock_update.return_value = PendingProject(name='test',
+ state=PendingProjectState.PENDING,
+ role=ProjectRole.PARTICIPANT,
+ created_at=datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc),
+ updated_at=datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc),
+ ticket_status=TicketStatus.APPROVED)
+ resp = self.patch_helper(f'/api/v2/pending_project/{pending_project.id}',
+ data={
+ 'state': PendingProjectState.ACCEPTED.name,
+ })
+ mock_sync.assert_called_once_with(uuid=pending_project.uuid, state=PendingProjectState.ACCEPTED)
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ mock_update.assert_called_once_with(pending_project.id, PendingProjectState.ACCEPTED.name)
+
+ @patch('fedlearner_webconsole.project.apis.PendingProjectRpcController.send_to_participants')
+ def test_delete_pending_project(self, mock_sync_delete: MagicMock):
+ pending_project = PendingProject(name='test', state=PendingProjectState.PENDING, role=ProjectRole.PARTICIPANT)
+ with db.session_scope() as session:
+ session.add(pending_project)
+ session.commit()
+ mock_sync_delete.return_value = {
+ 'a': ParticipantResp(succeeded=True, resp=None, msg=''),
+ 'b': ParticipantResp(succeeded=False, resp=None, msg='aa')
+ }
+ resp = self.delete_helper(f'/api/v2/pending_project/{pending_project.id}')
+ self.assertEqual(resp.status_code, 500)
+ mock_sync_delete.assert_called_once()
+ with db.session_scope() as session:
+ self.assertIsNotNone(session.query(PendingProject).get(pending_project.id))
+ mock_sync_delete.return_value = {
+ 'a': ParticipantResp(succeeded=True, resp=None, msg=''),
+ 'b': ParticipantResp(succeeded=True, resp=None, msg='aa')
+ }
+ self.delete_helper(f'/api/v2/pending_project/{pending_project.id}')
+
+ self.assertEqual(mock_sync_delete.call_count, 2)
+ with db.session_scope() as session:
+ self.assertIsNone(session.query(PendingProject).get(pending_project.id))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/project/controllers.py b/web_console_v2/api/fedlearner_webconsole/project/controllers.py
new file mode 100644
index 000000000..7398ed5b6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/controllers.py
@@ -0,0 +1,77 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Callable, Dict, NamedTuple
+
+import grpc
+from google.protobuf.empty_pb2 import Empty
+from google.protobuf.message import Message
+
+from fedlearner_webconsole.participant.models import ParticipantType
+from fedlearner_webconsole.project.models import PendingProject, ProjectRole, PendingProjectState
+from fedlearner_webconsole.rpc.v2.project_service_client import ProjectServiceClient
+
+
+class ParticipantResp(NamedTuple):
+ succeeded: bool
+ resp: Message
+ msg: str
+
+
+def _get_domain_name(pure_domain_name: str) -> str:
+ """Get domain name from pure_domain_name
+
+ Args:
+ pure_domain_name (str): pure_domain_name
+
+ Returns:
+ str: domain name, like fl-ali-test.com
+ """
+ return f'fl-{pure_domain_name}.com'
+
+
+def _get_resp(pure_domain_name: str, method: Callable) -> ParticipantResp:
+ client = ProjectServiceClient.from_participant(_get_domain_name(pure_domain_name))
+ try:
+ resp = ParticipantResp(True, method(client), '')
+ except grpc.RpcError as e:
+ resp = ParticipantResp(False, Empty(), str(e))
+ return resp
+
+
+class PendingProjectRpcController(object):
+ """A helper to Send Grpc request via participants_info in pending project."""
+
+ def __init__(self, pending_project: PendingProject = None):
+ self._pending_project = pending_project
+
+ def send_to_participants(self, method: Callable) -> Dict[str, ParticipantResp]:
+ if self._pending_project.role == ProjectRole.PARTICIPANT:
+ # when a project is in pending the proxy should not be supported,
+ # which participant used to connect to others via coordinator.
+ raise ValueError('participant cant connect to participant in pending project')
+ resp_map = {}
+ for pure_domain_name, p_info in self._pending_project.get_participants_info().participants_map.items():
+ if p_info.role == ProjectRole.COORDINATOR.name or p_info.type == ParticipantType.LIGHT_CLIENT.name:
+ continue
+
+ resp_map[pure_domain_name] = _get_resp(pure_domain_name, method)
+ return resp_map
+
+ def sync_pending_project_state_to_coordinator(self, uuid: str, state: PendingProjectState) -> ParticipantResp:
+ assert self._pending_project.role == ProjectRole.PARTICIPANT
+ pure_domain, _ = self._pending_project.get_coordinator_info()
+ return _get_resp(pure_domain,
+ lambda client: ProjectServiceClient.sync_pending_project_state(client, uuid, state))
diff --git a/web_console_v2/api/fedlearner_webconsole/project/controllers_test.py b/web_console_v2/api/fedlearner_webconsole/project/controllers_test.py
new file mode 100644
index 000000000..f05d9496f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/controllers_test.py
@@ -0,0 +1,96 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import Mock, call, patch
+
+import grpc
+from google.protobuf.empty_pb2 import Empty
+
+from fedlearner_webconsole.participant.models import ParticipantType
+from fedlearner_webconsole.project.controllers import PendingProjectRpcController, ParticipantResp, _get_domain_name
+from fedlearner_webconsole.project.models import PendingProject, ProjectRole, PendingProjectState
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.rpc.v2.client_base import ParticipantRpcClient
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class FakeRpcClient(ParticipantRpcClient):
+
+ def __init__(self):
+ super().__init__(None)
+
+ def fake_method(self, request: str, succeeded: bool = True):
+ if succeeded:
+ return request
+ raise grpc.RpcError
+
+ def sync_pending_project_state(self, uuid: str, state: PendingProjectState):
+ del uuid, state
+ return Empty()
+
+
+class ProjectControllerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.pending_project = PendingProject(id=1, role=ProjectRole.COORDINATOR)
+ self.pending_project.set_participants_info(
+ ParticipantsInfo(
+ participants_map={
+ 'coordinator': ParticipantInfo(role=ProjectRole.COORDINATOR.name),
+ 'part1': ParticipantInfo(role=ProjectRole.PARTICIPANT.name),
+ 'part2': ParticipantInfo(role=ProjectRole.PARTICIPANT.name),
+ 'part3': ParticipantInfo(role=ProjectRole.PARTICIPANT.name, type=ParticipantType.LIGHT_CLIENT.name),
+ }))
+
+ @patch('fedlearner_webconsole.project.controllers.ProjectServiceClient.from_participant')
+ def test_send_to_all(self, mock_from_participant: Mock):
+ mock_from_participant.return_value = FakeRpcClient()
+ result = PendingProjectRpcController(
+ self.pending_project).send_to_participants(lambda client: FakeRpcClient.fake_method(client, request='test'))
+ mock_from_participant.assert_has_calls([call('fl-part1.com'), call('fl-part2.com')], any_order=True)
+ self.assertEqual(mock_from_participant.call_count, 2)
+ self.assertEqual(
+ result, {
+ 'part1': ParticipantResp(succeeded=True, resp='test', msg=''),
+ 'part2': ParticipantResp(succeeded=True, resp='test', msg='')
+ })
+ # Failed case
+ result = PendingProjectRpcController(self.pending_project).send_to_participants(
+ lambda client: client.fake_method(request='test', succeeded=False))
+ self.assertEqual(
+ result, {
+ 'part1': ParticipantResp(succeeded=False, resp=Empty(), msg=''),
+ 'part2': ParticipantResp(succeeded=False, resp=Empty(), msg='')
+ })
+
+ @patch('fedlearner_webconsole.project.controllers.ProjectServiceClient.from_participant')
+ @patch('fedlearner_webconsole.project.controllers.ProjectServiceClient.sync_pending_project_state',
+ FakeRpcClient.sync_pending_project_state)
+ def test_send_to_coordinator(self, mock_from_participant: Mock):
+ mock_from_participant.return_value = FakeRpcClient()
+ self.pending_project.role = ProjectRole.PARTICIPANT
+ result = PendingProjectRpcController(self.pending_project).sync_pending_project_state_to_coordinator(
+ uuid='test', state=PendingProjectState.ACCEPTED)
+ mock_from_participant.assert_called_once_with('fl-coordinator.com')
+ self.assertEqual(result, ParticipantResp(succeeded=True, resp=Empty(), msg=''))
+
+ def test_get_domain_name(self):
+ self.assertEqual(_get_domain_name('bytedance'), 'fl-bytedance.com')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/project/models.py b/web_console_v2/api/fedlearner_webconsole/project/models.py
index 464d2877d..3b054d22d 100644
--- a/web_console_v2/api/fedlearner_webconsole/project/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/project/models.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,67 +13,217 @@
# limitations under the License.
# coding: utf-8
+import enum
+from typing import Optional, List, Tuple
+from google.protobuf import text_format
from sqlalchemy.sql import func
from sqlalchemy.sql.schema import Index, UniqueConstraint
-from fedlearner_webconsole.utils.mixins import to_dict_mixin
-from fedlearner_webconsole.db import db
+
+from fedlearner_webconsole.proto.project_pb2 import ProjectRef, ParticipantsInfo, ProjectConfig, ParticipantInfo
+from fedlearner_webconsole.utils.base_model.review_ticket_model import ReviewTicketModel
+from fedlearner_webconsole.utils.base_model.softdelete_model import SoftDeleteModel
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.db import db, default_table_args
from fedlearner_webconsole.proto import project_pb2
+from fedlearner_webconsole.participant.models import ParticipantType
+
+
+class PendingProjectState(enum.Enum):
+ PENDING = 'PENDING'
+ ACCEPTED = 'ACCEPTED'
+ FAILED = 'FAILED'
+ CLOSED = 'CLOSED'
+
+
+class ProjectRole(enum.Enum):
+ COORDINATOR = 'COORDINATOR'
+ PARTICIPANT = 'PARTICIPANT'
+
+
+class Action(enum.Enum):
+ ID_ALIGNMENT = 'ID_ALIGNMENT'
+ DATA_ALIGNMENT = 'DATA_ALIGNMENT'
+ HORIZONTAL_TRAIN = 'HORIZONTAL_TRAIN'
+ VERTICAL_TRAIN = 'VERTICAL_TRAIN'
+ VERTICAL_EVAL = 'VERTICAL_EVAL'
+ VERTICAL_PRED = 'VERTICAL_PRED'
+ VERTICAL_SERVING = 'VERTICAL_SERVING'
+ WORKFLOW = 'WORKFLOW'
+ TEE_SERVICE = 'TEE_SERVICE'
+ TEE_RESULT_EXPORT = 'TEE_SERVICE'
-@to_dict_mixin(ignores=['certificate'],
- extras={'config': (lambda project: project.get_config())})
class Project(db.Model):
__tablename__ = 'projects_v2'
- __table_args__ = (UniqueConstraint('name', name='idx_name'),
- Index('idx_token', 'token'), {
- 'comment': 'webconsole projects',
- 'mysql_engine': 'innodb',
- 'mysql_charset': 'utf8mb4',
- })
- id = db.Column(db.Integer,
- primary_key=True,
- autoincrement=True,
- comment='id')
+ __table_args__ = (UniqueConstraint('name', name='idx_name'), Index('idx_token', 'token'), {
+ 'comment': 'webconsole projects',
+ 'mysql_engine': 'innodb',
+ 'mysql_charset': 'utf8mb4',
+ })
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ role = db.Column(db.Enum(ProjectRole, length=32, native_enum=False, create_constraint=False),
+ default=ProjectRole.PARTICIPANT,
+ comment='pending project role')
+ participants_info = db.Column(db.Text(), comment='participants info')
+
name = db.Column(db.String(255), comment='name')
token = db.Column(db.String(64), comment='token')
config = db.Column(db.LargeBinary(), comment='config')
- certificate = db.Column(db.LargeBinary(), comment='certificate')
comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
- created_at = db.Column(db.DateTime(timezone=True),
- server_default=func.now(),
- comment='created at')
+ creator = db.Column(db.String(255), comment='creator')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created at')
updated_at = db.Column(db.DateTime(timezone=True),
onupdate=func.now(),
server_default=func.now(),
comment='updated at')
deleted_at = db.Column(db.DateTime(timezone=True), comment='deleted at')
+ participants = db.relationship('Participant',
+ secondary='projects_participants_v2',
+ primaryjoin='Project.id == foreign(ProjectParticipant.project_id)',
+ secondaryjoin='Participant.id == foreign(ProjectParticipant.participant_id)')
- def set_config(self, proto):
+ def set_config(self, proto: project_pb2.ProjectConfig):
self.config = proto.SerializeToString()
- def get_config(self):
- if self.config is None:
- return None
- proto = project_pb2.Project()
- proto.ParseFromString(self.config)
- return proto
+ def _get_config(self) -> project_pb2.ProjectConfig:
+ config = project_pb2.ProjectConfig()
+ if self.config:
+ config.ParseFromString(self.config)
+ return config
+
+ def get_variables(self) -> List[Variable]:
+ return list(self._get_config().variables)
+
+ def set_variables(self, variables: List[Variable]):
+ config = self._get_config()
+ del config.variables[:]
+ config.variables.extend(variables)
+ self.set_config(config)
- def set_certificate(self, proto):
- self.certificate = proto.SerializeToString()
+ def get_storage_root_path(self, dft_value: str) -> str:
+ variables = self.get_variables()
+ for variable in variables:
+ if variable.name == 'storage_root_path':
+ return variable.value
+ return dft_value
- def get_certificate(self):
- if self.certificate is None:
+ def get_participant_type(self) -> Optional[ParticipantType]:
+ if len(self.participants) == 0:
return None
- proto = project_pb2.CertificateStorage()
- proto.ParseFromString(self.certificate)
+ return self.participants[0].get_type()
+
+ def set_participants_info(self, proto: ParticipantsInfo):
+ self.participants_info = text_format.MessageToString(proto)
+
+ def get_participants_info(self) -> ParticipantsInfo:
+ if self.participants_info is not None:
+ return text_format.Parse(self.participants_info, ParticipantsInfo())
+ return ParticipantsInfo()
+
+ def to_ref(self) -> ProjectRef:
+ participant_type = self.get_participant_type()
+ ref = ProjectRef(id=self.id,
+ name=self.name,
+ creator=self.creator,
+ created_at=to_timestamp(self.created_at),
+ participant_type=participant_type.name if participant_type else None,
+ participants_info=self.get_participants_info(),
+ role=self.role.name if self.role else None)
+ for participant in self.participants:
+ ref.participants.append(participant.to_proto())
+ return ref
+
+ def to_proto(self) -> project_pb2.Project:
+ participant_type = self.get_participant_type()
+ proto = project_pb2.Project(id=self.id,
+ name=self.name,
+ creator=self.creator,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ participant_type=participant_type.name if participant_type else None,
+ token=self.token,
+ comment=self.comment,
+ variables=self.get_variables(),
+ participants_info=self.get_participants_info(),
+ config=self._get_config(),
+ role=self.role.name if self.role else None)
+ for participant in self.participants:
+ proto.participants.append(participant.to_proto())
return proto
- def get_namespace(self):
- config = self.get_config()
- if config is not None:
- variables = self.get_config().variables
- for variable in variables:
- if variable.name == 'namespace':
- return variable.value
- return 'default'
+
+class PendingProject(db.Model, SoftDeleteModel, ReviewTicketModel):
+ __tablename__ = 'pending_projects_v2'
+ __table_args__ = (default_table_args('This is webconsole pending_project table'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ name = db.Column(db.String(255), comment='name')
+ uuid = db.Column(db.String(64), comment='uuid')
+ config = db.Column(db.Text(), comment='config')
+ state = db.Column(db.Enum(PendingProjectState, length=32, native_enum=False, create_constraint=False),
+ nullable=False,
+ default=PendingProjectState.PENDING,
+ comment='pending project stage state')
+ participants_info = db.Column(db.Text(), comment='participants info')
+ role = db.Column(db.Enum(ProjectRole, length=32, native_enum=False, create_constraint=False),
+ nullable=False,
+ default=ProjectRole.PARTICIPANT,
+ comment='pending project role')
+
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
+ creator_username = db.Column(db.String(255), comment='creator')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created at')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ onupdate=func.now(),
+ server_default=func.now(),
+ comment='updated at')
+
+ def set_participants_info(self, proto: ParticipantsInfo):
+ self.participants_info = text_format.MessageToString(proto)
+
+ def get_participants_info(self) -> ParticipantsInfo:
+ if self.participants_info is not None:
+ return text_format.Parse(self.participants_info, ParticipantsInfo())
+ return ParticipantsInfo()
+
+ def set_config(self, proto: ProjectConfig):
+ self.config = text_format.MessageToString(proto)
+
+ def get_config(self) -> ProjectConfig:
+ if self.config is not None:
+ return text_format.Parse(self.config, ProjectConfig())
+ return ProjectConfig()
+
+ def to_proto(self) -> project_pb2.PendingProjectPb:
+ return project_pb2.PendingProjectPb(id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ config=self.get_config(),
+ state=self.state.name,
+ participants_info=self.get_participants_info(),
+ role=self.role.name,
+ comment=self.comment,
+ creator_username=self.creator_username,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ ticket_uuid=self.ticket_uuid,
+ ticket_status=self.ticket_status.name,
+ participant_type=self.get_participant_type())
+
+ def get_participant_info(self, pure_domain: str) -> Optional[ParticipantInfo]:
+ return self.get_participants_info().participants_map.get(pure_domain)
+
+ def get_coordinator_info(self) -> Tuple[str, ParticipantInfo]:
+ for pure_domain, p_info in self.get_participants_info().participants_map.items():
+ if p_info.role == ProjectRole.COORDINATOR.name:
+ return pure_domain, p_info
+ raise ValueError(f'not found coordinator in pending project {self.id}')
+
+ def get_participant_type(self) -> str:
+ # In the short term, the project will only have one type of participants,
+ # make pending project type hack to be the type of the first participant.
+ for info in self.get_participants_info().participants_map.values():
+ if info.role == ProjectRole.PARTICIPANT.name:
+ return info.type
+ return ParticipantType.LIGHT_CLIENT.name
diff --git a/web_console_v2/api/fedlearner_webconsole/project/models_test.py b/web_console_v2/api/fedlearner_webconsole/project/models_test.py
new file mode 100644
index 000000000..c852a9db3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/models_test.py
@@ -0,0 +1,187 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+
+from google.protobuf.struct_pb2 import Value
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant, ParticipantType
+from fedlearner_webconsole.project.models import Project, PendingProject, PendingProjectState, ProjectRole
+from fedlearner_webconsole.proto import project_pb2
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.project_pb2 import ProjectConfig, ProjectRef, ParticipantInfo, \
+ PendingProjectPb, ParticipantsInfo
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ProjectTest(NoWebServerTestCase):
+
+ def test_set_and_get_variables(self):
+ config = ProjectConfig(variables=[
+ Variable(name='old_var', value='old', access_mode=Variable.PEER_READABLE),
+ ])
+ with db.session_scope() as session:
+ project = Project(id=111, config=config.SerializeToString())
+ session.add(project)
+ session.commit()
+ self.assertEqual(len(project.get_variables()), 1)
+ self.assertEqual(project.get_variables()[0].name, 'old_var')
+ project.set_variables([Variable(name='new_var', value='new', access_mode=Variable.PEER_WRITABLE)])
+ session.commit()
+ with db.session_scope() as session:
+ project = session.query(Project).get(111)
+ self.assertEqual(len(project.get_variables()), 1)
+ self.assertEqual(project.get_variables()[0].name, 'new_var')
+
+ def test_get_storage_root_path(self):
+ project = Project(id=111)
+ self.assertEqual(project.get_storage_root_path('not found'), 'not found')
+ project.set_variables(
+ [Variable(name='storage_root_path', value='root path', access_mode=Variable.PEER_READABLE)])
+ self.assertEqual(project.get_storage_root_path('not found'), 'root path')
+
+ def test_to_ref(self):
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ project_id = 66666
+ participant_id = 2
+ with db.session_scope() as session:
+ project = Project(id=project_id, name='test project', creator='test_user', created_at=created_at)
+ participant = Participant(id=participant_id, name='test part', domain_name='fl-test.com')
+ relation = ProjectParticipant(project_id=project.id, participant_id=participant.id)
+ session.add_all([project, participant, relation])
+ session.commit()
+ with db.session_scope() as session:
+ project = session.query(Project).get(project_id)
+ participant = session.query(Participant).get(participant_id)
+ self.assertEqual(
+ project.to_ref(),
+ ProjectRef(id=project_id,
+ name='test project',
+ creator='test_user',
+ participant_type='PLATFORM',
+ created_at=int(created_at.timestamp()),
+ participants=[participant.to_proto()],
+ participants_info=ParticipantsInfo(),
+ role=ProjectRole.PARTICIPANT.name),
+ )
+
+ def test_to_proto(self):
+ created_at = datetime(2022, 5, 1, 10, 10, tzinfo=timezone.utc)
+ project_id = 12356
+ participant_id = 22
+ variable = Variable(name='test_var', access_mode=Variable.PEER_READABLE, typed_value=Value(string_value='jjjj'))
+ with db.session_scope() as session:
+ project = Project(id=project_id,
+ name='test project',
+ creator='test_user',
+ created_at=created_at,
+ comment='test comment',
+ token='test token')
+ project.set_variables([variable])
+ participant = Participant(id=participant_id, name='test part', domain_name='fl-test.com')
+ relation = ProjectParticipant(project_id=project.id, participant_id=participant.id)
+ participants_info = ParticipantsInfo(participants_map={'test': ParticipantInfo(name='test part')})
+ project.set_participants_info(participants_info)
+ session.add_all([project, participant, relation])
+ session.commit()
+ with db.session_scope() as session:
+ project = session.query(Project).get(project_id)
+ participant = session.query(Participant).get(participant_id)
+ actual = project.to_proto()
+ self.assertEqual(
+ actual,
+ project_pb2.Project(id=project_id,
+ name='test project',
+ token='test token',
+ comment='test comment',
+ creator='test_user',
+ participant_type='PLATFORM',
+ created_at=int(created_at.timestamp()),
+ updated_at=actual.updated_at,
+ variables=[variable],
+ participants=[participant.to_proto()],
+ participants_info=participants_info,
+ config=ProjectConfig(variables=[variable]),
+ role=ProjectRole.PARTICIPANT.name),
+ )
+
+
+class PendingProjectTest(NoWebServerTestCase):
+
+ def test_to_proto(self):
+ created_at = datetime(2022, 5, 10, 0, 0, 0)
+ updated_at = datetime(2022, 5, 10, 0, 0, 0)
+ pending_proj = PendingProject(id=123,
+ name='test',
+ uuid='uuid',
+ state=PendingProjectState.ACCEPTED,
+ role=ProjectRole.PARTICIPANT,
+ comment='test',
+ created_at=created_at,
+ updated_at=updated_at,
+ ticket_status=TicketStatus.PENDING)
+ pending_proj.set_config(ProjectConfig())
+ participants_infos = ParticipantsInfo(
+ participants_map={'test': ParticipantInfo(name='test', role=PendingProjectState.ACCEPTED.name)})
+ pending_proj.set_participants_info(participants_infos)
+ self.assertEqual(
+ pending_proj.to_proto(),
+ PendingProjectPb(id=123,
+ name='test',
+ uuid='uuid',
+ state=PendingProjectState.ACCEPTED.name,
+ role=ProjectRole.PARTICIPANT.name,
+ comment='test',
+ created_at=1652140800,
+ updated_at=1652140800,
+ config=ProjectConfig(),
+ participants_info=participants_infos,
+ ticket_status=TicketStatus.PENDING.name,
+ participant_type=ParticipantType.LIGHT_CLIENT.name))
+
+ def test_get_participant_info(self):
+ pending_proj = PendingProject(id=123, name='test', uuid='uuid')
+ pending_proj.set_config(ProjectConfig())
+ participants_infos = ParticipantsInfo(participants_map={'test': ParticipantInfo(name='test')})
+ pending_proj.set_participants_info(participants_infos)
+ self.assertEqual(pending_proj.get_participant_info('test'), ParticipantInfo(name='test'))
+ self.assertEqual(pending_proj.get_participant_info('test1'), None)
+
+ def test_get_coordinator_info(self):
+ pending_proj = PendingProject(id=123, name='test', uuid='uuid')
+ pending_proj.set_config(ProjectConfig())
+ participants_infos = ParticipantsInfo(
+ participants_map={'test': ParticipantInfo(name='test', role=ProjectRole.COORDINATOR.name)})
+ pending_proj.set_participants_info(participants_infos)
+ self.assertEqual(pending_proj.get_coordinator_info(),
+ ('test', ParticipantInfo(name='test', role=ProjectRole.COORDINATOR.name)))
+ pending_proj.set_participants_info(ParticipantsInfo())
+ with self.assertRaises(ValueError):
+ pending_proj.get_coordinator_info()
+
+ def test_get_participant_type(self):
+ pending_proj = PendingProject(id=123, name='test', uuid='uuid')
+ participants_infos = ParticipantsInfo(participants_map={
+ 'test': ParticipantInfo(name='test', role=ProjectRole.PARTICIPANT.name, type=ParticipantType.PLATFORM.name)
+ })
+ pending_proj.set_participants_info(participants_infos)
+ self.assertEqual(pending_proj.get_participant_type(), ParticipantType.PLATFORM.name)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/project/project_scheduler.py b/web_console_v2/api/fedlearner_webconsole/project/project_scheduler.py
new file mode 100644
index 000000000..80065f6e4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/project_scheduler.py
@@ -0,0 +1,127 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from functools import partial
+from typing import Tuple, List
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.controllers import PendingProjectRpcController
+from fedlearner_webconsole.project.models import PendingProject, ProjectRole, PendingProjectState
+from fedlearner_webconsole.project.services import PendingProjectService
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput
+from fedlearner_webconsole.rpc.v2.project_service_client import ProjectServiceClient
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+
+
+def _get_ids_needed_schedule(session: Session) -> List[int]:
+ return [
+ p.id for p in session.query(PendingProject.id).filter_by(role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED).all()
+ ]
+
+
+def _if_pending_project_needed_create(p: PendingProject) -> bool:
+ part_info_list = p.get_participants_info().participants_map.values()
+ # coordinator can go next step only after all participants make their choices.
+ if any(part_info.state == PendingProjectState.PENDING.name for part_info in part_info_list):
+ return False
+ return any(part_info.state == PendingProjectState.ACCEPTED.name
+ for part_info in part_info_list
+ if part_info.role == ProjectRole.PARTICIPANT.name)
+
+
+def _if_all_participants_closed(p: PendingProject) -> bool:
+ closed_part_count = 0
+ part_info_list = p.get_participants_info().participants_map.values()
+ for part_info in part_info_list:
+ if part_info.role == ProjectRole.COORDINATOR.name:
+ continue
+ if part_info.state == PendingProjectState.CLOSED.name:
+ closed_part_count += 1
+ return closed_part_count > 0 and closed_part_count == len(part_info_list) - 1
+
+
+class ScheduleProjectRunner(IRunnerV2):
+
+ @staticmethod
+ def _create_pending_project(ids: List[int]):
+ for pid in ids:
+ with db.session_scope() as session:
+ p = session.query(PendingProject).get(pid)
+ PendingProjectRpcController(p).send_to_participants(
+ partial(ProjectServiceClient.create_pending_project, pending_project=p))
+ return ids
+
+ @staticmethod
+ def _update_all_participants(ids: List[int]) -> List[int]:
+ for pid in ids:
+ with db.session_scope() as session:
+ p = session.query(PendingProject).get(pid)
+ PendingProjectRpcController(p).send_to_participants(
+ partial(ProjectServiceClient.update_pending_project,
+ uuid=p.uuid,
+ participants_map=p.get_participants_info().participants_map))
+ return ids
+
+ @staticmethod
+ def _create_project(ids: List[int]) -> List[str]:
+ p_needed_create = []
+ with db.session_scope() as session:
+ for pid in ids:
+ p = session.query(PendingProject).get(pid)
+ if _if_pending_project_needed_create(p):
+ p_needed_create.append(p)
+ for p in p_needed_create:
+ result = PendingProjectRpcController(p).send_to_participants(
+ partial(ProjectServiceClient.create_project, uuid=p.uuid))
+ if all(resp.succeeded for resp in result.values()):
+ # the project of coordinator must be created at last when all participants finished,
+ # to let scheduler be able to retry when some participant failed.
+ with db.session_scope() as session:
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ PendingProjectService(session).create_project_locally(p.uuid)
+ session.commit()
+ else:
+ logging.error(f'create project {p.uuid} failed: {result}')
+ return [p.uuid for p in p_needed_create]
+
+ @staticmethod
+ def _fail_pending_project(ids: List[int]):
+ failed_ids = []
+ for p_id in ids:
+ with db.session_scope() as session:
+ p: PendingProject = session.query(PendingProject).get(p_id)
+ if _if_all_participants_closed(p):
+ p.state = PendingProjectState.FAILED
+ failed_ids.append(p.id)
+ session.commit()
+ return failed_ids
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ with db.session_scope() as session:
+ ids = _get_ids_needed_schedule(session)
+ output = RunnerOutput()
+ output.pending_project_scheduler_output.pending_project_created_ids.extend(self._create_pending_project(ids))
+ output.pending_project_scheduler_output.pending_project_updated_ids.extend(self._update_all_participants(ids))
+ output.pending_project_scheduler_output.projects_created_uuids.extend(self._create_project(ids))
+ output.pending_project_scheduler_output.pending_project_failed_ids.extend(self._fail_pending_project(ids))
+ return RunnerStatus.DONE, output
diff --git a/web_console_v2/api/fedlearner_webconsole/project/project_scheduler_test.py b/web_console_v2/api/fedlearner_webconsole/project/project_scheduler_test.py
new file mode 100644
index 000000000..02b4102b9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/project_scheduler_test.py
@@ -0,0 +1,214 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from unittest.mock import patch, MagicMock
+
+from google.protobuf.empty_pb2 import Empty
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.controllers import ParticipantResp
+from fedlearner_webconsole.project.models import PendingProject, ProjectRole, PendingProjectState
+from fedlearner_webconsole.project.project_scheduler import ScheduleProjectRunner, _get_ids_needed_schedule,\
+ _if_pending_project_needed_create, _if_all_participants_closed
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, RunnerOutput, PendingProjectSchedulerOutput
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ProjectSchedulerTest(NoWebServerTestCase):
+
+ def test_get_ids_needed_schedule(self):
+ p1 = PendingProject(uuid='a',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ p2 = PendingProject(uuid='b',
+ role=ProjectRole.PARTICIPANT,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ p3 = PendingProject(uuid='a',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.CLOSED)
+ with db.session_scope() as session:
+ session.add(p1)
+ session.add(p2)
+ session.add(p3)
+ session.commit()
+ with db.session_scope() as session:
+ self.assertEqual(_get_ids_needed_schedule(session), [p1.id]) # pylint: disable=protected-access
+
+ def test_if_pending_project_needed_create(self):
+ p = PendingProject()
+
+ # Test case: all accepted
+ part_infos = ParticipantsInfo(
+ participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.PARTICIPANT.name),
+ 'b': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.PARTICIPANT.name)
+ })
+ p.set_participants_info(part_infos)
+ self.assertEqual(_if_pending_project_needed_create(p), True)
+
+ # Test case: one rejected one accepted
+ part_infos = ParticipantsInfo(
+ participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.CLOSED.name, role=ProjectRole.PARTICIPANT.name),
+ 'b': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.PARTICIPANT.name)
+ })
+ p.set_participants_info(part_infos)
+ self.assertEqual(_if_pending_project_needed_create(p), True)
+
+ # Test case: one pending one accepted
+ part_infos = ParticipantsInfo(
+ participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.PENDING.name, role=ProjectRole.PARTICIPANT.name),
+ 'b': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.PARTICIPANT.name)
+ })
+ p.set_participants_info(part_infos)
+ self.assertEqual(_if_pending_project_needed_create(p), False)
+
+ def test_if_all_participants_closed(self):
+ p = PendingProject()
+ only_coordinator = ParticipantsInfo(participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.COORDINATOR.name),
+ })
+ p.set_participants_info(only_coordinator)
+ self.assertEqual(_if_all_participants_closed(p), False)
+
+ part_all_accepted = ParticipantsInfo(
+ participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.COORDINATOR.name),
+ 'b': ParticipantInfo(state=PendingProjectState.CLOSED.name, role=ProjectRole.PARTICIPANT.name),
+ 'c': ParticipantInfo(state=PendingProjectState.CLOSED.name, role=ProjectRole.PARTICIPANT.name),
+ })
+ p.set_participants_info(part_all_accepted)
+ self.assertEqual(_if_all_participants_closed(p), True)
+ part_infos_one_pending = ParticipantsInfo(
+ participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.COORDINATOR.name),
+ 'b': ParticipantInfo(state=PendingProjectState.CLOSED.name, role=ProjectRole.PARTICIPANT.name),
+ 'c': ParticipantInfo(state=PendingProjectState.PENDING.name, role=ProjectRole.PARTICIPANT.name),
+ })
+ p.set_participants_info(part_infos_one_pending)
+ self.assertEqual(_if_all_participants_closed(p), False)
+
+ @patch('fedlearner_webconsole.project.project_scheduler.PendingProjectRpcController.send_to_participants')
+ def test_create_pending_project(self, mock_sent: MagicMock):
+ p1 = PendingProject(uuid='a',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ with db.session_scope() as session:
+ session.add(p1)
+ session.commit()
+ result = ScheduleProjectRunner()._create_pending_project([p1.id]) # pylint: disable=protected-access
+ mock_sent.assert_called_once()
+ self.assertEqual(result, [p1.id])
+
+ @patch('fedlearner_webconsole.project.project_scheduler.PendingProjectRpcController.send_to_participants')
+ def test_sync_all_participants(self, mock_sent: MagicMock):
+ p1 = PendingProject(uuid='a',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ with db.session_scope() as session:
+ session.add(p1)
+ session.commit()
+ result = ScheduleProjectRunner()._update_all_participants([p1.id]) # pylint: disable=protected-access
+ mock_sent.assert_called_once()
+ self.assertEqual(result, [p1.id])
+
+ @patch('fedlearner_webconsole.project.project_scheduler.PendingProjectRpcController.send_to_participants')
+ @patch('fedlearner_webconsole.project.project_scheduler.PendingProjectService.create_project_locally')
+ def test_create_project(self, mock_create: MagicMock, mock_sent: MagicMock):
+ p1 = PendingProject(uuid='a',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ p3 = PendingProject(uuid='c',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ part_infos = ParticipantsInfo(
+ participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.ACCEPTED.name),
+ 'b': ParticipantInfo(state=PendingProjectState.ACCEPTED.name)
+ })
+ p1.set_participants_info(part_infos)
+ p3.set_participants_info(
+ ParticipantsInfo(
+ participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.ACCEPTED.name),
+ 'b': ParticipantInfo(state=PendingProjectState.PENDING.name)
+ }))
+ with db.session_scope() as session:
+ session.add(p1)
+ session.add(p3)
+ session.commit()
+ mock_sent.return_value = {
+ 'a': ParticipantResp(succeeded=True, resp=Empty(), msg=''),
+ 'b': ParticipantResp(succeeded=True, resp=Empty(), msg='')
+ }
+ result = ScheduleProjectRunner()._create_project([p1.id, p3.id]) # pylint: disable=protected-access
+ mock_sent.assert_called_once()
+ mock_create.assert_called_once_with(p1.uuid)
+ self.assertEqual(result, [p1.uuid])
+
+ def test_fail_project(self):
+ p1 = PendingProject(uuid='a',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ p2 = PendingProject(uuid='b',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ part_infos = ParticipantsInfo(participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.ACCEPTED.name, role=ProjectRole.PARTICIPANT.name),
+ })
+ p1.set_participants_info(part_infos)
+ part_infos = ParticipantsInfo(participants_map={
+ 'a': ParticipantInfo(state=PendingProjectState.CLOSED.name, role=ProjectRole.PARTICIPANT.name),
+ })
+ p2.set_participants_info(part_infos)
+ with db.session_scope() as session:
+ session.add(p1)
+ session.add(p2)
+ session.commit()
+ result = ScheduleProjectRunner()._fail_pending_project([p1.id, p2.id]) # pylint: disable=protected-access
+ self.assertEqual(result, [p2.id])
+ with db.session_scope() as session:
+ p2 = session.query(PendingProject).get(p2.id)
+ self.assertEqual(p2.state, PendingProjectState.FAILED)
+
+ def test_run(self):
+ p1 = PendingProject(uuid='a',
+ role=ProjectRole.COORDINATOR,
+ ticket_status=TicketStatus.APPROVED,
+ state=PendingProjectState.ACCEPTED)
+ with db.session_scope() as session:
+ session.add(p1)
+ session.commit()
+ result: RunnerOutput = ScheduleProjectRunner().run(RunnerContext(1, RunnerInput()))
+ self.assertEqual(result[0], RunnerStatus.DONE)
+ self.assertEqual(
+ result[1].pending_project_scheduler_output,
+ PendingProjectSchedulerOutput(pending_project_created_ids=[p1.id],
+ pending_project_updated_ids=[p1.id],
+ projects_created_uuids=[p1.uuid]))
diff --git a/web_console_v2/api/fedlearner_webconsole/project/services.py b/web_console_v2/api/fedlearner_webconsole/project/services.py
new file mode 100644
index 000000000..e4f841f28
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/services.py
@@ -0,0 +1,197 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import List, Dict, Optional
+
+from google.protobuf.struct_pb2 import Value
+from sqlalchemy import func
+from sqlalchemy.orm import Session, joinedload
+
+from envs import Envs
+from fedlearner_webconsole.exceptions import ResourceConflictException
+from fedlearner_webconsole.participant.models import ProjectParticipant, Participant, ParticipantType
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.project.models import Project, PendingProject, PendingProjectState, ProjectRole
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterOp
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.project_pb2 import ProjectRef, ProjectConfig, ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.utils.filtering import FilterBuilder, SupportedField, FieldType
+from fedlearner_webconsole.utils.paginate import Pagination, paginate
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.workflow.models import Workflow
+
+
+class ProjectService(object):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def get_projects_by_participant(self, participant_id: int) -> List[Dict]:
+ projects = self._session.query(Project).join(
+ ProjectParticipant, ProjectParticipant.project_id == Project.id).filter(
+ ProjectParticipant.participant_id == participant_id). \
+ order_by(Project.created_at.desc()).all()
+ return projects
+
+ def get_projects(self) -> List[ProjectRef]:
+ """Gets all projects in the platform."""
+ # TODO(linfan.fine): Not count soft-deleted workflow
+ # Project left join workflow to get workflow counts
+ projects = self._session.query(
+ Project, func.count(Workflow.id).label('num_workflow')) \
+ .options(joinedload(Project.participants)) \
+ .outerjoin(Workflow, Workflow.project_id == Project.id) \
+ .group_by(Project.id) \
+ .order_by(Project.created_at.desc()) \
+ .all()
+ refs = []
+ for row in projects:
+ ref: ProjectRef = row.Project.to_ref()
+ ref.num_workflow = row.num_workflow
+ refs.append(ref)
+ return refs
+
+
+class PendingProjectService(object):
+ FILTER_FIELDS = {
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ 'role': SupportedField(type=FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'state': SupportedField(type=FieldType.STRING, ops={
+ FilterOp.EQUAL: None,
+ FilterOp.IN: None
+ }),
+ }
+
+ def __init__(self, session: Session):
+ self._session = session
+ self._filter_builder = FilterBuilder(model_class=PendingProject, supported_fields=self.FILTER_FIELDS)
+
+ def build_participants_info(self, participant_ids: List[int]) -> ParticipantInfo:
+ participants_info = ParticipantsInfo()
+ for p_id in participant_ids:
+ participant = self._session.query(Participant).get(p_id)
+ assert participant is not None, f'participant with id {p_id} is not found'
+ p_info = ParticipantInfo(name=participant.name,
+ role=ProjectRole.PARTICIPANT.name,
+ state=PendingProjectState.PENDING.name if participant.get_type()
+ == ParticipantType.PLATFORM else PendingProjectState.ACCEPTED.name,
+ type=participant.get_type().name)
+ participants_info.participants_map[participant.pure_domain_name()].CopyFrom(p_info)
+
+ sys_info = SettingService(self._session).get_system_info()
+ coordinator_info = ParticipantInfo(name=sys_info.name,
+ state=PendingProjectState.ACCEPTED.name,
+ role=ProjectRole.COORDINATOR.name,
+ type=ParticipantType.PLATFORM.name)
+ participants_info.participants_map[sys_info.pure_domain_name].CopyFrom(coordinator_info)
+ return participants_info
+
+ def get_ids_from_participants_info(self, participants_info: ParticipantInfo) -> List[int]:
+ participant_ids = []
+ for pure_domain, p_info in participants_info.participants_map.items():
+ if p_info.role == ProjectRole.COORDINATOR:
+ continue
+ participant = ParticipantService(self._session).get_participant_by_pure_domain_name(pure_domain)
+ if participant is not None:
+ participant_ids.append(participant.id)
+ return participant_ids
+
+ def create_pending_project(self,
+ name: str,
+ config: ProjectConfig,
+ participants_info: ParticipantInfo,
+ comment: str,
+ creator_username: str,
+ uuid: Optional[str] = None,
+ role: ProjectRole = ProjectRole.PARTICIPANT,
+ state: PendingProjectState = PendingProjectState.PENDING) -> PendingProject:
+ pending_project = PendingProject(
+ name=name,
+ uuid=uuid if uuid else resource_uuid(),
+ comment=comment,
+ role=role,
+ state=state,
+ creator_username=creator_username,
+ )
+ pending_project.set_config(config)
+ pending_project.set_participants_info(participants_info)
+ self._session.add(pending_project)
+ self._session.flush()
+ return pending_project
+
+ def update_state_as_participant(self, pending_project_id: int, state: str) -> PendingProject:
+ pending_project: PendingProject = self._session.query(PendingProject).get(pending_project_id)
+ assert pending_project is not None, f'pending project with id {pending_project_id} is not found'
+ assert pending_project.role == ProjectRole.PARTICIPANT, 'only participant can accept or refuse'
+
+ # TODO(xiangyuxuan.prs): remove after using token instead of name to ensure consistency of project
+ if PendingProjectState(state) == PendingProjectState.ACCEPTED and self.duplicated_name_exists(
+ pending_project.name):
+ raise ResourceConflictException(f'{pending_project.name} has already existed')
+
+ pending_project.state = PendingProjectState(state)
+ return pending_project
+
+ def list_pending_projects(self,
+ page: Optional[int] = None,
+ page_size: Optional[int] = None,
+ filter_exp: Optional[FilterExpression] = None) -> Pagination:
+ """Lists pending project by filter expression and pagination.
+
+ Raises:
+ ValueError: if the expression is unsupported.
+ """
+ query = self._session.query(PendingProject)
+ if filter_exp:
+ query = self._filter_builder.build_query(query, filter_exp)
+ query = query.order_by(PendingProject.id.desc())
+ return paginate(query, page, page_size)
+
+ def create_project_locally(self, pending_project_uuid: str):
+ pending_project: PendingProject = self._session.query(PendingProject).filter_by(
+ uuid=pending_project_uuid).first()
+ project = Project(name=pending_project.name,
+ token=pending_project.uuid,
+ role=pending_project.role,
+ creator=pending_project.creator_username,
+ comment=pending_project.comment)
+ project.set_participants_info(pending_project.get_participants_info())
+ project_config: ProjectConfig = pending_project.get_config()
+ # init storage root path variable to make user use config in environment by default.
+ project_config.variables.append(
+ Variable(name='storage_root_path',
+ typed_value=Value(string_value=Envs.STORAGE_ROOT),
+ value=Envs.STORAGE_ROOT))
+ project.set_config(project_config)
+ self._session.add(project)
+ self._session.flush()
+ part_ids = self.get_ids_from_participants_info(pending_project.get_participants_info())
+ for participant_id in part_ids:
+ # insert a relationship into the table
+ new_relationship = ProjectParticipant(project_id=project.id, participant_id=participant_id)
+ self._session.add(new_relationship)
+ self._session.flush()
+ pending_project.state = PendingProjectState.CLOSED
+
+ def duplicated_name_exists(self, name: str) -> bool:
+ p = self._session.query(Project.id).filter_by(name=name).first()
+ if p is not None:
+ return True
+ pending_p = self._session.query(PendingProject.id).filter_by(name=name,
+ state=PendingProjectState.ACCEPTED).first()
+ if pending_p is not None:
+ return True
+ return False
diff --git a/web_console_v2/api/fedlearner_webconsole/project/services_test.py b/web_console_v2/api/fedlearner_webconsole/project/services_test.py
new file mode 100644
index 000000000..7a0c62f81
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/project/services_test.py
@@ -0,0 +1,253 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import time
+import unittest
+from unittest.mock import patch
+
+from google.protobuf.struct_pb2 import Value
+
+from envs import Envs
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import ResourceConflictException
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant, ParticipantType
+from fedlearner_webconsole.project.models import Project, ProjectRole, PendingProjectState, PendingProject
+from fedlearner_webconsole.project.services import ProjectService, PendingProjectService
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterExpressionKind, SimpleExpression, FilterOp
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo, ProjectConfig
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.workflow.models import Workflow
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ProjectServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project_1 = Project(id=1, name='project 1', creator='user1')
+ participant_1 = Participant(id=1, name='participant 1', domain_name='fl-participant-1.com')
+ relation = ProjectParticipant(project_id=project_1.id, participant_id=participant_1.id)
+ session.add_all([project_1, participant_1, relation])
+ session.commit()
+ time.sleep(1)
+
+ project_2 = Project(id=2, name='project 2', creator='user2')
+ participant_2 = Participant(id=2, name='participant 2', domain_name='fl-participant-2.com')
+ relation_1 = ProjectParticipant(project_id=project_1.id, participant_id=participant_2.id)
+ relation_2 = ProjectParticipant(project_id=project_2.id, participant_id=participant_2.id)
+ session.add_all([project_2, participant_2, relation_1, relation_2])
+
+ session.commit()
+
+ def test_get_projects_by_participant_id(self):
+ with db.session_scope() as session:
+ service = ProjectService(session)
+ projects = service.get_projects_by_participant(1)
+ self.assertEqual(len(projects), 1)
+ self.assertEqual(projects[0].name, 'project 1')
+
+ projects = service.get_projects_by_participant(2)
+ self.assertEqual(len(projects), 2)
+ self.assertCountEqual([projects[0].name, projects[1].name], ['project 1', 'project 2'])
+
+ def test_get_projects(self):
+ with db.session_scope() as session:
+ workflow_1 = Workflow(name='workflow 1', project_id=1)
+ workflow_2 = Workflow(name='workflow 2', project_id=1)
+ session.add_all([workflow_1, workflow_2])
+ session.commit()
+ with db.session_scope() as session:
+ service = ProjectService(session)
+ projects = service.get_projects()
+ self.assertEqual(len(projects), 2)
+
+ self.assertEqual(projects[0].name, 'project 2')
+ self.assertEqual(len(projects[0].participants), 1)
+ self.assertEqual(projects[0].num_workflow, 0)
+
+ self.assertEqual(projects[1].name, 'project 1')
+ self.assertEqual(len(projects[1].participants), 2)
+ self.assertEqual(projects[1].num_workflow, 2)
+
+
+class PendingProjectServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ participant_1 = Participant(id=1,
+ name='participant 1',
+ domain_name='fl-participant-1.com',
+ type=ParticipantType.LIGHT_CLIENT)
+ participant_2 = Participant(id=2,
+ name='participant 2',
+ domain_name='fl-participant-2.com',
+ type=ParticipantType.PLATFORM)
+ session.add_all([participant_1, participant_2])
+ session.commit()
+
+ @patch('fedlearner_webconsole.project.services.SettingService.get_system_info')
+ def test_build_participants_info(self, mock_system_info):
+ mock_system_info.return_value = SystemInfo(pure_domain_name='self', name='self_name')
+ with db.session_scope() as session:
+ info = PendingProjectService(session).build_participants_info([1, 2])
+ self.assertEqual(
+ info,
+ ParticipantsInfo(
+ participants_map={
+ 'participant-1':
+ ParticipantInfo(name='participant 1',
+ role=ProjectRole.PARTICIPANT.name,
+ state=PendingProjectState.ACCEPTED.name,
+ type=ParticipantType.LIGHT_CLIENT.name),
+ 'participant-2':
+ ParticipantInfo(name='participant 2',
+ role=ProjectRole.PARTICIPANT.name,
+ state=PendingProjectState.PENDING.name,
+ type=ParticipantType.PLATFORM.name),
+ 'self':
+ ParticipantInfo(name='self_name',
+ role=ProjectRole.COORDINATOR.name,
+ state=PendingProjectState.ACCEPTED.name,
+ type=ParticipantType.PLATFORM.name)
+ }))
+
+ def test_get_ids_from_participants_info(self):
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'participant-1': ParticipantInfo(name='participant 1'),
+ 'participant-2': ParticipantInfo(name='participant 2'),
+ 'self': ParticipantInfo(name='self', role=ProjectRole.COORDINATOR.name),
+ 'no connection': ParticipantInfo(name='no connection')
+ })
+ with db.session_scope() as session:
+ ids = PendingProjectService(session).get_ids_from_participants_info(participants_info)
+ self.assertCountEqual(ids, [1, 2])
+
+ def test_create_pending_project(self):
+ with db.session_scope() as session:
+ pending_project = PendingProjectService(session).create_pending_project(
+ name='test',
+ config=ProjectConfig(variables=[Variable(name='test')]),
+ participants_info=ParticipantsInfo(
+ participants_map={'self': ParticipantInfo(name='self', role=ProjectRole.COORDINATOR.name)}),
+ comment='test',
+ creator_username='test',
+ uuid='uuid')
+ session.commit()
+ with db.session_scope() as session:
+ result: PendingProject = session.query(PendingProject).get(pending_project.id)
+ self.assertEqual(result.get_config(), ProjectConfig(variables=[Variable(name='test')]))
+ self.assertEqual(
+ result.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={'self': ParticipantInfo(name='self', role=ProjectRole.COORDINATOR.name)}))
+ self.assertEqual(result.name, 'test')
+ self.assertEqual(result.uuid, 'uuid')
+
+ @patch('fedlearner_webconsole.project.services.PendingProjectService.duplicated_name_exists')
+ def test_update_state_as_participant(self, mock_dup):
+ pending_project = PendingProject(name='test', state=PendingProjectState.PENDING, role=ProjectRole.PARTICIPANT)
+
+ with db.session_scope() as session:
+ session.add(pending_project)
+ session.commit()
+
+ mock_dup.return_value = True
+ with db.session_scope() as session:
+ with self.assertRaises(ResourceConflictException):
+ PendingProjectService(session).update_state_as_participant(pending_project.id,
+ PendingProjectState.ACCEPTED.name)
+ with db.session_scope() as session:
+ PendingProjectService(session).update_state_as_participant(pending_project.id,
+ PendingProjectState.CLOSED.name)
+ session.commit()
+ with db.session_scope() as session:
+ result: PendingProject = session.query(PendingProject).get(pending_project.id)
+ self.assertEqual(result.state, PendingProjectState.CLOSED)
+
+ def test_create_project_locally(self):
+ with db.session_scope() as session:
+ pending_project = PendingProjectService(session).create_pending_project(
+ name='test',
+ config=ProjectConfig(variables=[Variable(name='test')]),
+ participants_info=ParticipantsInfo(
+ participants_map={'self': ParticipantInfo(name='self', role=ProjectRole.COORDINATOR.name)}),
+ comment='test',
+ creator_username='test',
+ uuid='uuid')
+ session.commit()
+
+ with db.session_scope() as session:
+ PendingProjectService(session).create_project_locally(pending_project.uuid)
+ session.commit()
+ with db.session_scope() as session:
+ project = session.query(Project).filter_by(name=pending_project.name, token=pending_project.uuid).first()
+ self.assertEqual(project.get_variables(), [
+ Variable(name='test'),
+ Variable(name='storage_root_path',
+ value=Envs.STORAGE_ROOT,
+ typed_value=Value(string_value=Envs.STORAGE_ROOT))
+ ])
+ pending_project = session.query(PendingProject).get(pending_project.id)
+ self.assertEqual(
+ project.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={'self': ParticipantInfo(name='self', role=ProjectRole.COORDINATOR.name)}))
+ self.assertEqual(project.creator, 'test')
+ self.assertEqual(pending_project.state, PendingProjectState.CLOSED)
+ self.assertEqual(project.comment, 'test')
+
+ def test_list_pending_projects(self):
+ pending_project = PendingProject(name='test', state=PendingProjectState.PENDING, role=ProjectRole.PARTICIPANT)
+ pending_project1 = PendingProject(name='test1',
+ state=PendingProjectState.ACCEPTED,
+ role=ProjectRole.COORDINATOR)
+ with db.session_scope() as session:
+ session.add(pending_project)
+ session.add(pending_project1)
+ session.commit()
+ with db.session_scope() as session:
+ result = PendingProjectService(session).list_pending_projects().get_items()
+ self.assertEqual(len(result), 2)
+ result = PendingProjectService(session).list_pending_projects(page=1, page_size=1).get_items()
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0].id, pending_project1.id)
+ exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='role',
+ op=FilterOp.EQUAL,
+ string_value='COORDINATOR'))
+ result = PendingProjectService(session).list_pending_projects(filter_exp=exp).get_items()
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0].id, pending_project1.id)
+
+ def test_duplicated_name_exists(self):
+ p = Project(name='test')
+ pending_p = PendingProject(name='test1', state=PendingProjectState.ACCEPTED)
+ pending_p2 = PendingProject(name='test2', state=PendingProjectState.CLOSED)
+ with db.session_scope() as session:
+ session.add_all([p, pending_p, pending_p2])
+ session.commit()
+ with db.session_scope() as session:
+ self.assertTrue(PendingProjectService(session).duplicated_name_exists(p.name))
+ self.assertTrue(PendingProjectService(session).duplicated_name_exists(pending_p.name))
+ self.assertFalse(PendingProjectService(session).duplicated_name_exists(pending_p2.name))
+ self.assertFalse(PendingProjectService(session).duplicated_name_exists('test0'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/common/extension_pb2.pyi b/web_console_v2/api/fedlearner_webconsole/proto/common/extension_pb2.pyi
new file mode 100644
index 000000000..bcc03e48f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/proto/common/extension_pb2.pyi
@@ -0,0 +1,9 @@
+"""
+@generated by mypy-protobuf. Do not edit manually!
+isort:skip_file
+"""
+import google.protobuf.descriptor
+
+DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ...
+
+secret: google.protobuf.descriptor.FieldDescriptor = ...
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/jsonschemas/.gitignore b/web_console_v2/api/fedlearner_webconsole/proto/jsonschemas/.gitignore
new file mode 100644
index 000000000..5e7d2734c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/proto/jsonschemas/.gitignore
@@ -0,0 +1,4 @@
+# Ignore everything in this directory
+*
+# Except this file
+!.gitignore
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/job_service_pb2.pyi b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/job_service_pb2.pyi
new file mode 100644
index 000000000..0816a4b5d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/job_service_pb2.pyi
@@ -0,0 +1,304 @@
+"""
+@generated by mypy-protobuf. Do not edit manually!
+isort:skip_file
+"""
+import builtins
+import fedlearner_webconsole.proto.dataset_pb2
+import fedlearner_webconsole.proto.mmgr_pb2
+import google.protobuf.descriptor
+import google.protobuf.message
+import typing
+import typing_extensions
+
+DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ...
+
+class InformTrustedJobGroupRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ AUTH_STATUS_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ auth_status: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ auth_status : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"auth_status",b"auth_status",u"uuid",b"uuid"]) -> None: ...
+global___InformTrustedJobGroupRequest = InformTrustedJobGroupRequest
+
+class UpdateTrustedJobGroupRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ ALGORITHM_UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ algorithm_uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ algorithm_uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_uuid",b"algorithm_uuid",u"uuid",b"uuid"]) -> None: ...
+global___UpdateTrustedJobGroupRequest = UpdateTrustedJobGroupRequest
+
+class DeleteTrustedJobGroupRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___DeleteTrustedJobGroupRequest = DeleteTrustedJobGroupRequest
+
+class GetTrustedJobGroupRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___GetTrustedJobGroupRequest = GetTrustedJobGroupRequest
+
+class GetTrustedJobGroupResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ AUTH_STATUS_FIELD_NUMBER: builtins.int
+ auth_status: typing.Text = ...
+
+ def __init__(self,
+ *,
+ auth_status : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"auth_status",b"auth_status"]) -> None: ...
+global___GetTrustedJobGroupResponse = GetTrustedJobGroupResponse
+
+class InformTrustedJobRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ AUTH_STATUS_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ auth_status: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ auth_status : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"auth_status",b"auth_status",u"uuid",b"uuid"]) -> None: ...
+global___InformTrustedJobRequest = InformTrustedJobRequest
+
+class GetTrustedJobRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___GetTrustedJobRequest = GetTrustedJobRequest
+
+class GetTrustedJobResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ AUTH_STATUS_FIELD_NUMBER: builtins.int
+ auth_status: typing.Text = ...
+
+ def __init__(self,
+ *,
+ auth_status : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"auth_status",b"auth_status"]) -> None: ...
+global___GetTrustedJobResponse = GetTrustedJobResponse
+
+class CreateTrustedExportJobRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ NAME_FIELD_NUMBER: builtins.int
+ EXPORT_COUNT_FIELD_NUMBER: builtins.int
+ PARENT_UUID_FIELD_NUMBER: builtins.int
+ TICKET_UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ name: typing.Text = ...
+ export_count: builtins.int = ...
+ parent_uuid: typing.Text = ...
+ ticket_uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ name : typing.Text = ...,
+ export_count : builtins.int = ...,
+ parent_uuid : typing.Text = ...,
+ ticket_uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"export_count",b"export_count",u"name",b"name",u"parent_uuid",b"parent_uuid",u"ticket_uuid",b"ticket_uuid",u"uuid",b"uuid"]) -> None: ...
+global___CreateTrustedExportJobRequest = CreateTrustedExportJobRequest
+
+class CreateModelJobRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ NAME_FIELD_NUMBER: builtins.int
+ UUID_FIELD_NUMBER: builtins.int
+ GROUP_UUID_FIELD_NUMBER: builtins.int
+ MODEL_JOB_TYPE_FIELD_NUMBER: builtins.int
+ ALGORITHM_TYPE_FIELD_NUMBER: builtins.int
+ GLOBAL_CONFIG_FIELD_NUMBER: builtins.int
+ VERSION_FIELD_NUMBER: builtins.int
+ name: typing.Text = ...
+ uuid: typing.Text = ...
+ group_uuid: typing.Text = ...
+ model_job_type: typing.Text = ...
+ algorithm_type: typing.Text = ...
+ version: builtins.int = ...
+
+ @property
+ def global_config(self) -> fedlearner_webconsole.proto.mmgr_pb2.ModelJobGlobalConfig: ...
+
+ def __init__(self,
+ *,
+ name : typing.Text = ...,
+ uuid : typing.Text = ...,
+ group_uuid : typing.Text = ...,
+ model_job_type : typing.Text = ...,
+ algorithm_type : typing.Text = ...,
+ global_config : typing.Optional[fedlearner_webconsole.proto.mmgr_pb2.ModelJobGlobalConfig] = ...,
+ version : builtins.int = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"global_config",b"global_config"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_type",b"algorithm_type",u"global_config",b"global_config",u"group_uuid",b"group_uuid",u"model_job_type",b"model_job_type",u"name",b"name",u"uuid",b"uuid",u"version",b"version"]) -> None: ...
+global___CreateModelJobRequest = CreateModelJobRequest
+
+class CreateDatasetJobStageRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ DATASET_JOB_UUID_FIELD_NUMBER: builtins.int
+ DATASET_JOB_STAGE_UUID_FIELD_NUMBER: builtins.int
+ NAME_FIELD_NUMBER: builtins.int
+ EVENT_TIME_FIELD_NUMBER: builtins.int
+ dataset_job_uuid: typing.Text = ...
+ dataset_job_stage_uuid: typing.Text = ...
+ name: typing.Text = ...
+ event_time: builtins.int = ...
+
+ def __init__(self,
+ *,
+ dataset_job_uuid : typing.Text = ...,
+ dataset_job_stage_uuid : typing.Text = ...,
+ name : typing.Text = ...,
+ event_time : builtins.int = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"dataset_job_stage_uuid",b"dataset_job_stage_uuid",u"dataset_job_uuid",b"dataset_job_uuid",u"event_time",b"event_time",u"name",b"name"]) -> None: ...
+global___CreateDatasetJobStageRequest = CreateDatasetJobStageRequest
+
+class GetDatasetJobStageRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ DATASET_JOB_STAGE_UUID_FIELD_NUMBER: builtins.int
+ dataset_job_stage_uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ dataset_job_stage_uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"dataset_job_stage_uuid",b"dataset_job_stage_uuid"]) -> None: ...
+global___GetDatasetJobStageRequest = GetDatasetJobStageRequest
+
+class GetDatasetJobStageResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ DATASET_JOB_STAGE_FIELD_NUMBER: builtins.int
+
+ @property
+ def dataset_job_stage(self) -> fedlearner_webconsole.proto.dataset_pb2.DatasetJobStage: ...
+
+ def __init__(self,
+ *,
+ dataset_job_stage : typing.Optional[fedlearner_webconsole.proto.dataset_pb2.DatasetJobStage] = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"dataset_job_stage",b"dataset_job_stage"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"dataset_job_stage",b"dataset_job_stage"]) -> None: ...
+global___GetDatasetJobStageResponse = GetDatasetJobStageResponse
+
+class CreateModelJobGroupRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ NAME_FIELD_NUMBER: builtins.int
+ UUID_FIELD_NUMBER: builtins.int
+ ALGORITHM_TYPE_FIELD_NUMBER: builtins.int
+ DATASET_UUID_FIELD_NUMBER: builtins.int
+ ALGORITHM_PROJECT_LIST_FIELD_NUMBER: builtins.int
+ name: typing.Text = ...
+ uuid: typing.Text = ...
+ algorithm_type: typing.Text = ...
+ dataset_uuid: typing.Text = ...
+
+ @property
+ def algorithm_project_list(self) -> fedlearner_webconsole.proto.mmgr_pb2.AlgorithmProjectList: ...
+
+ def __init__(self,
+ *,
+ name : typing.Text = ...,
+ uuid : typing.Text = ...,
+ algorithm_type : typing.Text = ...,
+ dataset_uuid : typing.Text = ...,
+ algorithm_project_list : typing.Optional[fedlearner_webconsole.proto.mmgr_pb2.AlgorithmProjectList] = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"algorithm_project_list",b"algorithm_project_list"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_project_list",b"algorithm_project_list",u"algorithm_type",b"algorithm_type",u"dataset_uuid",b"dataset_uuid",u"name",b"name",u"uuid",b"uuid"]) -> None: ...
+global___CreateModelJobGroupRequest = CreateModelJobGroupRequest
+
+class GetModelJobRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___GetModelJobRequest = GetModelJobRequest
+
+class GetModelJobGroupRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___GetModelJobGroupRequest = GetModelJobGroupRequest
+
+class InformModelJobGroupRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ AUTH_STATUS_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ auth_status: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ auth_status : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"auth_status",b"auth_status",u"uuid",b"uuid"]) -> None: ...
+global___InformModelJobGroupRequest = InformModelJobGroupRequest
+
+class UpdateDatasetJobSchedulerStateRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ SCHEDULER_STATE_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ scheduler_state: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ scheduler_state : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"scheduler_state",b"scheduler_state",u"uuid",b"uuid"]) -> None: ...
+global___UpdateDatasetJobSchedulerStateRequest = UpdateDatasetJobSchedulerStateRequest
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/project_service_pb2.pyi b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/project_service_pb2.pyi
new file mode 100644
index 000000000..0352da8c0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/project_service_pb2.pyi
@@ -0,0 +1,149 @@
+"""
+@generated by mypy-protobuf. Do not edit manually!
+isort:skip_file
+"""
+import builtins
+import fedlearner_webconsole.proto.project_pb2
+import fedlearner_webconsole.proto.workflow_definition_pb2
+import google.protobuf.descriptor
+import google.protobuf.internal.containers
+import google.protobuf.message
+import typing
+import typing_extensions
+
+DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ...
+
+class CreatePendingProjectRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ NAME_FIELD_NUMBER: builtins.int
+ UUID_FIELD_NUMBER: builtins.int
+ PARTICIPANTS_INFO_FIELD_NUMBER: builtins.int
+ COMMENT_FIELD_NUMBER: builtins.int
+ CREATOR_USERNAME_FIELD_NUMBER: builtins.int
+ CONFIG_FIELD_NUMBER: builtins.int
+ TICKET_UUID_FIELD_NUMBER: builtins.int
+ name: typing.Text = ...
+ uuid: typing.Text = ...
+ comment: typing.Text = ...
+ creator_username: typing.Text = ...
+ ticket_uuid: typing.Text = ...
+
+ @property
+ def participants_info(self) -> fedlearner_webconsole.proto.project_pb2.ParticipantsInfo: ...
+
+ @property
+ def config(self) -> fedlearner_webconsole.proto.project_pb2.ProjectConfig: ...
+
+ def __init__(self,
+ *,
+ name : typing.Text = ...,
+ uuid : typing.Text = ...,
+ participants_info : typing.Optional[fedlearner_webconsole.proto.project_pb2.ParticipantsInfo] = ...,
+ comment : typing.Text = ...,
+ creator_username : typing.Text = ...,
+ config : typing.Optional[fedlearner_webconsole.proto.project_pb2.ProjectConfig] = ...,
+ ticket_uuid : typing.Text = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"config",b"config",u"participants_info",b"participants_info"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"comment",b"comment",u"config",b"config",u"creator_username",b"creator_username",u"name",b"name",u"participants_info",b"participants_info",u"ticket_uuid",b"ticket_uuid",u"uuid",b"uuid"]) -> None: ...
+global___CreatePendingProjectRequest = CreatePendingProjectRequest
+
+class UpdatePendingProjectRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ class ParticipantsMapEntry(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ KEY_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ key: typing.Text = ...
+
+ @property
+ def value(self) -> fedlearner_webconsole.proto.project_pb2.ParticipantInfo: ...
+
+ def __init__(self,
+ *,
+ key : typing.Text = ...,
+ value : typing.Optional[fedlearner_webconsole.proto.project_pb2.ParticipantInfo] = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"value",b"value"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"key",b"key",u"value",b"value"]) -> None: ...
+
+ UUID_FIELD_NUMBER: builtins.int
+ PARTICIPANTS_MAP_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ @property
+ def participants_map(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, fedlearner_webconsole.proto.project_pb2.ParticipantInfo]: ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ participants_map : typing.Optional[typing.Mapping[typing.Text, fedlearner_webconsole.proto.project_pb2.ParticipantInfo]] = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"participants_map",b"participants_map",u"uuid",b"uuid"]) -> None: ...
+global___UpdatePendingProjectRequest = UpdatePendingProjectRequest
+
+class SyncPendingProjectStateRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ STATE_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ state: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ state : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"state",b"state",u"uuid",b"uuid"]) -> None: ...
+global___SyncPendingProjectStateRequest = SyncPendingProjectStateRequest
+
+class CreateProjectRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___CreateProjectRequest = CreateProjectRequest
+
+class DeletePendingProjectRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___DeletePendingProjectRequest = DeletePendingProjectRequest
+
+class SendTemplateRevisionRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ CONFIG_FIELD_NUMBER: builtins.int
+ NAME_FIELD_NUMBER: builtins.int
+ COMMENT_FIELD_NUMBER: builtins.int
+ KIND_FIELD_NUMBER: builtins.int
+ REVISION_INDEX_FIELD_NUMBER: builtins.int
+ name: typing.Text = ...
+ comment: typing.Text = ...
+ kind: typing.Text = ...
+ revision_index: builtins.int = ...
+
+ @property
+ def config(self) -> fedlearner_webconsole.proto.workflow_definition_pb2.WorkflowDefinition: ...
+
+ def __init__(self,
+ *,
+ config : typing.Optional[fedlearner_webconsole.proto.workflow_definition_pb2.WorkflowDefinition] = ...,
+ name : typing.Text = ...,
+ comment : typing.Text = ...,
+ kind : typing.Text = ...,
+ revision_index : builtins.int = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"config",b"config"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"comment",b"comment",u"config",b"config",u"kind",b"kind",u"name",b"name",u"revision_index",b"revision_index"]) -> None: ...
+global___SendTemplateRevisionRequest = SendTemplateRevisionRequest
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/resource_service_pb2.pyi b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/resource_service_pb2.pyi
new file mode 100644
index 000000000..0b67b9ec3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/resource_service_pb2.pyi
@@ -0,0 +1,174 @@
+"""
+@generated by mypy-protobuf. Do not edit manually!
+isort:skip_file
+"""
+import builtins
+import fedlearner_webconsole.proto.algorithm_pb2
+import fedlearner_webconsole.proto.dataset_pb2
+import fedlearner_webconsole.proto.filtering_pb2
+import google.protobuf.descriptor
+import google.protobuf.internal.containers
+import google.protobuf.message
+import typing
+import typing_extensions
+
+DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ...
+
+class ListAlgorithmProjectsRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ FILTER_EXP_FIELD_NUMBER: builtins.int
+
+ @property
+ def filter_exp(self) -> fedlearner_webconsole.proto.filtering_pb2.FilterExpression: ...
+
+ def __init__(self,
+ *,
+ filter_exp : typing.Optional[fedlearner_webconsole.proto.filtering_pb2.FilterExpression] = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"filter_exp",b"filter_exp"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"filter_exp",b"filter_exp"]) -> None: ...
+global___ListAlgorithmProjectsRequest = ListAlgorithmProjectsRequest
+
+class ListAlgorithmProjectsResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ ALGORITHM_PROJECTS_FIELD_NUMBER: builtins.int
+
+ @property
+ def algorithm_projects(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[fedlearner_webconsole.proto.algorithm_pb2.AlgorithmProjectPb]: ...
+
+ def __init__(self,
+ *,
+ algorithm_projects : typing.Optional[typing.Iterable[fedlearner_webconsole.proto.algorithm_pb2.AlgorithmProjectPb]] = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_projects",b"algorithm_projects"]) -> None: ...
+global___ListAlgorithmProjectsResponse = ListAlgorithmProjectsResponse
+
+class ListAlgorithmsRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ ALGORITHM_PROJECT_UUID_FIELD_NUMBER: builtins.int
+ algorithm_project_uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ algorithm_project_uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_project_uuid",b"algorithm_project_uuid"]) -> None: ...
+global___ListAlgorithmsRequest = ListAlgorithmsRequest
+
+class ListAlgorithmsResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ ALGORITHMS_FIELD_NUMBER: builtins.int
+
+ @property
+ def algorithms(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[fedlearner_webconsole.proto.algorithm_pb2.AlgorithmPb]: ...
+
+ def __init__(self,
+ *,
+ algorithms : typing.Optional[typing.Iterable[fedlearner_webconsole.proto.algorithm_pb2.AlgorithmPb]] = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithms",b"algorithms"]) -> None: ...
+global___ListAlgorithmsResponse = ListAlgorithmsResponse
+
+class GetAlgorithmProjectRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ ALGORITHM_PROJECT_UUID_FIELD_NUMBER: builtins.int
+ algorithm_project_uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ algorithm_project_uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_project_uuid",b"algorithm_project_uuid"]) -> None: ...
+global___GetAlgorithmProjectRequest = GetAlgorithmProjectRequest
+
+class GetAlgorithmRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ ALGORITHM_UUID_FIELD_NUMBER: builtins.int
+ algorithm_uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ algorithm_uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_uuid",b"algorithm_uuid"]) -> None: ...
+global___GetAlgorithmRequest = GetAlgorithmRequest
+
+class GetAlgorithmFilesRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ ALGORITHM_UUID_FIELD_NUMBER: builtins.int
+ algorithm_uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ algorithm_uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"algorithm_uuid",b"algorithm_uuid"]) -> None: ...
+global___GetAlgorithmFilesRequest = GetAlgorithmFilesRequest
+
+class GetAlgorithmFilesResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ HASH_FIELD_NUMBER: builtins.int
+ CHUNK_FIELD_NUMBER: builtins.int
+ hash: typing.Text = ...
+ chunk: builtins.bytes = ...
+
+ def __init__(self,
+ *,
+ hash : typing.Text = ...,
+ chunk : builtins.bytes = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"chunk",b"chunk",u"hash",b"hash"]) -> None: ...
+global___GetAlgorithmFilesResponse = GetAlgorithmFilesResponse
+
+class InformDatasetRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ AUTH_STATUS_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ auth_status: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ auth_status : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"auth_status",b"auth_status",u"uuid",b"uuid"]) -> None: ...
+global___InformDatasetRequest = InformDatasetRequest
+
+class ListDatasetsRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ KIND_FIELD_NUMBER: builtins.int
+ STATE_FIELD_NUMBER: builtins.int
+ TIME_RANGE_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+ kind: typing.Text = ...
+ state: typing.Text = ...
+
+ @property
+ def time_range(self) -> fedlearner_webconsole.proto.dataset_pb2.TimeRange: ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ kind : typing.Text = ...,
+ state : typing.Text = ...,
+ time_range : typing.Optional[fedlearner_webconsole.proto.dataset_pb2.TimeRange] = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"time_range",b"time_range"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"kind",b"kind",u"state",b"state",u"time_range",b"time_range",u"uuid",b"uuid"]) -> None: ...
+global___ListDatasetsRequest = ListDatasetsRequest
+
+class ListDatasetsResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ PARTICIPANT_DATASETS_FIELD_NUMBER: builtins.int
+
+ @property
+ def participant_datasets(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[fedlearner_webconsole.proto.dataset_pb2.ParticipantDatasetRef]: ...
+
+ def __init__(self,
+ *,
+ participant_datasets : typing.Optional[typing.Iterable[fedlearner_webconsole.proto.dataset_pb2.ParticipantDatasetRef]] = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"participant_datasets",b"participant_datasets"]) -> None: ...
+global___ListDatasetsResponse = ListDatasetsResponse
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/review_service_pb2.pyi b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/review_service_pb2.pyi
new file mode 100644
index 000000000..43120bc48
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/review_service_pb2.pyi
@@ -0,0 +1,45 @@
+"""
+@generated by mypy-protobuf. Do not edit manually!
+isort:skip_file
+"""
+import builtins
+import fedlearner_webconsole.proto.review_pb2
+import google.protobuf.descriptor
+import google.protobuf.message
+import typing
+import typing_extensions
+
+DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ...
+
+class CreateTicketRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ TTYPE_FIELD_NUMBER: builtins.int
+ CREATOR_USERNAME_FIELD_NUMBER: builtins.int
+ DETAILS_FIELD_NUMBER: builtins.int
+ ttype: fedlearner_webconsole.proto.review_pb2.TicketType.V = ...
+ creator_username: typing.Text = ...
+
+ @property
+ def details(self) -> fedlearner_webconsole.proto.review_pb2.TicketDetails: ...
+
+ def __init__(self,
+ *,
+ ttype : fedlearner_webconsole.proto.review_pb2.TicketType.V = ...,
+ creator_username : typing.Text = ...,
+ details : typing.Optional[fedlearner_webconsole.proto.review_pb2.TicketDetails] = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"details",b"details"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"creator_username",b"creator_username",u"details",b"details",u"ttype",b"ttype"]) -> None: ...
+global___CreateTicketRequest = CreateTicketRequest
+
+class GetTicketRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ UUID_FIELD_NUMBER: builtins.int
+ uuid: typing.Text = ...
+
+ def __init__(self,
+ *,
+ uuid : typing.Text = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"uuid",b"uuid"]) -> None: ...
+global___GetTicketRequest = GetTicketRequest
diff --git a/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/system_service_pb2.pyi b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/system_service_pb2.pyi
new file mode 100644
index 000000000..fd4c5eca9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/proto/rpc/v2/system_service_pb2.pyi
@@ -0,0 +1,82 @@
+"""
+@generated by mypy-protobuf. Do not edit manually!
+isort:skip_file
+"""
+import builtins
+import fedlearner_webconsole.proto.common_pb2
+import google.protobuf.descriptor
+import google.protobuf.message
+import google.protobuf.struct_pb2
+import typing
+import typing_extensions
+
+DESCRIPTOR: google.protobuf.descriptor.FileDescriptor = ...
+
+class CheckHealthRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+
+ def __init__(self,
+ ) -> None: ...
+global___CheckHealthRequest = CheckHealthRequest
+
+class CheckHealthResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ APPLICATION_VERSION_FIELD_NUMBER: builtins.int
+ HEALTHY_FIELD_NUMBER: builtins.int
+ MESSAGE_FIELD_NUMBER: builtins.int
+ healthy: builtins.bool = ...
+ message: typing.Text = ...
+
+ @property
+ def application_version(self) -> fedlearner_webconsole.proto.common_pb2.ApplicationVersion: ...
+
+ def __init__(self,
+ *,
+ application_version : typing.Optional[fedlearner_webconsole.proto.common_pb2.ApplicationVersion] = ...,
+ healthy : builtins.bool = ...,
+ message : typing.Text = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"application_version",b"application_version"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"application_version",b"application_version",u"healthy",b"healthy",u"message",b"message"]) -> None: ...
+global___CheckHealthResponse = CheckHealthResponse
+
+class ListFlagsRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+
+ def __init__(self,
+ ) -> None: ...
+global___ListFlagsRequest = ListFlagsRequest
+
+class ListFlagsResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ FLAGS_FIELD_NUMBER: builtins.int
+
+ @property
+ def flags(self) -> google.protobuf.struct_pb2.Struct: ...
+
+ def __init__(self,
+ *,
+ flags : typing.Optional[google.protobuf.struct_pb2.Struct] = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal[u"flags",b"flags"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"flags",b"flags"]) -> None: ...
+global___ListFlagsResponse = ListFlagsResponse
+
+class CheckTeeEnabledRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+
+ def __init__(self,
+ ) -> None: ...
+global___CheckTeeEnabledRequest = CheckTeeEnabledRequest
+
+class CheckTeeEnabledResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor = ...
+ TEE_ENABLED_FIELD_NUMBER: builtins.int
+ tee_enabled: builtins.bool = ...
+
+ def __init__(self,
+ *,
+ tee_enabled : builtins.bool = ...,
+ ) -> None: ...
+ def ClearField(self, field_name: typing_extensions.Literal[u"tee_enabled",b"tee_enabled"]) -> None: ...
+global___CheckTeeEnabledResponse = CheckTeeEnabledResponse
diff --git a/web_console_v2/api/fedlearner_webconsole/review/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/review/BUILD.bazel
new file mode 100644
index 000000000..9c1c6149c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/review/BUILD.bazel
@@ -0,0 +1,86 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "common_lib",
+ srcs = [
+ "common.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_library(
+ name = "ticket_helper_lib",
+ srcs = [
+ "ticket_helper.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:review_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "ticket_helper_lib_test",
+ srcs = [
+ "ticket_helper_test.py",
+ ],
+ imports = ["../.."],
+ main = "ticket_helper_test.py",
+ deps = [
+ ":ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "runners_lib",
+ srcs = [
+ "runners.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":common_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:review_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "runners_test",
+ size = "small",
+ srcs = [
+ "runners_test.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":runners_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/review/common.py b/web_console_v2/api/fedlearner_webconsole/review/common.py
new file mode 100644
index 000000000..924cf7204
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/review/common.py
@@ -0,0 +1,33 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from fedlearner_webconsole.project.models import PendingProject
+from fedlearner_webconsole.proto import review_pb2
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.mmgr.models import ModelJobGroup
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob
+
+REVIEW_ORM_MAPPER = {
+ review_pb2.TicketType.CREATE_PARTICIPANT: Participant,
+ review_pb2.TicketType.CREATE_PROJECT: PendingProject,
+ review_pb2.TicketType.PUBLISH_DATASET: Dataset,
+ review_pb2.TicketType.CREATE_PROCESSED_DATASET: Dataset,
+ review_pb2.TicketType.CREATE_MODELJOB_GROUP: ModelJobGroup,
+ review_pb2.TicketType.TK_CREATE_TRUSTED_JOB_GROUP: TrustedJobGroup,
+ review_pb2.TicketType.TK_CREATE_TRUSTED_EXPORT_JOB: TrustedJob,
+}
+
+NO_CENTRAL_SERVER_UUID = 'no_central_server_uuid'
diff --git a/web_console_v2/api/fedlearner_webconsole/review/runners.py b/web_console_v2/api/fedlearner_webconsole/review/runners.py
new file mode 100644
index 000000000..138efad93
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/review/runners.py
@@ -0,0 +1,70 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+from fedlearner_webconsole.db import db
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.review.common import REVIEW_ORM_MAPPER
+from fedlearner_webconsole.rpc.v2.review_service_client import ReviewServiceClient
+from fedlearner_webconsole.proto import review_pb2
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, TicketHelperOutput
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+
+
+class TicketHelperRunner(IRunnerV2):
+
+ def __init__(self, domain_name: str):
+ self._domain_name = domain_name
+
+ def _update_ticket_status(self, resource) -> bool:
+ uuid = getattr(resource, 'ticket_uuid')
+ client = ReviewServiceClient.from_participant(domain_name=self._domain_name)
+ resp = client.get_ticket(uuid)
+ status = TicketStatus(review_pb2.ReviewStatus.Name(resp.status))
+ if status in [TicketStatus.APPROVED, TicketStatus.DECLINED]:
+ setattr(resource, 'ticket_status', status)
+ return True
+
+ return False
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ del context
+
+ ticket_output = TicketHelperOutput()
+ for ttype, orm in REVIEW_ORM_MAPPER.items():
+ ttype_name = review_pb2.TicketType.Name(ttype)
+ updated_ticket = ticket_output.updated_ticket[ttype_name].ids
+ unupdated_ticket = ticket_output.unupdated_ticket[ttype_name].ids
+ failed_ticket = ticket_output.failed_ticket[ttype_name].ids
+ with db.session_scope() as session:
+ resources = session.query(orm).filter_by(ticket_status=TicketStatus.PENDING).all()
+ for resource in resources:
+ try:
+ if self._update_ticket_status(resource):
+ updated_ticket.append(resource.id)
+ else:
+ unupdated_ticket.append(resource.id)
+ except Exception: # pylint: disable=broad-except
+ failed_ticket.append(resource.id)
+ session.commit()
+ logging.info(f'ticket routine for {ttype_name}:')
+ logging.info(f' updated_ticket {updated_ticket}')
+ logging.info(f' failed_ticket {failed_ticket}')
+
+ return (RunnerStatus.DONE, RunnerOutput(ticket_helper_output=ticket_output))
diff --git a/web_console_v2/api/fedlearner_webconsole/review/runners_test.py b/web_console_v2/api/fedlearner_webconsole/review/runners_test.py
new file mode 100644
index 000000000..230228eeb
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/review/runners_test.py
@@ -0,0 +1,75 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock, call
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto import review_pb2
+from fedlearner_webconsole.review.runners import TicketHelperRunner
+from fedlearner_webconsole.composer.context import RunnerContext, RunnerInput
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+
+
+class TicketHelperRunnerTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.review.runners.ReviewServiceClient')
+ def test_run(self, mock_review_service_client: MagicMock):
+ with db.session_scope() as session:
+ session.add(
+ Participant(id=111,
+ name='p1',
+ domain_name='fl-test1.com',
+ ticket_status=TicketStatus.PENDING,
+ ticket_uuid='u12345'))
+ session.add(
+ Participant(id=222,
+ name='p2',
+ domain_name='fl-test2.com',
+ ticket_status=TicketStatus.APPROVED,
+ ticket_uuid='u22345'))
+ session.add(
+ Participant(id=333,
+ name='p3',
+ domain_name='fl-test3.com',
+ ticket_status=TicketStatus.DECLINED,
+ ticket_uuid='u32345'))
+ session.add(
+ Participant(id=444,
+ name='p1',
+ domain_name='fl-test4.com',
+ ticket_status=TicketStatus.PENDING,
+ ticket_uuid='u42345'))
+ session.commit()
+
+ client = MagicMock()
+ mock_review_service_client.from_participant.return_value = client
+ client.get_ticket.side_effect = [
+ review_pb2.Ticket(status=review_pb2.ReviewStatus.APPROVED),
+ review_pb2.Ticket(status=review_pb2.ReviewStatus.PENDING)
+ ]
+
+ runner = TicketHelperRunner(domain_name='fl-central.com')
+ _, output = runner.run(RunnerContext(0, RunnerInput()))
+
+ mock_review_service_client.from_participant.assert_called_with(domain_name='fl-central.com')
+ client.get_ticket.assert_has_calls(calls=[call('u12345'), call('u42345')])
+ self.assertEqual(output.ticket_helper_output.updated_ticket['CREATE_PARTICIPANT'].ids, [111])
+ self.assertEqual(output.ticket_helper_output.unupdated_ticket['CREATE_PARTICIPANT'].ids, [444])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/review/ticket_helper.py b/web_console_v2/api/fedlearner_webconsole/review/ticket_helper.py
new file mode 100644
index 000000000..e610d39d0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/review/ticket_helper.py
@@ -0,0 +1,122 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC, abstractmethod
+import json
+from typing import Callable
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.review import common
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.proto import review_pb2
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.rpc.v2.review_service_client import ReviewServiceClient
+from fedlearner_webconsole.utils.flask_utils import get_current_user
+
+
+class ITicketHelper(ABC):
+
+ @abstractmethod
+ def create_ticket(self, ticket_type: review_pb2.TicketType, details: review_pb2.TicketDetails) -> review_pb2.Ticket:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_ticket(self, uuid: str) -> review_pb2.Ticket:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def validate_ticket(self, uuid: str, validate_fn: Callable[[review_pb2.Ticket], bool]) -> bool:
+ raise NotImplementedError()
+
+
+def _get_model_from_ticket_type(ticket_type: review_pb2.TicketType) -> db.Model:
+ model = common.REVIEW_ORM_MAPPER.get(ticket_type)
+ if model is None:
+ raise InvalidArgumentException(details=f'failed to get orm.Model for {review_pb2.TicketType.Name(ticket_type)}')
+ return model
+
+
+class NoCenterServerTicketHelper(ITicketHelper):
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def create_ticket(self, ticket_type: review_pb2.TicketType, details: review_pb2.TicketDetails) -> review_pb2.Ticket:
+ model = _get_model_from_ticket_type(ticket_type)
+ # TODO(wangsen.0914): add extension for filter related resources other than uuid.
+ resource = self._session.query(model).filter_by(uuid=details.uuid).first()
+ if resource is None:
+ raise InvalidArgumentException(details=f'failed to get resource with {details.uuid}')
+ resource.ticket_status = TicketStatus(review_pb2.ReviewStatus.Name(review_pb2.ReviewStatus.APPROVED))
+ resource.ticket_uuid = common.NO_CENTRAL_SERVER_UUID
+ return review_pb2.Ticket(type=ticket_type,
+ details=details,
+ uuid=common.NO_CENTRAL_SERVER_UUID,
+ status=review_pb2.ReviewStatus.APPROVED,
+ review_strategy=review_pb2.ReviewStrategy.AUTO)
+
+ def get_ticket(self, uuid: str) -> review_pb2.Ticket:
+ return review_pb2.Ticket(uuid=uuid,
+ type=review_pb2.TicketType.UNKOWN_TYPE,
+ status=review_pb2.ReviewStatus.APPROVED,
+ review_strategy=review_pb2.ReviewStrategy.AUTO)
+
+ def validate_ticket(self, uuid: str, validate_fn: Callable[[review_pb2.Ticket], bool]) -> bool:
+ del validate_fn # ignore validate_fn, because center server is not configured.
+
+ if uuid != common.NO_CENTRAL_SERVER_UUID:
+ return False
+ return True
+
+
+class CenterServerTicketHelper(ITicketHelper):
+
+ def __init__(self, session: Session, domain_name: str):
+ self._session = session
+ self._domain_name = domain_name
+
+ def create_ticket(self, ticket_type: review_pb2.TicketType, details: review_pb2.TicketDetails) -> review_pb2.Ticket:
+ model = _get_model_from_ticket_type(ticket_type)
+ # TODO(wangsen.0914): add extension for filter related resources other than uuid.
+ resource = self._session.query(model).filter_by(uuid=details.uuid).first()
+ if resource is None:
+ raise InvalidArgumentException(details=f'failed to get resource with {details.uuid}')
+ client = ReviewServiceClient.from_participant(domain_name=self._domain_name)
+ current_user = get_current_user()
+ creator_username = current_user.username if current_user else None
+ ticket = client.create_ticket(ticket_type, creator_username, details)
+ resource.ticket_status = TicketStatus(review_pb2.ReviewStatus.Name(ticket.status))
+ resource.ticket_uuid = ticket.uuid
+ return ticket
+
+ def get_ticket(self, uuid: str) -> review_pb2.Ticket:
+ client = ReviewServiceClient.from_participant(domain_name=self._domain_name)
+ return client.get_ticket(uuid)
+
+ def validate_ticket(self, uuid: str, validate_fn: Callable[[review_pb2.Ticket], bool]) -> bool:
+ ticket = self.get_ticket(uuid)
+ if ticket.uuid != uuid:
+ return False
+ return validate_fn(ticket)
+
+
+def get_ticket_helper(session: Session) -> ITicketHelper:
+ configuration = json.loads(Flag.REVIEW_CENTER_CONFIGURATION.value)
+ if not configuration:
+ return NoCenterServerTicketHelper(session)
+ domain_name = configuration['domain_name']
+ return CenterServerTicketHelper(session, domain_name)
diff --git a/web_console_v2/api/fedlearner_webconsole/review/ticket_helper_test.py b/web_console_v2/api/fedlearner_webconsole/review/ticket_helper_test.py
new file mode 100644
index 000000000..a9d4dd6e4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/review/ticket_helper_test.py
@@ -0,0 +1,170 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+from unittest.mock import call, patch, MagicMock
+from fedlearner_webconsole.exceptions import InvalidArgumentException
+from fedlearner_webconsole.review.common import NO_CENTRAL_SERVER_UUID
+from fedlearner_webconsole.review.ticket_helper import (NoCenterServerTicketHelper, CenterServerTicketHelper,
+ get_ticket_helper)
+from fedlearner_webconsole.proto import review_pb2
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.utils.base_model.review_ticket_model import ReviewTicketModel, TicketStatus
+from testing.common import NoWebServerTestCase
+
+
+class FakeModel(db.Model, ReviewTicketModel):
+ __tablename__ = 'fake_model'
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ uuid = db.Column(db.String(255), nullable=True, comment='uuid')
+
+
+class NoCenterServerTicketHelperTest(NoWebServerTestCase):
+
+ def test_get_ticket(self):
+ with db.session_scope() as session:
+ ticket = NoCenterServerTicketHelper(session).get_ticket(uuid='u12345')
+ self.assertEqual(
+ ticket,
+ review_pb2.Ticket(uuid='u12345',
+ type=review_pb2.TicketType.UNKOWN_TYPE,
+ status=review_pb2.ReviewStatus.APPROVED,
+ review_strategy=review_pb2.ReviewStrategy.AUTO))
+
+ @patch('fedlearner_webconsole.review.common.REVIEW_ORM_MAPPER',
+ {review_pb2.TicketType.CREATE_PARTICIPANT: FakeModel})
+ def test_create_ticket(self):
+ with db.session_scope() as session:
+ self.assertRaises(InvalidArgumentException,
+ NoCenterServerTicketHelper(session).create_ticket,
+ ticket_type=review_pb2.TicketType.CREATE_NODE,
+ details=review_pb2.TicketDetails(uuid='u1234'))
+
+ with db.session_scope() as session:
+ fake_data = FakeModel(uuid='u1234')
+ session.add(fake_data)
+ session.flush()
+ ticket = NoCenterServerTicketHelper(session).create_ticket(
+ ticket_type=review_pb2.TicketType.CREATE_PARTICIPANT, details=review_pb2.TicketDetails(uuid='u1234'))
+ session.flush()
+ self.assertEqual(fake_data.ticket_status, TicketStatus.APPROVED)
+ session.commit()
+ with db.session_scope() as session:
+ resource = session.query(FakeModel).filter_by(uuid='u1234').first()
+ self.assertEqual(resource.ticket_status, TicketStatus.APPROVED)
+ self.assertEqual(resource.ticket_uuid, NO_CENTRAL_SERVER_UUID)
+ self.assertEqual(
+ ticket,
+ review_pb2.Ticket(uuid=NO_CENTRAL_SERVER_UUID,
+ type=review_pb2.TicketType.CREATE_PARTICIPANT,
+ details=review_pb2.TicketDetails(uuid='u1234'),
+ status=review_pb2.ReviewStatus.APPROVED,
+ review_strategy=review_pb2.ReviewStrategy.AUTO))
+
+ def test_validate_ticket(self):
+ with db.session_scope() as session:
+ # ignore validate_fn, because center server is not configured.
+ validate_fn = lambda _: False
+ self.assertFalse(NoCenterServerTicketHelper(session).validate_ticket('u1234', validate_fn))
+ self.assertTrue(NoCenterServerTicketHelper(session).validate_ticket(NO_CENTRAL_SERVER_UUID, validate_fn))
+
+
+class CenterServerTicketHelperTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.review.ticket_helper.ReviewServiceClient')
+ def test_validate_ticket(self, mock_review_service_client: MagicMock):
+ client = MagicMock()
+ mock_review_service_client.from_participant.return_value = client
+ client.get_ticket.side_effect = [
+ review_pb2.Ticket(uuid='u1234',
+ status=review_pb2.ReviewStatus.APPROVED,
+ details=review_pb2.TicketDetails(uuid='u1234')),
+ review_pb2.Ticket(uuid='u234',
+ status=review_pb2.ReviewStatus.APPROVED,
+ details=review_pb2.TicketDetails(uuid='u2345')),
+ review_pb2.Ticket(uuid='u1234',
+ status=review_pb2.ReviewStatus.APPROVED,
+ details=review_pb2.TicketDetails(uuid='u2345')),
+ ]
+
+ with db.session_scope() as session:
+ validate_fn = lambda t: t.details.uuid == 'u2345'
+ self.assertFalse(CenterServerTicketHelper(session, 'fl-central.com').validate_ticket('u1234', validate_fn))
+ self.assertFalse(CenterServerTicketHelper(session, 'fl-central.com').validate_ticket('u1234', validate_fn))
+ self.assertTrue(CenterServerTicketHelper(session, 'fl-central.com').validate_ticket('u1234', validate_fn))
+
+ mock_review_service_client.from_participant.assert_called_with(domain_name='fl-central.com')
+ client.get_ticket.assert_has_calls(calls=[call('u1234'), call('u1234'), call('u1234')])
+
+ @patch('fedlearner_webconsole.review.common.REVIEW_ORM_MAPPER',
+ {review_pb2.TicketType.CREATE_PARTICIPANT: FakeModel})
+ @patch('fedlearner_webconsole.review.ticket_helper.get_current_user', lambda: User(username='creator'))
+ @patch('fedlearner_webconsole.review.ticket_helper.ReviewServiceClient')
+ def test_create_ticket(self, mock_review_service_client: MagicMock):
+ with db.session_scope() as session:
+ self.assertRaises(InvalidArgumentException,
+ CenterServerTicketHelper(session, 'fl-central').create_ticket,
+ ticket_type=review_pb2.TicketType.CREATE_NODE,
+ details=review_pb2.TicketDetails(uuid='u1234'))
+
+ client = MagicMock()
+ mock_review_service_client.from_participant.return_value = client
+ client.create_ticket.return_value = review_pb2.Ticket(uuid='u4321',
+ status=review_pb2.ReviewStatus.PENDING,
+ details=review_pb2.TicketDetails(uuid='u1234'))
+
+ with db.session_scope() as session:
+ fake_data = FakeModel(uuid='u1234')
+ session.add(fake_data)
+ session.flush()
+ CenterServerTicketHelper(session, 'fl-central.com').create_ticket(
+ ticket_type=review_pb2.TicketType.CREATE_PARTICIPANT, details=review_pb2.TicketDetails(uuid='u1234'))
+ session.flush()
+ self.assertEqual(fake_data.ticket_status, TicketStatus.PENDING)
+ session.commit()
+
+ with db.session_scope() as session:
+ resource = session.query(FakeModel).filter_by(uuid='u1234').first()
+ self.assertEqual(resource.ticket_status, TicketStatus.PENDING)
+ self.assertEqual(resource.ticket_uuid, 'u4321')
+
+ mock_review_service_client.from_participant.assert_called_with(domain_name='fl-central.com')
+ self.assertEqual([call[0][1] for call in client.create_ticket.call_args_list][0], 'creator')
+
+
+class GetTicketHelperTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.flag.models.Flag.REVIEW_CENTER_CONFIGURATION.value', '{}')
+ def test_no_center_server(self):
+ with db.session_scope() as session:
+ self.assertIsInstance(get_ticket_helper(session), NoCenterServerTicketHelper)
+
+ @patch('fedlearner_webconsole.flag.models.Flag.REVIEW_CENTER_CONFIGURATION.value',
+ '{"domain_name": "fl-central.com"}')
+ def test_with_center_server(self):
+ with db.session_scope() as session:
+ self.assertIsInstance(get_ticket_helper(session), CenterServerTicketHelper)
+
+ @patch('fedlearner_webconsole.flag.models.Flag.REVIEW_CENTER_CONFIGURATION.value', '{"dom_name": "fl-central.com"}')
+ def test_with_invalid_center_server(self):
+ with db.session_scope() as session:
+ self.assertRaises(KeyError, get_ticket_helper, session)
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/rpc/BUILD.bazel
new file mode 100644
index 000000000..fac35558a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/BUILD.bazel
@@ -0,0 +1,152 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "auth_lib",
+ srcs = ["auth.py"],
+ imports = ["../.."],
+)
+
+py_test(
+ name = "auth_lib_test",
+ srcs = [
+ "auth_test.py",
+ ],
+ imports = ["../../.."],
+ main = "auth_test.py",
+ deps = [
+ ":auth_lib",
+ ],
+)
+
+py_library(
+ name = "client_interceptor_lib",
+ srcs = ["client_interceptor.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/middleware:request_id_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_library(
+ name = "client_lib",
+ srcs = ["client.py"],
+ imports = ["../.."],
+ deps = [
+ ":client_interceptor_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:client_base_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "client_lib_test",
+ # TODO(liuhehan): change it back to small when dataset model lightweight enough.
+ size = "medium",
+ srcs = [
+ "client_test.py",
+ ],
+ imports = ["../../.."],
+ main = "client_test.py",
+ deps = [
+ ":client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:request_id_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "server_lib",
+ srcs = ["server.py"],
+ imports = ["../.."],
+ deps = [
+ ":auth_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:request_id_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:auth_server_interceptor_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:project_service_server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:resource_service_server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/serving:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:handlers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:es_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:kibana_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_grpcio_reflection//:pkg",
+ ],
+)
+
+py_test(
+ name = "server_lib_test",
+ size = "small",
+ srcs = [
+ "server_test.py",
+ ],
+ imports = ["../../.."],
+ main = "server_test.py",
+ deps = [
+ ":auth_lib",
+ ":server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:request_id_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/auth.py b/web_console_v2/api/fedlearner_webconsole/rpc/auth.py
new file mode 100644
index 000000000..93412c039
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/auth.py
@@ -0,0 +1,41 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+
+X_HOST_HEADER = 'x-host'
+SSL_CLIENT_SUBJECT_DN_HEADER = 'ssl-client-subject-dn'
+PROJECT_NAME_HEADER = 'project-name'
+
+
+def get_common_name(subject_dn: str) -> Optional[str]:
+ """Gets common name from x.509
+
+ Args:
+ subject_dn (str): ssl-client-subject-dn from header
+
+ Returns:
+ Optional[str]: common name if exists
+ """
+
+ # ssl-client-subject-dn example:
+ # CN=*.fl-xxx.com,OU=security,O=security,L=beijing,ST=beijing,C=CN
+ for s in subject_dn.split(','):
+ if s.find('=') == -1:
+ return None
+ k, v = s.split('=', maxsplit=1)
+ if k == 'CN':
+ return v
+ return None
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/auth_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/auth_test.py
new file mode 100644
index 000000000..59c3b6f98
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/auth_test.py
@@ -0,0 +1,34 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.rpc.auth import get_common_name
+
+
+class AuthTest(unittest.TestCase):
+
+ def test_get_common_name(self):
+ self.assertIsNone(get_common_name('invalid'))
+ self.assertIsNone(get_common_name('CN*.fl-xxx.com,C=CN'))
+ self.assertEqual(get_common_name('CN=*.fl-xxx.com,OU=security,O=security,L=beijing,ST=beijing,C=CN'),
+ '*.fl-xxx.com')
+ self.assertEqual(get_common_name('CN=aaa.fedlearner.net,OU=security,O=security,L=beijing,ST=beijing,C=CN'),
+ 'aaa.fedlearner.net')
+ self.assertEqual(get_common_name('CN==*.fl-xxx.com,C=CN'), '=*.fl-xxx.com')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/client.py b/web_console_v2/api/fedlearner_webconsole/rpc/client.py
index 726568ad6..829778325 100644
--- a/web_console_v2/api/fedlearner_webconsole/rpc/client.py
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/client.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,159 +17,212 @@
import logging
from functools import wraps
+from typing import Optional
import grpc
+from google.protobuf import empty_pb2
from envs import Envs
-from fedlearner_webconsole.exceptions import (
- UnauthorizedException, InvalidArgumentException
-)
-from fedlearner_webconsole.proto import (
- service_pb2, service_pb2_grpc, common_pb2
-)
-from fedlearner_webconsole.utils.decorators import retry_fn
-
-
-def _build_channel(url, authority):
- """A helper function to build gRPC channel for easy testing."""
- return grpc.insecure_channel(
- target=url,
+from fedlearner_webconsole.utils.decorators.lru_cache import lru_cache
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+from fedlearner_webconsole.exceptions import (UnauthorizedException, InvalidArgumentException)
+from fedlearner_webconsole.proto import (dataset_pb2, service_pb2, service_pb2_grpc, common_pb2)
+from fedlearner_webconsole.proto.service_pb2_grpc import WebConsoleV2ServiceStub
+from fedlearner_webconsole.proto.serving_pb2 import ServingServiceType
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TwoPcAction, TransactionData
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.rpc.client_interceptor import ClientInterceptor
+from fedlearner_webconsole.rpc.v2.client_base import get_nginx_controller_url
+
+
+@lru_cache(timeout=60, maxsize=100)
+def _build_grpc_stub(egress_url: str, authority: str) -> WebConsoleV2ServiceStub:
+ """A helper function to build gRPC stub with cache.
+
+ Notice that as we cache the stub, if nginx controller gets restarted, the channel may break.
+ This practice is following official best practice: https://grpc.io/docs/guides/performance/
+
+ Args:
+ egress_url: nginx controller url in current cluster.
+ authority: ingress domain in current cluster.
+
+ Returns:
+ A grpc service stub to call API.
+ """
+ channel = grpc.insecure_channel(
+ target=egress_url,
# options defined at
# https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h
options=[('grpc.default_authority', authority)])
+ channel = grpc.intercept_channel(channel, ClientInterceptor())
+ return service_pb2_grpc.WebConsoleV2ServiceStub(channel)
+# TODO(linfan.fine): refactor catch_and_fallback
def catch_and_fallback(resp_class):
+
def decorator(f):
+
@wraps(f)
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except grpc.RpcError as e:
- return resp_class(status=common_pb2.Status(
- code=common_pb2.STATUS_UNKNOWN_ERROR, msg=repr(e)))
+ return resp_class(status=common_pb2.Status(code=common_pb2.STATUS_UNKNOWN_ERROR, msg=repr(e)))
return wrapper
return decorator
+def _need_retry_for_get(err: Exception) -> bool:
+ if not isinstance(err, grpc.RpcError):
+ return False
+ # No need to retry for NOT_FOUND
+ return err.code() != grpc.StatusCode.NOT_FOUND
+
+
+def _default_need_retry(err: Exception) -> bool:
+ return isinstance(err, grpc.RpcError)
+
+
class RpcClient(object):
- def __init__(self, project_config, receiver_config):
- self._project = project_config
- self._receiver = receiver_config
- self._auth_info = service_pb2.ProjAuthInfo(
- project_name=self._project.name,
- target_domain=self._receiver.domain_name,
- auth_token=self._project.token)
-
- egress_url = 'fedlearner-stack-ingress-nginx-controller.default.svc:80'
- for variable in self._project.variables:
- if variable.name == 'EGRESS_URL':
- egress_url = variable.value
- break
- self._client = service_pb2_grpc.WebConsoleV2ServiceStub(
- _build_channel(egress_url, self._receiver.grpc_spec.authority))
+
+ def __init__(self,
+ egress_url: str,
+ authority: str,
+ x_host: str,
+ project_auth_info: Optional[service_pb2.ProjAuthInfo] = None):
+ """Inits rpc client.
+
+ Args:
+ egress_url: nginx controller url in current cluster.
+ authority: ingress domain in current cluster.
+ x_host: ingress domain in target cluster, nginx will handle the
+ rewriting.
+ project_auth_info: info for project level authentication.
+ """
+ self._x_host = x_host
+ self._project_auth_info = project_auth_info
+
+ self._client = _build_grpc_stub(egress_url, authority)
+
+ @classmethod
+ def from_project_and_participant(cls, project_name: str, project_token: str, domain_name: str):
+ # Builds auth info from project and receiver
+ auth_info = service_pb2.ProjAuthInfo(project_name=project_name,
+ target_domain=domain_name,
+ auth_token=project_token)
+ return cls(egress_url=get_nginx_controller_url(),
+ authority=gen_egress_authority(domain_name),
+ x_host=gen_x_host(domain_name),
+ project_auth_info=auth_info)
+
+ @classmethod
+ def from_participant(cls, domain_name: str):
+ return cls(egress_url=get_nginx_controller_url(),
+ authority=gen_egress_authority(domain_name),
+ x_host=gen_x_host(domain_name))
def _get_metadata(self):
- metadata = []
- x_host_prefix = 'fedlearner-webconsole-v2'
- for variable in self._project.variables:
- if variable.name == 'X_HOST':
- x_host_prefix = variable.value
- break
- metadata.append(('x-host', '{}.{}'.format(x_host_prefix,
- self._receiver.domain_name)))
- for key, value in self._receiver.grpc_spec.extra_headers.items():
- metadata.append((key, value))
# metadata is a tuple of tuples
- return tuple(metadata)
+ return tuple([('x-host', self._x_host)])
@catch_and_fallback(resp_class=service_pb2.CheckConnectionResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
def check_connection(self):
- msg = service_pb2.CheckConnectionRequest(auth_info=self._auth_info)
- response = self._client.CheckConnection(
- request=msg,
- metadata=self._get_metadata(),
- timeout=Envs.GRPC_CLIENT_TIMEOUT)
+ msg = service_pb2.CheckConnectionRequest(auth_info=self._project_auth_info)
+ response = self._client.CheckConnection(request=msg,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
if response.status.code != common_pb2.STATUS_SUCCESS:
- logging.debug('check_connection request error: %s',
- response.status.msg)
+ logging.debug('check_connection request error: %s', response.status.msg)
+ return response
+
+ @catch_and_fallback(resp_class=service_pb2.CheckPeerConnectionResponse)
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def check_peer_connection(self):
+ # TODO(taoyanting): double check
+ msg = service_pb2.CheckPeerConnectionRequest()
+ response = self._client.CheckPeerConnection(request=msg,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
+ if response.status.code != common_pb2.STATUS_SUCCESS:
+ logging.debug('check_connection request error: %s', response.status.msg)
return response
@catch_and_fallback(resp_class=service_pb2.UpdateWorkflowStateResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
- def update_workflow_state(self, name, state, target_state,
- transaction_state, uuid, forked_from_uuid,
- extra=''):
- msg = service_pb2.UpdateWorkflowStateRequest(
- auth_info=self._auth_info,
- workflow_name=name,
- state=state.value,
- target_state=target_state.value,
- transaction_state=transaction_state.value,
- uuid=uuid,
- forked_from_uuid=forked_from_uuid,
- extra=extra
- )
- response = self._client.UpdateWorkflowState(
- request=msg, metadata=self._get_metadata(),
- timeout=Envs.GRPC_CLIENT_TIMEOUT)
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def update_workflow_state(self, name, state, target_state, transaction_state, uuid, forked_from_uuid, extra=''):
+ msg = service_pb2.UpdateWorkflowStateRequest(auth_info=self._project_auth_info,
+ workflow_name=name,
+ state=state.value,
+ target_state=target_state.value,
+ transaction_state=transaction_state.value,
+ uuid=uuid,
+ forked_from_uuid=forked_from_uuid,
+ extra=extra)
+ response = self._client.UpdateWorkflowState(request=msg,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
if response.status.code != common_pb2.STATUS_SUCCESS:
- logging.debug('update_workflow_state request error: %s',
- response.status.msg)
+ logging.debug('update_workflow_state request error: %s', response.status.msg)
return response
@catch_and_fallback(resp_class=service_pb2.GetWorkflowResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
- def get_workflow(self, name):
- msg = service_pb2.GetWorkflowRequest(auth_info=self._auth_info,
- workflow_name=name)
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def get_workflow(self, uuid, name):
+ msg = service_pb2.GetWorkflowRequest(auth_info=self._project_auth_info, workflow_name=name, workflow_uuid=uuid)
response = self._client.GetWorkflow(request=msg,
metadata=self._get_metadata(),
timeout=Envs.GRPC_CLIENT_TIMEOUT)
if response.status.code != common_pb2.STATUS_SUCCESS:
- logging.debug('get_workflow request error: %s',
- response.status.msg)
+ logging.debug('get_workflow request error: %s', response.status.msg)
return response
@catch_and_fallback(resp_class=service_pb2.UpdateWorkflowResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
- def update_workflow(self, name, config):
- msg = service_pb2.UpdateWorkflowRequest(auth_info=self._auth_info,
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def update_workflow(self, uuid, name, config):
+ msg = service_pb2.UpdateWorkflowRequest(auth_info=self._project_auth_info,
workflow_name=name,
+ workflow_uuid=uuid,
config=config)
response = self._client.UpdateWorkflow(request=msg,
metadata=self._get_metadata(),
timeout=Envs.GRPC_CLIENT_TIMEOUT)
if response.status.code != common_pb2.STATUS_SUCCESS:
- logging.debug('update_workflow request error: %s',
- response.status.msg)
+ logging.debug('update_workflow request error: %s', response.status.msg)
+ return response
+
+ @catch_and_fallback(resp_class=service_pb2.InvalidateWorkflowResponse)
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def invalidate_workflow(self, uuid: str):
+ msg = service_pb2.InvalidateWorkflowRequest(auth_info=self._project_auth_info, workflow_uuid=uuid)
+ response = self._client.InvalidateWorkflow(request=msg,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
+ if response.status.code != common_pb2.STATUS_SUCCESS:
+ logging.debug('invalidate_workflow request error: %s', response.status.msg)
return response
@catch_and_fallback(resp_class=service_pb2.GetJobMetricsResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
def get_job_metrics(self, job_name):
- msg = service_pb2.GetJobMetricsRequest(auth_info=self._auth_info,
- job_name=job_name)
+ msg = service_pb2.GetJobMetricsRequest(auth_info=self._project_auth_info, job_name=job_name)
response = self._client.GetJobMetrics(request=msg,
metadata=self._get_metadata(),
timeout=Envs.GRPC_CLIENT_TIMEOUT)
if response.status.code != common_pb2.STATUS_SUCCESS:
- logging.debug('get_job_metrics request error: %s',
- response.status.msg)
+ logging.debug('get_job_metrics request error: %s', response.status.msg)
return response
@catch_and_fallback(resp_class=service_pb2.GetJobMetricsResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
def get_job_kibana(self, job_name, json_args):
- msg = service_pb2.GetJobKibanaRequest(auth_info=self._auth_info,
- job_name=job_name,
- json_args=json_args)
+ msg = service_pb2.GetJobKibanaRequest(auth_info=self._project_auth_info, job_name=job_name, json_args=json_args)
response = self._client.GetJobKibana(request=msg,
metadata=self._get_metadata(),
timeout=Envs.GRPC_CLIENT_TIMEOUT)
@@ -179,14 +232,13 @@ def get_job_kibana(self, job_name, json_args):
raise UnauthorizedException(status.msg)
if status.code == common_pb2.STATUS_INVALID_ARGUMENT:
raise InvalidArgumentException(status.msg)
- logging.debug('get_job_kibana request error: %s',
- response.status.msg)
+ logging.debug('get_job_kibana request error: %s', response.status.msg)
return response
@catch_and_fallback(resp_class=service_pb2.GetJobEventsResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
def get_job_events(self, job_name, start_time, max_lines):
- msg = service_pb2.GetJobEventsRequest(auth_info=self._auth_info,
+ msg = service_pb2.GetJobEventsRequest(auth_info=self._project_auth_info,
job_name=job_name,
start_time=start_time,
max_lines=max_lines)
@@ -195,21 +247,135 @@ def get_job_events(self, job_name, start_time, max_lines):
timeout=Envs.GRPC_CLIENT_TIMEOUT)
if response.status.code != common_pb2.STATUS_SUCCESS:
- logging.debug('get_job_events request error: %s',
- response.status.msg)
+ logging.debug('get_job_events request error: %s', response.status.msg)
return response
@catch_and_fallback(resp_class=service_pb2.CheckJobReadyResponse)
- @retry_fn(retry_times=3, needed_exceptions=[grpc.RpcError])
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
def check_job_ready(self, job_name: str) \
-> service_pb2.CheckJobReadyResponse:
- msg = service_pb2.CheckJobReadyRequest(auth_info=self._auth_info,
- job_name=job_name)
+ msg = service_pb2.CheckJobReadyRequest(auth_info=self._project_auth_info, job_name=job_name)
response = self._client.CheckJobReady(request=msg,
timeout=Envs.GRPC_CLIENT_TIMEOUT,
metadata=self._get_metadata())
if response.status.code != common_pb2.STATUS_SUCCESS:
- logging.debug('check_job_ready request error: %s',
- response.status.msg)
+ logging.debug('check_job_ready request error: %s', response.status.msg)
+ return response
+
+ @catch_and_fallback(resp_class=service_pb2.TwoPcResponse)
+ @retry_fn(retry_times=3, need_retry=_default_need_retry, delay=200, backoff=2)
+ def run_two_pc(self, transaction_uuid: str, two_pc_type: TwoPcType, action: TwoPcAction,
+ data: TransactionData) -> service_pb2.TwoPcResponse:
+ msg = service_pb2.TwoPcRequest(auth_info=self._project_auth_info,
+ transaction_uuid=transaction_uuid,
+ type=two_pc_type,
+ action=action,
+ data=data)
+ response = self._client.Run2Pc(request=msg, metadata=self._get_metadata(), timeout=Envs.GRPC_CLIENT_TIMEOUT)
+ return response
+
+ @catch_and_fallback(resp_class=service_pb2.ServingServiceResponse)
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def operate_serving_service(self, operation_type: ServingServiceType, serving_model_uuid: str, model_uuid: str,
+ name: str):
+ msg = service_pb2.ServingServiceRequest(auth_info=self._project_auth_info,
+ operation_type=operation_type,
+ serving_model_uuid=serving_model_uuid,
+ model_uuid=model_uuid,
+ serving_model_name=name)
+ response = self._client.ServingServiceManagement(request=msg,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ if response.status.code != common_pb2.STATUS_SUCCESS:
+ logging.debug('serving_service request error: %s', response.status.msg)
+ return response
+
+ @catch_and_fallback(resp_class=service_pb2.ServingServiceInferenceResponse)
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def inference_serving_service(self, serving_model_uuid: str, example_id: str):
+ msg = service_pb2.ServingServiceInferenceRequest(auth_info=self._project_auth_info,
+ serving_model_uuid=serving_model_uuid,
+ example_id=example_id)
+ response = self._client.ServingServiceInference(request=msg,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ if response.status.code != common_pb2.STATUS_SUCCESS:
+ logging.debug('serving_service request error: %s', response.status.msg)
return response
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def get_model_job(self, model_job_uuid: str, need_metrics: bool = False) -> service_pb2.GetModelJobResponse:
+ request = service_pb2.GetModelJobRequest(auth_info=self._project_auth_info,
+ uuid=model_job_uuid,
+ need_metrics=need_metrics)
+ return self._client.GetModelJob(request, metadata=self._get_metadata(), timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def get_model_job_group(self, model_job_group_uuid: str) -> service_pb2.GetModelJobGroupResponse:
+ request = service_pb2.GetModelJobGroupRequest(auth_info=self._project_auth_info, uuid=model_job_group_uuid)
+ return self._client.GetModelJobGroup(request, metadata=self._get_metadata(), timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def update_model_job_group(self, model_job_group_uuid: str,
+ config: WorkflowDefinition) -> service_pb2.UpdateModelJobGroupResponse:
+ request = service_pb2.UpdateModelJobGroupRequest(auth_info=self._project_auth_info,
+ uuid=model_job_group_uuid,
+ config=config)
+ return self._client.UpdateModelJobGroup(request,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def list_participant_datasets(self,
+ kind: Optional[str] = None,
+ uuid: Optional[str] = None) -> service_pb2.ListParticipantDatasetsResponse:
+ request = service_pb2.ListParticipantDatasetsRequest(auth_info=self._project_auth_info)
+ if kind is not None:
+ request.kind = kind
+ if uuid is not None:
+ request.uuid = uuid
+ return self._client.ListParticipantDatasets(request,
+ metadata=self._get_metadata(),
+ timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_dataset_job(self, uuid: str) -> service_pb2.GetDatasetJobResponse:
+ request = service_pb2.GetDatasetJobRequest(auth_info=self._project_auth_info, uuid=uuid)
+ return self._client.GetDatasetJob(request, metadata=self._get_metadata(), timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def create_dataset_job(self, dataset_job: dataset_pb2.DatasetJob, ticket_uuid: str,
+ dataset: dataset_pb2.Dataset) -> empty_pb2.Empty:
+ request = service_pb2.CreateDatasetJobRequest(auth_info=self._project_auth_info,
+ dataset_job=dataset_job,
+ ticket_uuid=ticket_uuid,
+ dataset=dataset)
+ return self._client.CreateDatasetJob(request, metadata=self._get_metadata(), timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+
+def gen_egress_authority(domain_name: str) -> str:
+ """generate egress host
+ Args:
+ domain_name:
+ ex: 'test-1.com'
+ Returns:
+ authority:
+ ex:'test-1-client-auth.com'
+ """
+ domain_name_prefix = domain_name.rpartition('.')[0]
+ return f'{domain_name_prefix}-client-auth.com'
+
+
+def gen_x_host(domain_name: str) -> str:
+ """generate x host
+ Args:
+ domain_name:
+ ex: 'test-1.com'
+ Returns:
+ x-host:
+ ex:'fedlearner-webconsole-v2.test-1.com'
+ """
+ return f'fedlearner-webconsole-v2.{domain_name}'
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/client_interceptor.py b/web_console_v2/api/fedlearner_webconsole/rpc/client_interceptor.py
new file mode 100644
index 000000000..9b9719c77
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/client_interceptor.py
@@ -0,0 +1,61 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import NamedTuple, Optional, Sequence, Tuple, Union
+
+import grpc
+
+from fedlearner_webconsole.middleware.request_id import GrpcRequestIdMiddleware
+
+
+# pylint: disable=line-too-long
+# Ref: https://github.com/d5h-foss/grpc-interceptor/blob/master/src/grpc_interceptor/client.py#L9
+class _ClientCallDetailsFields(NamedTuple):
+ method: str
+ timeout: Optional[float]
+ metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]]
+ credentials: Optional[grpc.CallCredentials]
+ wait_for_ready: Optional[bool]
+ compression: Optional[grpc.Compression]
+
+
+class ClientCallDetails(_ClientCallDetailsFields, grpc.ClientCallDetails):
+ pass
+
+
+class ClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
+ grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
+
+ def _intercept_call(self, continuation, client_call_details, request_or_iterator):
+ metadata = []
+ if client_call_details.metadata is not None:
+ metadata = list(client_call_details.metadata)
+ # Metadata of ClientCallDetails can not be set directly
+ new_details = ClientCallDetails(client_call_details.method, client_call_details.timeout,
+ GrpcRequestIdMiddleware.add_header(metadata), client_call_details.credentials,
+ client_call_details.wait_for_ready, client_call_details.compression)
+ return continuation(new_details, request_or_iterator)
+
+ def intercept_unary_unary(self, continuation, client_call_details, request):
+ return self._intercept_call(continuation, client_call_details, request)
+
+ def intercept_unary_stream(self, continuation, client_call_details, request):
+ return self._intercept_call(continuation, client_call_details, request)
+
+ def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
+ return self._intercept_call(continuation, client_call_details, request_iterator)
+
+ def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
+ return self._intercept_call(continuation, client_call_details, request_iterator)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/client_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/client_test.py
new file mode 100644
index 000000000..e772d49bb
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/client_test.py
@@ -0,0 +1,281 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import patch
+
+import grpc_testing
+from grpc import StatusCode
+from datetime import datetime
+from google.protobuf.wrappers_pb2 import BoolValue
+from google.protobuf import empty_pb2
+
+from fedlearner_webconsole.middleware.request_id import GrpcRequestIdMiddleware
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TwoPcAction, TransactionData, CreateModelJobData
+from fedlearner_webconsole.participant.models import Participant
+
+from fedlearner_webconsole.proto.service_pb2 import DESCRIPTOR, CheckPeerConnectionRequest, \
+ CheckPeerConnectionResponse, CreateDatasetJobRequest, TwoPcRequest, TwoPcResponse
+from fedlearner_webconsole.rpc.client import RpcClient, _build_grpc_stub
+from fedlearner_webconsole.project.models import Project as ProjectModel
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.proto.common_pb2 import (Status, StatusCode as FedLearnerStatusCode)
+from fedlearner_webconsole.proto.dataset_pb2 import ParticipantDatasetRef
+from fedlearner_webconsole.proto.service_pb2 import (CheckConnectionRequest, ProjAuthInfo)
+from fedlearner_webconsole.proto.service_pb2 import CheckConnectionResponse, \
+ CheckJobReadyResponse, CheckJobReadyRequest, ListParticipantDatasetsRequest, ListParticipantDatasetsResponse, \
+ GetModelJobRequest, GetModelJobResponse
+from fedlearner_webconsole.proto import dataset_pb2, project_pb2
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.dataset.models import DatasetFormat, DatasetKindV2
+from testing.rpc.client import RpcClientTestCase
+from testing.fake_time_patcher import FakeTimePatcher
+
+TARGET_SERVICE = DESCRIPTOR.services_by_name['WebConsoleV2Service']
+
+
+class RpcClientTest(RpcClientTestCase):
+ _TEST_PROJECT_NAME = 'test-project'
+ _TEST_RECEIVER_NAME = 'test-receiver'
+ _TEST_URL = 'localhost:123'
+ _X_HOST_HEADER_KEY = 'x-host'
+ _TEST_X_HOST = 'fedlearner-webconsole-v2.fl-test.com'
+ _TEST_SELF_DOMAIN_NAME = 'fl-test-self.com'
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ participant = Participant(name=cls._TEST_RECEIVER_NAME, domain_name='fl-test.com')
+ job = Job(name='test-job')
+
+ cls._participant = participant
+ cls._project = ProjectModel(name=cls._TEST_PROJECT_NAME)
+ cls._job = job
+
+ # Builds a testing channel
+ cls._fake_channel = grpc_testing.channel(DESCRIPTOR.services_by_name.values(), grpc_testing.strict_real_time())
+ cls._fake_channel_patcher = patch('fedlearner_webconsole.rpc.client.grpc.insecure_channel')
+ cls._mock_build_channel = cls._fake_channel_patcher.start()
+ cls._mock_build_channel.return_value = cls._fake_channel
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._fake_channel_patcher.stop()
+ super().tearDownClass()
+
+ def setUp(self):
+ super().setUp()
+ self._client = RpcClient.from_project_and_participant(self._project.name, self._project.token,
+ self._participant.domain_name)
+
+ def test_build_grpc_stub(self):
+ fake_timer = FakeTimePatcher()
+ fake_timer.start()
+ authority = 'fl-test-client-auth.com'
+
+ # Don't know where to put this check - -
+ self._mock_build_channel.assert_called_once_with(
+ options=[('grpc.default_authority', 'fl-test-client-auth.com')],
+ target='fedlearner-stack-ingress-nginx-controller.default.svc:80')
+ self._mock_build_channel.reset_mock()
+
+ _build_grpc_stub(self._TEST_URL, authority)
+ self._mock_build_channel.assert_called_once_with(options=[('grpc.default_authority', authority)],
+ target=self._TEST_URL)
+ _build_grpc_stub(self._TEST_URL, authority)
+ self.assertEqual(self._mock_build_channel.call_count, 1)
+ # Ticks 61 seconds to timeout
+ fake_timer.interrupt(61)
+ _build_grpc_stub(self._TEST_URL, authority)
+ self.assertEqual(self._mock_build_channel.call_count, 2)
+ fake_timer.stop()
+
+ @patch('fedlearner_webconsole.middleware.request_id.get_current_request_id')
+ def test_request_id_in_metadata(self, mock_get_current_request_id):
+ mock_get_current_request_id.return_value = 'test-request-id'
+
+ call = self.client_execution_pool.submit(self._client.check_connection)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['CheckConnection'])
+ self.assertIn((GrpcRequestIdMiddleware.REQUEST_HEADER_NAME, 'test-request-id'), invocation_metadata)
+ # We don't care the result
+ rpc.terminate(response=CheckConnectionResponse(), code=StatusCode.OK, trailing_metadata=(), details=None)
+
+ def test_check_connection(self):
+ call = self.client_execution_pool.submit(self._client.check_connection)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['CheckConnection'])
+
+ self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata)
+ self.assertEqual(
+ request,
+ CheckConnectionRequest(auth_info=ProjAuthInfo(project_name=self._project.name,
+ target_domain=self._participant.domain_name,
+ auth_token=self._project.token)))
+
+ expected_status = Status(code=FedLearnerStatusCode.STATUS_SUCCESS, msg='test')
+ rpc.terminate(response=CheckConnectionResponse(status=expected_status),
+ code=StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ self.assertEqual(call.result().status, expected_status)
+
+ def test_check_peer_connection(self):
+ call = self.client_execution_pool.submit(self._client.check_peer_connection)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['CheckPeerConnection'])
+
+ self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata)
+
+ self.assertEqual(request, CheckPeerConnectionRequest())
+
+ expected_status = Status(code=FedLearnerStatusCode.STATUS_SUCCESS, msg='received check request successfully!')
+ rpc.terminate(response=CheckPeerConnectionResponse(status=expected_status,
+ application_version={'version': '2.0.1.5'}),
+ code=StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ self.assertEqual(call.result().status, expected_status)
+ self.assertEqual(call.result().application_version.version, '2.0.1.5')
+
+ def test_check_job_ready(self):
+ call = self.client_execution_pool.submit(self._client.check_job_ready, self._job.name)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['CheckJobReady'])
+
+ self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata)
+ self.assertEqual(
+ request,
+ CheckJobReadyRequest(job_name=self._job.name,
+ auth_info=ProjAuthInfo(project_name=self._project.name,
+ target_domain=self._participant.domain_name,
+ auth_token=self._project.token)))
+
+ expected_status = Status(code=FedLearnerStatusCode.STATUS_SUCCESS, msg='test')
+ rpc.terminate(response=CheckJobReadyResponse(status=expected_status),
+ code=StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ self.assertEqual(call.result().status, expected_status)
+
+ def test_run_two_pc(self):
+ transaction_data = TransactionData(create_model_job_data=CreateModelJobData(model_job_name='test model name'))
+ call = self.client_execution_pool.submit(self._client.run_two_pc,
+ transaction_uuid='test-id',
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['Run2Pc'])
+
+ self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata)
+ self.assertEqual(
+ request,
+ TwoPcRequest(auth_info=ProjAuthInfo(project_name=self._project.name,
+ target_domain=self._participant.domain_name,
+ auth_token=self._project.token),
+ transaction_uuid='test-id',
+ type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data))
+
+ expected_status = Status(code=FedLearnerStatusCode.STATUS_SUCCESS, msg='test run two pc')
+ rpc.terminate(response=TwoPcResponse(status=expected_status),
+ code=StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ self.assertEqual(call.result().status, expected_status)
+
+ def test_list_participant_datasets(self):
+ call = self.client_execution_pool.submit(self._client.list_participant_datasets)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['ListParticipantDatasets'])
+
+ self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata)
+ self.assertEqual(
+ request,
+ ListParticipantDatasetsRequest(auth_info=ProjAuthInfo(project_name=self._project.name,
+ target_domain=self._participant.domain_name,
+ auth_token=self._project.token)))
+
+ dataref = ParticipantDatasetRef(uuid='1',
+ name='dataset',
+ format=DatasetFormat.TABULAR.name,
+ file_size=0,
+ updated_at=to_timestamp(datetime(2012, 1, 14, 12, 0, 5)),
+ dataset_kind=DatasetKindV2.RAW.name)
+ rpc.terminate(response=ListParticipantDatasetsResponse(participant_datasets=[dataref]),
+ code=StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ self.assertEqual(call.result(), ListParticipantDatasetsResponse(participant_datasets=[dataref]))
+
+ def test_get_model_job(self):
+ call = self.client_execution_pool.submit(self._client.get_model_job, model_job_uuid='uuid', need_metrics=True)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['GetModelJob'])
+
+ self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata)
+ self.assertEqual(
+ request,
+ GetModelJobRequest(auth_info=ProjAuthInfo(project_name=self._project.name,
+ target_domain=self._participant.domain_name,
+ auth_token=self._project.token),
+ uuid='uuid',
+ need_metrics=True))
+
+ expected_metric = BoolValue(value=True)
+ resp = GetModelJobResponse(metric_is_public=expected_metric)
+ rpc.terminate(response=resp, code=StatusCode.OK, trailing_metadata=(), details=None)
+ self.assertEqual(call.result(), GetModelJobResponse(metric_is_public=expected_metric))
+
+ def test_create_dataset_job(self):
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_participant_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_participant_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ call = self.client_execution_pool.submit(self._client.create_dataset_job,
+ dataset_job=dataset_pb2.DatasetJob(uuid='test'),
+ ticket_uuid='test ticket_uuid',
+ dataset=dataset_pb2.Dataset(participants_info=participants_info))
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ TARGET_SERVICE.methods_by_name['CreateDatasetJob'])
+
+ self.assertIn((self._X_HOST_HEADER_KEY, self._TEST_X_HOST), invocation_metadata)
+ self.assertEqual(
+ request,
+ CreateDatasetJobRequest(auth_info=ProjAuthInfo(project_name=self._project.name,
+ target_domain=self._participant.domain_name,
+ auth_token=self._project.token),
+ dataset_job=dataset_pb2.DatasetJob(uuid='test'),
+ ticket_uuid='test ticket_uuid',
+ dataset=dataset_pb2.Dataset(participants_info=participants_info)))
+
+ resp = empty_pb2.Empty()
+ rpc.terminate(response=resp, code=StatusCode.OK, trailing_metadata=(), details=None)
+ self.assertEqual(call.result(), empty_pb2.Empty())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/server.py b/web_console_v2/api/fedlearner_webconsole/rpc/server.py
index 19b9ac285..99d154b75 100644
--- a/web_console_v2/api/fedlearner_webconsole/rpc/server.py
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/server.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,47 +13,115 @@
# limitations under the License.
# coding: utf-8
-# pylint: disable=broad-except, cyclic-import
+# pylint: disable=broad-except
+from datetime import timedelta
+import inspect
import time
import logging
import json
-import os
import sys
import threading
import traceback
from concurrent import futures
+from functools import wraps
+from envs import Envs
+
import grpc
from grpc_reflection.v1alpha import reflection
-from fedlearner_webconsole.proto import (
- service_pb2, service_pb2_grpc,
- common_pb2, workflow_definition_pb2
-)
+from google.protobuf import empty_pb2
+from google.protobuf.wrappers_pb2 import BoolValue
+from fedlearner_webconsole.middleware.request_id import GrpcRequestIdMiddleware
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.proto import (dataset_pb2, service_pb2, service_pb2_grpc, common_pb2,
+ workflow_definition_pb2)
+from fedlearner_webconsole.proto.review_pb2 import ReviewStatus
+from fedlearner_webconsole.proto.rpc.v2 import system_service_pb2_grpc, system_service_pb2, project_service_pb2_grpc
+from fedlearner_webconsole.proto.service_pb2 import (TwoPcRequest, TwoPcResponse)
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+from fedlearner_webconsole.review.common import NO_CENTRAL_SERVER_UUID
+from fedlearner_webconsole.rpc.auth import get_common_name
+from fedlearner_webconsole.rpc.v2.system_service_server import SystemGrpcService
+from fedlearner_webconsole.serving.services import NegotiatorServingService
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.two_pc.handlers import run_two_pc_action
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
from fedlearner_webconsole.utils.es import es
-from fedlearner_webconsole.db import db, get_session
+from fedlearner_webconsole.db import db
from fedlearner_webconsole.utils.kibana import Kibana
from fedlearner_webconsole.project.models import Project
-from fedlearner_webconsole.workflow.models import (
- Workflow, WorkflowState, TransactionState,
- _merge_workflow_config
-)
-
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.workflow.models import (Workflow, WorkflowState, TransactionState)
+from fedlearner_webconsole.workflow.resource_manager import \
+ merge_workflow_config, ResourceManager
+from fedlearner_webconsole.workflow.service import WorkflowService
+from fedlearner_webconsole.workflow.workflow_controller import invalidate_workflow_locally
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.utils.proto import to_json, to_dict
from fedlearner_webconsole.job.models import Job
from fedlearner_webconsole.job.service import JobService
from fedlearner_webconsole.job.metrics import JobMetricsBuilder
-from fedlearner_webconsole.exceptions import (
- UnauthorizedException, InvalidArgumentException
-)
-from envs import Envs
-
-
+from fedlearner_webconsole.mmgr.models import ModelJobGroup
+from fedlearner_webconsole.exceptions import (UnauthorizedException, InvalidArgumentException)
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.mmgr.models import ModelJob
+from fedlearner_webconsole.mmgr.service import ModelJobService
+from fedlearner_webconsole.dataset.services import DatasetService, DatasetJobService, BatchService
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobKind, \
+ Dataset, ProcessedDataset, DatasetKindV2, DatasetFormat, ResourceState
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.dataset.job_configer.dataset_job_configer import DatasetJobConfiger
+from fedlearner_webconsole.proto.rpc.v2 import job_service_pb2_grpc, job_service_pb2, resource_service_pb2_grpc,\
+ resource_service_pb2
+from fedlearner_webconsole.rpc.v2.auth_server_interceptor import AuthServerInterceptor
+from fedlearner_webconsole.rpc.v2.job_service_server import JobServiceServicer
+from fedlearner_webconsole.rpc.v2.resource_service_server import ResourceServiceServicer
+from fedlearner_webconsole.rpc.v2.project_service_server import ProjectGrpcService
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.audit.decorators import emits_rpc_event, get_two_pc_request_uuid
+
+
+def _set_request_id_for_all_methods():
+ """A hack way to wrap all gRPC methods to set request id in context.
+
+ Why not service interceptor?
+ The request id is attached on thread local, but interceptor is not sharing
+ the same thread with service handler, we are not able to set the context on
+ thread local in interceptor as it will not work in service handler."""
+
+ def set_request_id_in_context(fn):
+
+ @wraps(fn)
+ def wrapper(self, request, context):
+ GrpcRequestIdMiddleware.set_request_id_in_context(context)
+ return fn(self, request, context)
+
+ return wrapper
+
+ def decorate(cls):
+ # A hack to get all methods
+ grpc_methods = service_pb2.DESCRIPTOR.services_by_name['WebConsoleV2Service'].methods_by_name
+ for name, fn in inspect.getmembers(cls, inspect.isfunction):
+ # If this is a gRPC method
+ if name in grpc_methods:
+ setattr(cls, name, set_request_id_in_context(fn))
+ return cls
+
+ return decorate
+
+
+@_set_request_id_for_all_methods()
class RPCServerServicer(service_pb2_grpc.WebConsoleV2ServiceServicer):
+
def __init__(self, server):
self._server = server
def _secure_exc(self):
exc_type, exc_obj, exc_tb = sys.exc_info()
# filter out exc_obj to protect sensitive info
- secure_exc = 'Error %s at '%exc_type
+ secure_exc = f'Error {exc_type} at '
secure_exc += ''.join(traceback.format_tb(exc_tb))
return secure_exc
@@ -61,199 +129,254 @@ def _try_handle_request(self, func, request, context, resp_class):
try:
return func(request, context)
except UnauthorizedException as e:
- return resp_class(
- status=common_pb2.Status(
- code=common_pb2.STATUS_UNAUTHORIZED,
- msg='Invalid auth: %s'%repr(request.auth_info)))
+ return resp_class(status=common_pb2.Status(code=common_pb2.STATUS_UNAUTHORIZED,
+ msg=f'Invalid auth: {repr(request.auth_info)}'))
except Exception as e:
logging.error('%s rpc server error: %s', func.__name__, repr(e))
- return resp_class(
- status=common_pb2.Status(
- code=common_pb2.STATUS_UNKNOWN_ERROR,
- msg=self._secure_exc()))
+ return resp_class(status=common_pb2.Status(code=common_pb2.STATUS_UNKNOWN_ERROR, msg=self._secure_exc()))
def CheckConnection(self, request, context):
- return self._try_handle_request(
- self._server.check_connection, request, context,
- service_pb2.CheckConnectionResponse)
+ return self._try_handle_request(self._server.check_connection, request, context,
+ service_pb2.CheckConnectionResponse)
- def Ping(self, request, context):
- return self._try_handle_request(
- self._server.ping, request, context,
- service_pb2.PingResponse)
+ def CheckPeerConnection(self, request, context):
+ return self._try_handle_request(self._server.check_peer_connection, request, context,
+ service_pb2.CheckPeerConnectionResponse)
+ @emits_rpc_event(resource_type=Event.ResourceType.WORKFLOW,
+ op_type=Event.OperationType.UPDATE_STATE,
+ resource_name_fn=lambda request: request.uuid)
def UpdateWorkflowState(self, request, context):
- return self._try_handle_request(
- self._server.update_workflow_state, request, context,
- service_pb2.UpdateWorkflowStateResponse)
+ return self._try_handle_request(self._server.update_workflow_state, request, context,
+ service_pb2.UpdateWorkflowStateResponse)
def GetWorkflow(self, request, context):
- return self._try_handle_request(
- self._server.get_workflow, request, context,
- service_pb2.GetWorkflowResponse)
+ return self._try_handle_request(self._server.get_workflow, request, context, service_pb2.GetWorkflowResponse)
+ @emits_rpc_event(resource_type=Event.ResourceType.WORKFLOW,
+ op_type=Event.OperationType.UPDATE,
+ resource_name_fn=lambda request: request.workflow_uuid)
def UpdateWorkflow(self, request, context):
- return self._try_handle_request(
- self._server.update_workflow, request, context,
- service_pb2.UpdateWorkflowResponse)
+ return self._try_handle_request(self._server.update_workflow, request, context,
+ service_pb2.UpdateWorkflowResponse)
+
+ @emits_rpc_event(resource_type=Event.ResourceType.WORKFLOW,
+ op_type=Event.OperationType.INVALIDATE,
+ resource_name_fn=lambda request: request.workflow_uuid)
+ def InvalidateWorkflow(self, request, context):
+ return self._try_handle_request(self._server.invalidate_workflow, request, context,
+ service_pb2.InvalidateWorkflowResponse)
def GetJobMetrics(self, request, context):
- return self._try_handle_request(
- self._server.get_job_metrics, request, context,
- service_pb2.GetJobMetricsResponse)
+ return self._try_handle_request(self._server.get_job_metrics, request, context,
+ service_pb2.GetJobMetricsResponse)
def GetJobKibana(self, request, context):
- return self._try_handle_request(
- self._server.get_job_kibana, request, context,
- service_pb2.GetJobKibanaResponse
- )
+ return self._try_handle_request(self._server.get_job_kibana, request, context, service_pb2.GetJobKibanaResponse)
def GetJobEvents(self, request, context):
- return self._try_handle_request(
- self._server.get_job_events, request, context,
- service_pb2.GetJobEventsResponse)
+ return self._try_handle_request(self._server.get_job_events, request, context, service_pb2.GetJobEventsResponse)
def CheckJobReady(self, request, context):
- return self._try_handle_request(
- self._server.check_job_ready, request, context,
- service_pb2.CheckJobReadyResponse)
-
-
+ return self._try_handle_request(self._server.check_job_ready, request, context,
+ service_pb2.CheckJobReadyResponse)
+
+ def _run_2pc(self, request: TwoPcRequest, context: grpc.ServicerContext) -> TwoPcResponse:
+ with db.session_scope() as session:
+ project, _ = self._server.check_auth_info(request.auth_info, context, session)
+ succeeded, message = run_two_pc_action(session=session,
+ tid=request.transaction_uuid,
+ two_pc_type=request.type,
+ action=request.action,
+ data=request.data)
+ session.commit()
+ return TwoPcResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ transaction_uuid=request.transaction_uuid,
+ type=request.type,
+ action=request.action,
+ succeeded=succeeded,
+ message=message)
+
+ @emits_rpc_event(resource_type=Event.ResourceType.UNKNOWN_RESOURCE_TYPE,
+ op_type=Event.OperationType.UNKNOWN_OPERATION_TYPE,
+ resource_name_fn=get_two_pc_request_uuid)
+ def Run2Pc(self, request: TwoPcRequest, context: grpc.ServicerContext):
+ return self._try_handle_request(self._run_2pc, request, context, service_pb2.TwoPcResponse)
+
+ @emits_rpc_event(resource_type=Event.ResourceType.SERVING_SERVICE,
+ op_type=Event.OperationType.OPERATE,
+ resource_name_fn=lambda request: request.serving_model_uuid)
+ def ServingServiceManagement(self, request: service_pb2.ServingServiceRequest,
+ context: grpc.ServicerContext) -> service_pb2.ServingServiceResponse:
+ return self._try_handle_request(self._server.operate_serving_service, request, context,
+ service_pb2.ServingServiceResponse)
+
+ @emits_rpc_event(resource_type=Event.ResourceType.SERVING_SERVICE,
+ op_type=Event.OperationType.INFERENCE,
+ resource_name_fn=lambda request: request.serving_model_uuid)
+ def ServingServiceInference(self, request: service_pb2.ServingServiceInferenceRequest,
+ context: grpc.ServicerContext) -> service_pb2.ServingServiceInferenceResponse:
+ return self._try_handle_request(self._server.inference_serving_service, request, context,
+ service_pb2.ServingServiceInferenceResponse)
+
+ def ClientHeartBeat(self, request, context):
+ return self._server.client_heart_beat(request, context)
+
+ def GetModelJob(self, request, context):
+ return self._server.get_model_job(request, context)
+
+ def GetModelJobGroup(self, request, context):
+ return self._server.get_model_job_group(request, context)
+
+ @emits_rpc_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP,
+ op_type=Event.OperationType.UPDATE,
+ resource_name_fn=lambda request: request.uuid)
+ def UpdateModelJobGroup(self, request, context):
+ return self._server.update_model_job_group(request, context)
+
+ def ListParticipantDatasets(self, request, context):
+ return self._server.list_participant_datasets(request, context)
+
+ def GetDatasetJob(self, request: service_pb2.GetDatasetJobRequest,
+ context: grpc.ServicerContext) -> service_pb2.GetDatasetJobResponse:
+ return self._server.get_dataset_job(request, context)
+
+ @emits_rpc_event(resource_type=Event.ResourceType.DATASET_JOB,
+ op_type=Event.OperationType.CREATE,
+ resource_name_fn=lambda request: request.dataset_job.uuid)
+ def CreateDatasetJob(self, request: service_pb2.CreateDatasetJobRequest,
+ context: grpc.ServicerContext) -> empty_pb2.Empty:
+ return self._server.create_dataset_job(request, context)
+
+
+# TODO(wangsen.0914): make the rpc server clean, move business logic out
class RpcServer(object):
+
def __init__(self):
+ self.started = False
self._lock = threading.Lock()
- self._started = False
self._server = None
- self._app = None
- def start(self, app):
- assert not self._started, 'Already started'
- self._app = app
- listen_port = app.config.get('GRPC_LISTEN_PORT', 1999)
+ def start(self, port: int):
+ assert not self.started, 'Already started'
with self._lock:
- self._server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=20))
- service_pb2_grpc.add_WebConsoleV2ServiceServicer_to_server(
- RPCServerServicer(self), self._server)
- # reflection support server find the proto file path automatically
- # when using grpcurl
- reflection.enable_server_reflection(
- service_pb2.DESCRIPTOR.services_by_name, self._server)
- self._server.add_insecure_port('[::]:%d' % listen_port)
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=30),
+ interceptors=[AuthServerInterceptor()])
+ service_pb2_grpc.add_WebConsoleV2ServiceServicer_to_server(RPCServerServicer(self), self._server)
+ system_service_pb2_grpc.add_SystemServiceServicer_to_server(SystemGrpcService(), self._server)
+ job_service_pb2_grpc.add_JobServiceServicer_to_server(JobServiceServicer(), self._server)
+ resource_service_pb2_grpc.add_ResourceServiceServicer_to_server(ResourceServiceServicer(), self._server)
+ project_service_pb2_grpc.add_ProjectServiceServicer_to_server(ProjectGrpcService(), self._server)
+ # reflection supports server find service by using url, e.g. /SystemService.CheckHealth
+ reflection.enable_server_reflection(service_pb2.DESCRIPTOR.services_by_name, self._server)
+ reflection.enable_server_reflection(system_service_pb2.DESCRIPTOR.services_by_name, self._server)
+ reflection.enable_server_reflection(job_service_pb2.DESCRIPTOR.services_by_name, self._server)
+ reflection.enable_server_reflection(resource_service_pb2.DESCRIPTOR.services_by_name, self._server)
+ self._server.add_insecure_port(f'[::]:{port}')
self._server.start()
- self._started = True
+ self.started = True
def stop(self):
- if not self._started:
+ if not self.started:
return
with self._lock:
self._server.stop(None).wait()
del self._server
- self._started = False
+ self.started = False
- def check_auth_info(self, auth_info, context):
+ def check_auth_info(self, auth_info, context, session):
logging.debug('auth_info: %s', auth_info)
- project = Project.query.filter_by(
- name=auth_info.project_name).first()
+ project = session.query(Project).filter_by(name=auth_info.project_name).first()
if project is None:
- raise UnauthorizedException('Invalid project')
- project_config = project.get_config()
+ raise UnauthorizedException(f'Invalid project {auth_info.project_name}')
# TODO: fix token verification
# if project_config.token != auth_info.auth_token:
# raise UnauthorizedException('Invalid token')
- # Use first participant to mock for unit test
+ service = ParticipantService(session)
+ participants = service.get_participants_by_project(project.id)
# TODO: Fix for multi-peer
- source_party = project_config.participants[0]
- if os.environ.get('FLASK_ENV') == 'production':
+ source_party = participants[0]
+ if Envs.FLASK_ENV == 'production':
+ source_party = None
metadata = dict(context.invocation_metadata())
- # ssl-client-subject-dn example:
- # CN=*.fl-xxx.com,OU=security,O=security,L=beijing,ST=beijing,C=CN
- cn = metadata.get('ssl-client-subject-dn').split(',')[0][5:]
- for party in project_config.participants:
- if party.domain_name == cn:
+ cn = get_common_name(metadata.get('ssl-client-subject-dn'))
+ if not cn:
+ raise UnauthorizedException('Failed to get domain name from certs')
+ pure_domain_name = get_pure_domain_name(cn)
+ for party in participants:
+ if get_pure_domain_name(party.domain_name) == pure_domain_name:
source_party = party
if source_party is None:
- raise UnauthorizedException('Invalid domain')
+ raise UnauthorizedException(f'Invalid domain {pure_domain_name}')
return project, source_party
def check_connection(self, request, context):
- with self._app.app_context():
- _, party = self.check_auth_info(request.auth_info, context)
- logging.debug(
- 'received check_connection from %s', party.domain_name)
- return service_pb2.CheckConnectionResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS))
-
- def ping(self, request, context):
- return service_pb2.PingResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- msg='Pong!')
+ with db.session_scope() as session:
+ _, party = self.check_auth_info(request.auth_info, context, session)
+ logging.debug('received check_connection from %s', party.domain_name)
+ return service_pb2.CheckConnectionResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS))
+
+ def check_peer_connection(self, request, context):
+ logging.debug('received request: check peer connection')
+ with db.session_scope() as session:
+ service = SettingService(session)
+ version = service.get_application_version()
+ return service_pb2.CheckPeerConnectionResponse(status=common_pb2.Status(
+ code=common_pb2.STATUS_SUCCESS, msg='participant received check request successfully!'),
+ application_version=version.to_proto())
def update_workflow_state(self, request, context):
- with self._app.app_context():
- project, party = self.check_auth_info(request.auth_info, context)
- logging.debug(
- 'received update_workflow_state from %s: %s',
- party.domain_name, request)
+ with db.session_scope() as session:
+ project, party = self.check_auth_info(request.auth_info, context, session)
+ logging.debug('received update_workflow_state from %s: %s', party.domain_name, request)
name = request.workflow_name
uuid = request.uuid
forked_from_uuid = request.forked_from_uuid
- forked_from = Workflow.query.filter_by(
+ forked_from = session.query(Workflow).filter_by(
uuid=forked_from_uuid).first().id if forked_from_uuid else None
state = WorkflowState(request.state)
target_state = WorkflowState(request.target_state)
transaction_state = TransactionState(request.transaction_state)
- workflow = Workflow.query.filter_by(
- name=request.workflow_name,
- project_id=project.id).first()
+ workflow = session.query(Workflow).filter_by(name=request.workflow_name, project_id=project.id).first()
if workflow is None:
assert state == WorkflowState.NEW
assert target_state == WorkflowState.READY
- workflow = Workflow(
- name=name,
- project_id=project.id,
- state=state, target_state=target_state,
- transaction_state=transaction_state,
- uuid=uuid,
- forked_from=forked_from,
- extra=request.extra
- )
- db.session.add(workflow)
- db.session.commit()
- db.session.refresh(workflow)
-
- workflow.update_state(
- state, target_state, transaction_state)
- db.session.commit()
- return service_pb2.UpdateWorkflowStateResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- state=workflow.state.value,
- target_state=workflow.target_state.value,
- transaction_state=workflow.transaction_state.value)
+ workflow = Workflow(name=name,
+ project_id=project.id,
+ state=state,
+ target_state=target_state,
+ transaction_state=transaction_state,
+ uuid=uuid,
+ forked_from=forked_from,
+ extra=request.extra)
+ session.add(workflow)
+ session.commit()
+ session.refresh(workflow)
+
+ ResourceManager(session, workflow).update_state(state, target_state, transaction_state)
+ session.commit()
+ return service_pb2.UpdateWorkflowStateResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ state=workflow.state.value,
+ target_state=workflow.target_state.value,
+ transaction_state=workflow.transaction_state.value)
def _filter_workflow(self, workflow, modes):
# filter peer-readable and peer-writable variables
if workflow is None:
return None
- new_wf = workflow_definition_pb2.WorkflowDefinition(
- group_alias=workflow.group_alias,
- is_left=workflow.is_left)
+ new_wf = workflow_definition_pb2.WorkflowDefinition(group_alias=workflow.group_alias)
for var in workflow.variables:
if var.access_mode in modes:
new_wf.variables.append(var)
for job_def in workflow.job_definitions:
# keep yaml template private
- new_jd = workflow_definition_pb2.JobDefinition(
- name=job_def.name,
- job_type=job_def.job_type,
- is_federated=job_def.is_federated,
- dependencies=job_def.dependencies)
+ new_jd = workflow_definition_pb2.JobDefinition(name=job_def.name,
+ job_type=job_def.job_type,
+ is_federated=job_def.is_federated,
+ dependencies=job_def.dependencies)
for var in job_def.variables:
if var.access_mode in modes:
new_jd.variables.append(var)
@@ -261,147 +384,300 @@ def _filter_workflow(self, workflow, modes):
return new_wf
def get_workflow(self, request, context):
- with self._app.app_context():
- project, party = self.check_auth_info(request.auth_info, context)
- workflow = Workflow.query.filter_by(
- name=request.workflow_name,
- project_id=project.id).first()
+ with db.session_scope() as session:
+ project, party = self.check_auth_info(request.auth_info, context, session)
+ # TODO(hangweiqiang): remove workflow name
+ # compatible method for previous version
+ if request.workflow_uuid:
+ workflow = session.query(Workflow).filter_by(uuid=request.workflow_uuid, project_id=project.id).first()
+ else:
+ workflow = session.query(Workflow).filter_by(name=request.workflow_name, project_id=project.id).first()
assert workflow is not None, 'Workflow not found'
config = workflow.get_config()
- config = self._filter_workflow(
- config,
- [
- common_pb2.Variable.PEER_READABLE,
- common_pb2.Variable.PEER_WRITABLE
- ])
+ config = self._filter_workflow(config,
+ [common_pb2.Variable.PEER_READABLE, common_pb2.Variable.PEER_WRITABLE])
# job details
- jobs = [service_pb2.JobDetail(
- name=job.name,
- state=job.get_state_for_frontend(),
- pods=json.dumps(
- job.get_pods_for_frontend(include_private_info=False)))
- for job in workflow.get_jobs()]
+ jobs = [
+ service_pb2.JobDetail(
+ name=job.name,
+ state=job.state.name,
+ created_at=to_timestamp(job.created_at),
+ pods=json.dumps([to_dict(pod)
+ for pod in JobService.get_pods(job, include_private_info=False)]))
+ for job in workflow.get_jobs(session)
+ ]
# fork info
forked_from = ''
if workflow.forked_from:
- forked_from = Workflow.query.get(workflow.forked_from).name
- return service_pb2.GetWorkflowResponse(
- name=request.workflow_name,
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- config=config,
- jobs=jobs,
- state=workflow.state.value,
- target_state=workflow.target_state.value,
- transaction_state=workflow.transaction_state.value,
- forkable=workflow.forkable,
- forked_from=forked_from,
- create_job_flags=workflow.get_create_job_flags(),
- peer_create_job_flags=workflow.get_peer_create_job_flags(),
- fork_proposal_config=workflow.get_fork_proposal_config(),
- uuid=workflow.uuid,
- metric_is_public=workflow.metric_is_public)
+ forked_from = session.query(Workflow).get(workflow.forked_from).name
+ return service_pb2.GetWorkflowResponse(name=workflow.name,
+ status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ config=config,
+ jobs=jobs,
+ state=workflow.state.value,
+ target_state=workflow.target_state.value,
+ transaction_state=workflow.transaction_state.value,
+ forkable=workflow.forkable,
+ forked_from=forked_from,
+ create_job_flags=workflow.get_create_job_flags(),
+ peer_create_job_flags=workflow.get_peer_create_job_flags(),
+ fork_proposal_config=workflow.get_fork_proposal_config(),
+ uuid=workflow.uuid,
+ metric_is_public=workflow.metric_is_public,
+ is_finished=workflow.is_finished())
def update_workflow(self, request, context):
- with self._app.app_context():
- project, party = self.check_auth_info(request.auth_info, context)
- workflow = Workflow.query.filter_by(
- name=request.workflow_name,
- project_id=project.id).first()
+ with db.session_scope() as session:
+ project, party = self.check_auth_info(request.auth_info, context, session)
+ # TODO(hangweiqiang): remove workflow name
+ # compatible method for previous version
+ if request.workflow_uuid:
+ workflow = session.query(Workflow).filter_by(uuid=request.workflow_uuid, project_id=project.id).first()
+ else:
+ workflow = session.query(Workflow).filter_by(name=request.workflow_name, project_id=project.id).first()
assert workflow is not None, 'Workflow not found'
config = workflow.get_config()
- _merge_workflow_config(
- config, request.config,
- [common_pb2.Variable.PEER_WRITABLE])
- workflow.set_config(config)
- db.session.commit()
-
- config = self._filter_workflow(
- config,
- [
- common_pb2.Variable.PEER_READABLE,
- common_pb2.Variable.PEER_WRITABLE
- ])
- return service_pb2.UpdateWorkflowResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- workflow_name=request.workflow_name,
- config=config)
+ merge_workflow_config(config, request.config, [common_pb2.Variable.PEER_WRITABLE])
+ WorkflowService(session).update_config(workflow, config)
+ session.commit()
+
+ config = self._filter_workflow(config,
+ [common_pb2.Variable.PEER_READABLE, common_pb2.Variable.PEER_WRITABLE])
+ # compatible method for previous version
+ if request.workflow_uuid:
+ return service_pb2.UpdateWorkflowResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ workflow_uuid=request.workflow_uuid,
+ config=config)
+ return service_pb2.UpdateWorkflowResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ workflow_name=request.workflow_name,
+ config=config)
+
+ def invalidate_workflow(self, request, context):
+ with db.session_scope() as session:
+ project, party = self.check_auth_info(request.auth_info, context, session)
+ workflow = session.query(Workflow).filter_by(uuid=request.workflow_uuid, project_id=project.id).first()
+ if workflow is None:
+ logging.error(f'Failed to find workflow: {request.workflow_uuid}')
+ return service_pb2.InvalidateWorkflowResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ succeeded=False)
+ invalidate_workflow_locally(session, workflow)
+ session.commit()
+ return service_pb2.InvalidateWorkflowResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ succeeded=True)
def _check_metrics_public(self, request, context):
- project, party = self.check_auth_info(request.auth_info, context)
- job = db.session.query(Job).filter_by(name=request.job_name,
- project_id=project.id).first()
- assert job is not None, f'job {request.job_name} not found'
- workflow = job.workflow
- if not workflow.metric_is_public:
- raise UnauthorizedException('Metric is private!')
- return job
+ with db.session_scope() as session:
+ project, party = self.check_auth_info(request.auth_info, context, session)
+ job = session.query(Job).filter_by(name=request.job_name, project_id=project.id).first()
+ assert job is not None, f'job {request.job_name} not found'
+ workflow = job.workflow
+ if not workflow.metric_is_public:
+ raise UnauthorizedException('Metric is private!')
+ return job
def get_job_metrics(self, request, context):
- with self._app.app_context():
+ with db.session_scope():
job = self._check_metrics_public(request, context)
metrics = JobMetricsBuilder(job).plot_metrics()
- return service_pb2.GetJobMetricsResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- metrics=json.dumps(metrics))
+ return service_pb2.GetJobMetricsResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ metrics=json.dumps(metrics))
def get_job_kibana(self, request, context):
- with self._app.app_context():
+ with db.session_scope():
job = self._check_metrics_public(request, context)
try:
- metrics = Kibana.remote_query(job,
- json.loads(request.json_args))
+ metrics = Kibana.remote_query(job, json.loads(request.json_args))
except UnauthorizedException as ua_e:
return service_pb2.GetJobKibanaResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_UNAUTHORIZED,
- msg=ua_e.message))
+ status=common_pb2.Status(code=common_pb2.STATUS_UNAUTHORIZED, msg=ua_e.message))
except InvalidArgumentException as ia_e:
return service_pb2.GetJobKibanaResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_INVALID_ARGUMENT,
- msg=ia_e.message))
- return service_pb2.GetJobKibanaResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- metrics=json.dumps(metrics))
+ status=common_pb2.Status(code=common_pb2.STATUS_INVALID_ARGUMENT, msg=ia_e.message))
+ return service_pb2.GetJobKibanaResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ metrics=json.dumps(metrics))
def get_job_events(self, request, context):
- with self._app.app_context():
- project, party = self.check_auth_info(request.auth_info, context)
- job = Job.query.filter_by(name=request.job_name,
- project_id=project.id).first()
+ with db.session_scope() as session:
+ project, party = self.check_auth_info(request.auth_info, context, session)
+ job = session.query(Job).filter_by(name=request.job_name, project_id=project.id).first()
assert job is not None, \
f'Job {request.job_name} not found'
- result = es.query_events('filebeat-*', job.name,
- 'fedlearner-operator',
- request.start_time,
- int(time.time() * 1000),
- Envs.OPERATOR_LOG_MATCH_PHRASE
- )[:request.max_lines][::-1]
+ result = es.query_events('filebeat-*', job.name, 'fedlearner-operator', request.start_time,
+ int(time.time() * 1000), Envs.OPERATOR_LOG_MATCH_PHRASE)[:request.max_lines][::-1]
- return service_pb2.GetJobEventsResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- logs=result)
+ return service_pb2.GetJobEventsResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ logs=result)
def check_job_ready(self, request, context):
- with self._app.app_context():
- project, _ = self.check_auth_info(request.auth_info, context)
- job = db.session.query(Job).filter_by(name=request.job_name,
- project_id=project.id).first()
+ with db.session_scope() as session:
+ project, _ = self.check_auth_info(request.auth_info, context, session)
+ job = session.query(Job).filter_by(name=request.job_name, project_id=project.id).first()
assert job is not None, \
f'Job {request.job_name} not found'
- with get_session(db.get_engine()) as session:
- is_ready = JobService(session).is_ready(job)
- return service_pb2.CheckJobReadyResponse(
- status=common_pb2.Status(
- code=common_pb2.STATUS_SUCCESS),
- is_ready=is_ready)
+ is_ready = JobService(session).is_ready(job)
+ return service_pb2.CheckJobReadyResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ is_ready=is_ready)
+
+ def operate_serving_service(self, request, context) -> service_pb2.ServingServiceResponse:
+ with db.session_scope() as session:
+ project, _ = self.check_auth_info(request.auth_info, context, session)
+ return NegotiatorServingService(session).handle_participant_request(request, project)
+
+ def inference_serving_service(self, request, context) -> service_pb2.ServingServiceInferenceResponse:
+ with db.session_scope() as session:
+ project, _ = self.check_auth_info(request.auth_info, context, session)
+ return NegotiatorServingService(session).handle_participant_inference_request(request, project)
+
+ def client_heart_beat(self, request: service_pb2.ClientHeartBeatRequest, context):
+ with db.session_scope() as session:
+ party: Participant = session.query(Participant).filter_by(request.domain_name)
+ if party is None:
+ return service_pb2.ClientHeartBeatResponse(succeeded=False)
+ party.last_connected_at = now()
+ session.commit()
+ return service_pb2.ClientHeartBeatResponse(succeeded=True)
+
+ def get_model_job(self, request: service_pb2.GetModelJobRequest, context) -> service_pb2.GetModelJobResponse:
+ with db.session_scope() as session:
+ project, _ = self.check_auth_info(request.auth_info, context, session)
+ model_job: ModelJob = session.query(ModelJob).filter_by(uuid=request.uuid).first()
+ group_uuid = None
+ if model_job.group:
+ group_uuid = model_job.group.uuid
+ config = model_job.workflow.get_config()
+ config = self._filter_workflow(config,
+ [common_pb2.Variable.PEER_READABLE, common_pb2.Variable.PEER_WRITABLE])
+ metrics = None
+ if request.need_metrics and model_job.job is not None and model_job.metric_is_public:
+ metrics = to_json(ModelJobService(session).query_metrics(model_job))
+ return service_pb2.GetModelJobResponse(name=model_job.name,
+ uuid=model_job.uuid,
+ algorithm_type=model_job.algorithm_type.name,
+ model_job_type=model_job.model_job_type.name,
+ state=model_job.state.name,
+ group_uuid=group_uuid,
+ config=config,
+ metrics=metrics,
+ metric_is_public=BoolValue(value=model_job.metric_is_public))
+
+ def get_model_job_group(self, request: service_pb2.GetModelJobGroupRequest, context):
+ with db.session_scope() as session:
+ project, _ = self.check_auth_info(request.auth_info, context, session)
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(uuid=request.uuid).first()
+ return service_pb2.GetModelJobGroupResponse(name=group.name,
+ uuid=group.uuid,
+ role=group.role.name,
+ authorized=group.authorized,
+ algorithm_type=group.algorithm_type.name,
+ config=group.get_config())
+
+ def update_model_job_group(self, request: service_pb2.UpdateModelJobGroupRequest, context):
+ with db.session_scope() as session:
+ project, _ = self.check_auth_info(request.auth_info, context, session)
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(uuid=request.uuid).first()
+ if not group.authorized:
+ raise UnauthorizedException(f'group {group.name} is not authorized for editing')
+ group.set_config(request.config)
+ session.commit()
+ return service_pb2.UpdateModelJobGroupResponse(uuid=group.uuid, config=group.get_config())
+
+ # TODO(liuhehan): delete after all participants support new rpc
+ def list_participant_datasets(self, request: service_pb2.ListParticipantDatasetsRequest, context):
+ kind = DatasetKindV2(request.kind) if request.kind else None
+ uuid = request.uuid if request.uuid else None
+ state = ResourceState.SUCCEEDED
+ with db.session_scope() as session:
+ project, _ = self.check_auth_info(request.auth_info, context, session)
+ datasets = DatasetService(session=session).get_published_datasets(project.id, kind, uuid, state)
+ return service_pb2.ListParticipantDatasetsResponse(participant_datasets=datasets)
+
+ def get_dataset_job(self, request: service_pb2.GetDatasetJobRequest,
+ context: grpc.ServicerContext) -> service_pb2.GetDatasetJobResponse:
+ with db.session_scope() as session:
+ self.check_auth_info(request.auth_info, context, session)
+ dataset_job_model = session.query(DatasetJob).filter(DatasetJob.uuid == request.uuid).first()
+ if dataset_job_model is None:
+ context.abort(code=grpc.StatusCode.NOT_FOUND, details=f'could not find dataset {request.uuid}')
+ dataset_job = dataset_job_model.to_proto()
+ dataset_job.workflow_definition.MergeFrom(
+ DatasetJobConfiger.from_kind(dataset_job_model.kind, session).get_config())
+ return service_pb2.GetDatasetJobResponse(dataset_job=dataset_job)
+
+ def create_dataset_job(self, request: service_pb2.CreateDatasetJobRequest,
+ context: grpc.ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project, participant = self.check_auth_info(request.auth_info, context, session)
+
+ # this is a hack to allow no ticket_uuid, delete it after all customers update
+ ticket_uuid = request.ticket_uuid if request.ticket_uuid else NO_CENTRAL_SERVER_UUID
+ ticket_helper = get_ticket_helper(session=session)
+ validate = ticket_helper.validate_ticket(
+ ticket_uuid, lambda ticket: ticket.details.uuid == request.dataset_job.result_dataset_uuid and ticket.
+ status == ReviewStatus.APPROVED)
+ if not validate:
+ message = f'[create_dataset_job]: ticket status is not approved, ticket_uuid: {request.ticket_uuid}'
+ logging.warning(message)
+ context.abort(code=grpc.StatusCode.PERMISSION_DENIED, details=message)
+
+ processed_dataset = session.query(ProcessedDataset).filter_by(
+ uuid=request.dataset_job.result_dataset_uuid).first()
+ if processed_dataset is None:
+ # create processed dataset
+ domain_name = SettingService.get_system_info().pure_domain_name
+ dataset_job_config = request.dataset_job.global_configs.global_configs.get(domain_name)
+ dataset = session.query(Dataset).filter_by(uuid=dataset_job_config.dataset_uuid).first()
+ dataset_param = dataset_pb2.DatasetParameter(
+ name=request.dataset_job.result_dataset_name,
+ type=dataset.dataset_type.value,
+ project_id=project.id,
+ kind=DatasetKindV2.PROCESSED.value,
+ format=DatasetFormat(dataset.dataset_format).name,
+ uuid=request.dataset_job.result_dataset_uuid,
+ is_published=True,
+ creator_username=request.dataset.creator_username,
+ )
+ participants_info = request.dataset.participants_info
+ if not Flag.DATASET_AUTH_STATUS_CHECK_ENABLED.value:
+ # auto set participant auth_status and cache to authorized if no need check
+ dataset_param.auth_status = AuthStatus.AUTHORIZED.name
+ participants_info.participants_map[domain_name].auth_status = AuthStatus.AUTHORIZED.name
+ processed_dataset = DatasetService(session=session).create_dataset(dataset_param)
+ processed_dataset.ticket_uuid = request.ticket_uuid
+ processed_dataset.ticket_status = TicketStatus.APPROVED
+ session.flush([processed_dataset])
+ # old dataset job will create data_batch in grpc level
+ # new dataset job will create data_batch before create dataset_job_stage
+ if not request.dataset_job.has_stages:
+ batch_parameter = dataset_pb2.BatchParameter(dataset_id=processed_dataset.id)
+ BatchService(session).create_batch(batch_parameter)
+
+ dataset_job = session.query(DatasetJob).filter_by(uuid=request.dataset_job.uuid).first()
+ if dataset_job is None:
+ time_range = timedelta(days=request.dataset_job.time_range.days,
+ hours=request.dataset_job.time_range.hours)
+ dataset_job = DatasetJobService(session=session).create_as_participant(
+ project_id=project.id,
+ kind=DatasetJobKind(request.dataset_job.kind),
+ global_configs=request.dataset_job.global_configs,
+ config=request.dataset_job.workflow_definition,
+ output_dataset_id=processed_dataset.id,
+ coordinator_id=participant.id,
+ uuid=request.dataset_job.uuid,
+ creator_username=request.dataset_job.creator_username,
+ time_range=time_range if time_range else None)
+ session.flush()
+ AuthService(session=session, dataset_job=dataset_job).initialize_participants_info_as_participant(
+ participants_info=request.dataset.participants_info)
+
+ session.commit()
+ return empty_pb2.Empty()
+
+ def wait_for_termination(self):
+ if not self.started:
+ logging.warning('gRPC service is not yet started, failed to wait')
+ self._server.wait_for_termination()
rpc_server = RpcServer()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/server_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/server_test.py
new file mode 100644
index 000000000..d43fb8fef
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/server_test.py
@@ -0,0 +1,396 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from typing import Optional
+import unittest
+import grpc
+from datetime import datetime, timedelta
+from concurrent import futures
+
+from unittest.mock import patch
+
+from google.protobuf.struct_pb2 import Value
+from google.protobuf.wrappers_pb2 import BoolValue
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.dataset import FakeDatasetJobConfiger
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.rpc.v2 import system_service_pb2_grpc
+from fedlearner_webconsole.proto.rpc.v2.system_service_pb2 import CheckHealthRequest
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto import dataset_pb2, project_pb2, service_pb2, service_pb2_grpc
+from fedlearner_webconsole.proto.metrics_pb2 import ModelJobMetrics
+from fedlearner_webconsole.auth.models import Session
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.rpc.v2.auth_server_interceptor import AuthServerInterceptor
+from fedlearner_webconsole.rpc.server import RPCServerServicer, RpcServer, rpc_server
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetJobState, DatasetKindV2, \
+ DatasetType, ProcessedDataset
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobGroup, ModelJobType
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.utils.proto import to_json
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.review.common import NO_CENTRAL_SERVER_UUID
+
+_FAKE_SYSTEM_INFO = SystemInfo(
+ name='test',
+ domain_name='fl-participant.com',
+ pure_domain_name='participant',
+)
+
+
+def make_check_auth_info(self, auth_info: Optional[service_pb2.ProjAuthInfo], context: Optional[grpc.ServicerContext],
+ session: Session):
+ project = Project(id=1, name='test')
+ participant = Participant(id=1, name='participant', domain_name='test_domain')
+ return project, participant
+
+
+class ServerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ listen_port = 1991
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=20), interceptors=[AuthServerInterceptor()])
+ service_pb2_grpc.add_WebConsoleV2ServiceServicer_to_server(RPCServerServicer(RpcServer()), self._server)
+ self._server.add_insecure_port(f'[::]:{listen_port}')
+ self._server.start()
+
+ self._stub = service_pb2_grpc.WebConsoleV2ServiceStub(grpc.insecure_channel(target=f'localhost:{listen_port}'))
+
+ def tearDown(self):
+ self._server.stop(5)
+ return super().tearDown()
+
+ @patch('fedlearner_webconsole.rpc.server.RpcServer.check_auth_info', lambda *args: ('test_project', 'party_1'))
+ def test_get_dataset_job_unexist(self):
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.GetDatasetJob(service_pb2.GetDatasetJobRequest(auth_info=None, uuid='u1234'))
+
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobConfiger.from_kind',
+ lambda *args: FakeDatasetJobConfiger(None))
+ @patch('fedlearner_webconsole.rpc.server.RpcServer.check_auth_info', lambda *args: ('test_project', 'party_1'))
+ def test_get_dataset_job(self):
+ request = service_pb2.GetDatasetJobRequest(auth_info=None, uuid='dataset_job_uuid')
+ # no dataset_job
+ with self.assertRaises(grpc.RpcError):
+ self._stub.GetDatasetJob(request)
+ # check no output_dataset and workflow failed
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job_uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=0,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.RUNNING,
+ coordinator_id=0,
+ workflow_id=0)
+ session.add(dataset_job)
+ session.commit()
+ resp = self._stub.GetDatasetJob(request)
+ self.assertEqual(resp.dataset_job.uuid, dataset_job.uuid)
+ self.assertEqual(resp.dataset_job.result_dataset_uuid, '')
+ self.assertEqual(resp.dataset_job.is_ready, False)
+ # check is_ready successed
+ with db.session_scope() as session:
+ dataset = Dataset(id=2,
+ name='output dataset',
+ uuid='result_dataset_uuid',
+ path='/data/dataset/321',
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ dataset_kind=DatasetKindV2.PROCESSED)
+ session.add(dataset)
+ workflow = Workflow(id=1, project_id=1, name='test workflow', uuid='dataset_job_uuid')
+ session.add(workflow)
+ dataset_job = session.query(DatasetJob).get(1)
+ dataset_job.workflow_id = 1
+ dataset_job.output_dataset_id = 2
+ session.commit()
+ resp = self._stub.GetDatasetJob(request)
+ self.assertEqual(resp.dataset_job.uuid, dataset_job.uuid)
+ self.assertEqual(resp.dataset_job.result_dataset_uuid, 'result_dataset_uuid')
+ self.assertEqual(resp.dataset_job.is_ready, True)
+
+ @patch('fedlearner_webconsole.rpc.server.RpcServer.check_auth_info', make_check_auth_info)
+ @patch('fedlearner_webconsole.rpc.server.SettingService.get_system_info', lambda: _FAKE_SYSTEM_INFO)
+ def test_create_dataset_job(self):
+ with db.session_scope() as session:
+ project = Project(id=1, name='test')
+ session.add(project)
+ dataset = Dataset(id=1,
+ name='input dataset',
+ uuid='raw_dataset_uuid',
+ path='/data/dataset/321',
+ is_published=True,
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ dataset_kind=DatasetKindV2.RAW)
+ session.add(dataset)
+ session.commit()
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job_uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.RUNNING,
+ coordinator_id=0,
+ workflow_id=1)
+ dataset_job_parameter = dataset_pb2.DatasetJob(
+ uuid=dataset_job.uuid,
+ kind=dataset_job.kind.value,
+ global_configs=dataset_pb2.DatasetJobGlobalConfigs(global_configs={
+ _FAKE_SYSTEM_INFO.pure_domain_name: dataset_pb2.DatasetJobConfig(dataset_uuid='raw_dataset_uuid')
+ }),
+ workflow_definition=WorkflowDefinition(group_alias='test'),
+ result_dataset_uuid='dataset_uuid',
+ result_dataset_name='dataset_name',
+ creator_username='test user')
+ request = service_pb2.CreateDatasetJobRequest(auth_info=None,
+ dataset_job=dataset_job_parameter,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID)
+ self._stub.CreateDatasetJob(request)
+ with db.session_scope() as session:
+ dataset = session.query(ProcessedDataset).filter_by(uuid='dataset_uuid').first()
+ self.assertEqual(dataset.name, 'dataset_name')
+ self.assertEqual(dataset.dataset_kind, DatasetKindV2.PROCESSED)
+ self.assertEqual(len(dataset.data_batches), 1)
+ self.assertEqual(dataset.is_published, True)
+ dataset_job = session.query(DatasetJob).filter_by(uuid='dataset_job_uuid').first()
+ self.assertIsNotNone(dataset_job)
+ self.assertEqual(dataset_job.output_dataset_id, dataset.id)
+ self.assertEqual(dataset_job.creator_username, 'test user')
+ self.assertIsNone(dataset_job.time_range)
+
+ # test with time_range
+ dataset_job = DatasetJob(id=2,
+ uuid='dataset_job_uuid with time_range',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.RUNNING,
+ coordinator_id=0,
+ workflow_id=1,
+ time_range=timedelta(days=1))
+ dataset_job_parameter = dataset_pb2.DatasetJob(
+ uuid=dataset_job.uuid,
+ kind=dataset_job.kind.value,
+ global_configs=dataset_pb2.DatasetJobGlobalConfigs(global_configs={
+ _FAKE_SYSTEM_INFO.pure_domain_name: dataset_pb2.DatasetJobConfig(dataset_uuid='raw_dataset_uuid')
+ }),
+ workflow_definition=WorkflowDefinition(group_alias='test'),
+ result_dataset_uuid='dataset_uuid wit time_range',
+ result_dataset_name='dataset_name',
+ creator_username='test user',
+ time_range=dataset_job.time_range_pb)
+ request = service_pb2.CreateDatasetJobRequest(auth_info=None,
+ dataset_job=dataset_job_parameter,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID)
+ self._stub.CreateDatasetJob(request)
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).filter_by(uuid='dataset_job_uuid with time_range').first()
+ self.assertEqual(dataset_job.time_range, timedelta(days=1))
+
+ @patch('fedlearner_webconsole.rpc.server.RpcServer.check_auth_info', make_check_auth_info)
+ @patch('fedlearner_webconsole.rpc.server.SettingService.get_system_info', lambda: _FAKE_SYSTEM_INFO)
+ def test_create_dataset_job_has_stage(self):
+ with db.session_scope() as session:
+ project = Project(id=1, name='test')
+ session.add(project)
+ streaming_dataset = Dataset(id=1,
+ name='input streaming_dataset',
+ uuid='raw_dataset_uuid',
+ path='/data/dataset/321',
+ is_published=True,
+ project_id=1,
+ created_at=datetime(2012, 1, 14, 12, 0, 7),
+ dataset_kind=DatasetKindV2.RAW,
+ dataset_type=DatasetType.STREAMING.value)
+ session.add(streaming_dataset)
+ session.commit()
+ dataset_job_parameter = dataset_pb2.DatasetJob(
+ uuid='dataset_job_uuid',
+ kind=DatasetJobKind.DATA_ALIGNMENT.value,
+ global_configs=dataset_pb2.DatasetJobGlobalConfigs(global_configs={
+ _FAKE_SYSTEM_INFO.pure_domain_name: dataset_pb2.DatasetJobConfig(dataset_uuid='raw_dataset_uuid')
+ }),
+ workflow_definition=WorkflowDefinition(group_alias='test'),
+ result_dataset_uuid='dataset_uuid',
+ result_dataset_name='dataset_name',
+ has_stages=True,
+ )
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_participant_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_participant_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'participant': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name)
+ })
+ dataset_parameter = dataset_pb2.Dataset(participants_info=participants_info, creator_username='test user')
+ request = service_pb2.CreateDatasetJobRequest(auth_info=None,
+ dataset_job=dataset_job_parameter,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID,
+ dataset=dataset_parameter)
+ self._stub.CreateDatasetJob(request)
+ with db.session_scope() as session:
+ dataset = session.query(ProcessedDataset).filter_by(uuid='dataset_uuid').first()
+ self.assertEqual(dataset.name, 'dataset_name')
+ self.assertEqual(dataset.dataset_kind, DatasetKindV2.PROCESSED)
+ self.assertEqual(len(dataset.data_batches), 0)
+ self.assertEqual(dataset.is_published, True)
+ self.assertEqual(dataset.dataset_type, DatasetType.STREAMING)
+ self.assertEqual(dataset.creator_username, 'test user')
+ dataset_job = session.query(DatasetJob).filter_by(uuid='dataset_job_uuid').first()
+ self.assertIsNotNone(dataset_job)
+ self.assertEqual(dataset_job.output_dataset_id, dataset.id)
+ self.assertEqual(dataset.ticket_uuid, NO_CENTRAL_SERVER_UUID)
+ self.assertEqual(dataset.ticket_status, TicketStatus.APPROVED)
+ expected_participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_participant_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_participant_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'participant': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ self.assertEqual(dataset.get_participants_info(), expected_participants_info)
+
+ @patch('fedlearner_webconsole.rpc.server.RpcServer.check_auth_info', make_check_auth_info)
+ def test_update_workflow(self):
+ uuid = 'u1ff0ab5596bb487e96c'
+ with db.session_scope() as session:
+ var1 = Variable(name='hello',
+ value_type=Variable.NUMBER,
+ typed_value=Value(number_value=1),
+ access_mode=Variable.PEER_WRITABLE)
+ var2 = Variable(name='hello',
+ value_type=Variable.NUMBER,
+ typed_value=Value(number_value=1),
+ access_mode=Variable.PEER_READABLE)
+ jd = JobDefinition(name='test1', yaml_template='{}', variables=[var1, var2])
+ wd = WorkflowDefinition(job_definitions=[jd])
+ workflow = Workflow(
+ name='test-workflow',
+ uuid=uuid,
+ project_id=1,
+ config=wd.SerializeToString(),
+ )
+ session.add(workflow)
+ session.flush()
+ job = Job(name='test_job',
+ config=jd.SerializeToString(),
+ workflow_id=workflow.id,
+ job_type=JobType(1),
+ project_id=1,
+ is_disabled=False)
+ session.add(job)
+ session.flush()
+ workflow.job_ids = str(job.id)
+ session.commit()
+ var1 = Variable(name='hello',
+ value_type=Variable.NUMBER,
+ typed_value=Value(number_value=2),
+ access_mode=Variable.PEER_WRITABLE)
+ var2 = Variable(name='hello',
+ value_type=Variable.NUMBER,
+ typed_value=Value(number_value=2),
+ access_mode=Variable.PEER_READABLE)
+ jd = JobDefinition(name='test1', yaml_template='{}', variables=[var1, var2])
+ wd = WorkflowDefinition(job_definitions=[jd])
+ request = service_pb2.UpdateWorkflowRequest(auth_info=None, workflow_uuid=uuid, config=wd)
+ self._stub.UpdateWorkflow(request)
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).filter_by(uuid=uuid).first()
+ self.assertEqual(workflow.get_config().job_definitions[0].variables[0].typed_value, Value(number_value=2))
+ self.assertEqual(workflow.get_config().job_definitions[0].variables[1].typed_value, Value(number_value=1))
+ jd = workflow.get_jobs(session)[0].get_config()
+ self.assertEqual(jd.variables[0].typed_value, Value(number_value=2))
+ self.assertEqual(jd.variables[1].typed_value, Value(number_value=1))
+
+ @patch('fedlearner_webconsole.rpc.server.RpcServer.check_auth_info', make_check_auth_info)
+ @patch('fedlearner_webconsole.mmgr.service.ModelJobService.query_metrics')
+ def test_get_model_job(self, mock_query_metrics):
+ metrics = ModelJobMetrics()
+ metric = metrics.train.get_or_create('acc')
+ metric.steps.extend([1, 2, 3])
+ metric.values.extend([1.0, 2.0, 3.0])
+ mock_query_metrics.return_value = metrics
+ with db.session_scope() as session:
+ job = Job(name='uuid-job',
+ project_id=1,
+ workflow_id=1,
+ job_type=JobType.NN_MODEL_TRANINING,
+ state=JobState.COMPLETED)
+ workflow = Workflow(id=1, name='workflow', uuid='uuid')
+ group = ModelJobGroup(id=1, name='group', uuid='uuid', project_id=1)
+ model_job = ModelJob(id=1,
+ name='job',
+ uuid='uuid',
+ group_id=1,
+ project_id=1,
+ metric_is_public=False,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ model_job_type=ModelJobType.TRAINING,
+ workflow_uuid='uuid',
+ job_name='uuid-job')
+ session.add_all([job, workflow, group, model_job])
+ session.commit()
+ request = service_pb2.GetModelJobRequest(auth_info=None, uuid='uuid', need_metrics=True)
+ resp = self._stub.GetModelJob(request)
+ mock_query_metrics.assert_not_called()
+ expected_resp = service_pb2.GetModelJobResponse(name='job',
+ uuid='uuid',
+ algorithm_type='NN_VERTICAL',
+ model_job_type='TRAINING',
+ group_uuid='uuid',
+ state='INVALID',
+ metric_is_public=BoolValue(value=False))
+ self.assertEqual(resp, expected_resp)
+ with db.session_scope() as session:
+ model_job: ModelJob = session.query(ModelJob).get(1)
+ model_job.metric_is_public = True
+ session.commit()
+ resp = self._stub.GetModelJob(request)
+ mock_query_metrics.assert_called()
+ expected_resp.metric_is_public.MergeFrom(BoolValue(value=True))
+ expected_resp.metrics = to_json(metrics)
+ self.assertEqual(resp, expected_resp)
+
+
+class RpcServerTest(NoWebServerTestCase):
+
+ def test_smoke_test(self):
+ rpc_server.start(13546)
+ # Waits for server ready
+ time.sleep(2)
+ stub = system_service_pb2_grpc.SystemServiceStub(grpc.insecure_channel(target='localhost:13546'))
+ self.assertIsNotNone(
+ stub.CheckHealth(CheckHealthRequest(),
+ metadata=[('ssl-client-subject-dn',
+ 'CN=aaa.fedlearner.net,OU=security,O=security,L=beijing,ST=beijing,C=CN')]))
+ rpc_server.stop()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/rpc/v2/BUILD.bazel
new file mode 100644
index 000000000..6dd8e5680
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/BUILD.bazel
@@ -0,0 +1,483 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "auth_client_interceptor_lib",
+ srcs = ["auth_client_interceptor.py"],
+ imports = ["../../.."],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:auth_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "auth_client_interceptor_lib_test",
+ srcs = [
+ "auth_client_interceptor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "auth_client_interceptor_test.py",
+ deps = [
+ ":auth_client_interceptor_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:auth_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/testing:py_proto",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "auth_server_interceptor_lib",
+ srcs = ["auth_server_interceptor.py"],
+ imports = ["../../.."],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:auth_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "auth_server_interceptor_lib_test",
+ srcs = [
+ "auth_server_interceptor_test.py",
+ ],
+ imports = ["../../.."],
+ main = "auth_server_interceptor_test.py",
+ deps = [
+ ":auth_server_interceptor_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/testing:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "//web_console_v2/api/testing/rpc:service_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_library(
+ name = "client_base_lib",
+ srcs = ["client_base.py"],
+ imports = ["../../.."],
+ deps = [
+ ":auth_client_interceptor_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "client_base_lib_test",
+ srcs = [
+ "client_base_test.py",
+ ],
+ imports = ["../../.."],
+ main = "client_base_test.py",
+ deps = [
+ ":client_base_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/testing:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@common_grpcio_testing//:pkg",
+ ],
+)
+
+py_library(
+ name = "system_service_client_lib",
+ srcs = ["system_service_client.py"],
+ imports = ["../../.."],
+ deps = [
+ ":client_base_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "system_service_client_lib_test",
+ srcs = [
+ "system_service_client_test.py",
+ ],
+ imports = ["../../.."],
+ main = "system_service_client_test.py",
+ deps = [
+ ":system_service_client_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_grpcio_testing//:pkg",
+ ],
+)
+
+py_library(
+ name = "system_service_server_lib",
+ srcs = ["system_service_server.py"],
+ imports = ["../../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:services_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "system_service_server_lib_test",
+ srcs = [
+ "system_service_server_test.py",
+ ],
+ imports = ["../../.."],
+ main = "system_service_server_test.py",
+ deps = [
+ ":system_service_server_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:app_version_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "job_service_client_lib",
+ srcs = ["job_service_client.py"],
+ imports = ["../../.."],
+ deps = [
+ ":client_base_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "job_service_client_lib_test",
+ size = "medium",
+ srcs = [
+ "job_service_client_test.py",
+ ],
+ imports = ["../../.."],
+ main = "job_service_client_test.py",
+ deps = [
+ ":job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils.py"],
+ imports = ["../../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:auth_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_base64_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "utils_lib_test",
+ srcs = [
+ "utils_test.py",
+ ],
+ imports = ["../../.."],
+ main = "utils_test.py",
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:auth_lib",
+ ],
+)
+
+py_library(
+ name = "job_service_server_lib",
+ srcs = ["job_service_server.py"],
+ imports = ["../../.."],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:local_controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset/job_configer",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "job_service_server_lib_test",
+ size = "small",
+ srcs = [
+ "job_service_server_test.py",
+ ],
+ imports = ["../../.."],
+ main = "job_service_server_test.py",
+ deps = [
+ ":job_service_server_lib",
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:auth_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "resource_service_server_lib",
+ srcs = ["resource_service_server.py"],
+ imports = ["../../.."],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm/transmit",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "resource_service_server_lib_test",
+ size = "small",
+ srcs = [
+ "resource_service_server_test.py",
+ ],
+ imports = ["../../.."],
+ main = "resource_service_server_test.py",
+ deps = [
+ ":resource_service_server_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "resource_service_client_lib",
+ srcs = ["resource_service_client.py"],
+ imports = ["../../.."],
+ deps = [
+ ":client_base_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "resource_service_client_lib_test",
+ size = "small",
+ srcs = [
+ "resource_service_client_test.py",
+ ],
+ imports = ["../../.."],
+ main = "resource_service_client_test.py",
+ deps = [
+ ":resource_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_grpcio_testing//:pkg",
+ ],
+)
+
+py_library(
+ name = "project_service_server_lib",
+ srcs = ["project_service_server.py"],
+ imports = ["../../.."],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:service_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "project_service_server_lib_test",
+ size = "small",
+ srcs = [
+ "project_service_server_test.py",
+ ],
+ imports = ["../../.."],
+ main = "project_service_server_test.py",
+ deps = [
+ ":project_service_server_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "project_service_client_lib",
+ srcs = ["project_service_client.py"],
+ imports = ["../../.."],
+ deps = [
+ ":client_base_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "project_service_client_lib_test",
+ size = "small",
+ srcs = [
+ "project_service_client_test.py",
+ ],
+ imports = ["../../.."],
+ main = "project_service_client_test.py",
+ deps = [
+ ":project_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_grpcio_testing//:pkg",
+ ],
+)
+
+py_library(
+ name = "review_service_client_lib",
+ srcs = ["review_service_client.py"],
+ imports = ["../../.."],
+ deps = [
+ ":client_base_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_test(
+ name = "review_service_client_lib_test",
+ srcs = [
+ "review_service_client_test.py",
+ ],
+ imports = ["../../.."],
+ main = "review_service_client_test.py",
+ deps = [
+ ":review_service_client_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_grpcio_testing//:pkg",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_client_interceptor.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_client_interceptor.py
new file mode 100644
index 000000000..b8d8d9310
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_client_interceptor.py
@@ -0,0 +1,111 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Any, Callable, Iterator, NamedTuple, Optional, Sequence, Tuple, Union
+
+import grpc
+
+from fedlearner_webconsole.rpc.auth import PROJECT_NAME_HEADER, X_HOST_HEADER
+from fedlearner_webconsole.rpc.v2.utils import encode_project_name
+
+
+# Ref: https://github.com/d5h-foss/grpc-interceptor/blob/master/src/grpc_interceptor/client.py#L9
+class _ClientCallDetailsFields(NamedTuple):
+ method: str
+ timeout: Optional[float]
+ metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]]
+ credentials: Optional[grpc.CallCredentials]
+ wait_for_ready: Optional[bool]
+ compression: Optional[grpc.Compression]
+
+
+class ClientCallDetails(_ClientCallDetailsFields, grpc.ClientCallDetails):
+ pass
+
+
+RequestOrIterator = Union[Any, Iterator[Any]]
+
+
+class _RpcErrorOutcome(grpc.RpcError, grpc.Future):
+
+ def __init__(self, rpc_error: grpc.RpcError):
+ super().__init__()
+ self._error = rpc_error
+
+ def cancel(self):
+ return False
+
+ def cancelled(self):
+ return False
+
+ def running(self):
+ return False
+
+ def done(self):
+ return True
+
+ def result(self, timeout=None):
+ raise self._error
+
+ def exception(self, timeout=None):
+ return self._error
+
+ def traceback(self, timeout=None):
+ return ''
+
+ def add_done_callback(self, fn):
+ fn(self)
+
+
+class AuthClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
+ grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
+
+ def __init__(self, x_host: str, project_name: Optional[str] = None):
+ super().__init__()
+ self.x_host = x_host
+ self.project_name = project_name
+
+ def _intercept_call(self, continuation: Callable[[ClientCallDetails, RequestOrIterator], Any],
+ client_call_details: ClientCallDetails, request_or_iterator: RequestOrIterator):
+ metadata = []
+ if client_call_details.metadata is not None:
+ metadata = list(client_call_details.metadata)
+ metadata.append((X_HOST_HEADER, self.x_host))
+ if self.project_name is not None:
+ metadata.append((PROJECT_NAME_HEADER, encode_project_name(self.project_name)))
+
+ # Metadata of ClientCallDetails can not be set directly
+ new_details = ClientCallDetails(client_call_details.method, client_call_details.timeout, metadata,
+ client_call_details.credentials, client_call_details.wait_for_ready,
+ client_call_details.compression)
+ response_future = continuation(new_details, request_or_iterator)
+ # This is a hack for testing only that grpc interceptor will treat testing channel's grpc error as
+ # a regular response instead of an exception, whose interface is different with channel,
+ # it was introduced in https://github.com/grpc/grpc/pull/17317
+ if isinstance(response_future, grpc.RpcError) and not isinstance(response_future, grpc.Future):
+ return _RpcErrorOutcome(response_future)
+ return response_future
+
+ def intercept_unary_unary(self, continuation, client_call_details, request):
+ return self._intercept_call(continuation, client_call_details, request)
+
+ def intercept_unary_stream(self, continuation, client_call_details, request):
+ return self._intercept_call(continuation, client_call_details, request)
+
+ def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
+ return self._intercept_call(continuation, client_call_details, request_iterator)
+
+ def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
+ return self._intercept_call(continuation, client_call_details, request_iterator)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_client_interceptor_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_client_interceptor_test.py
new file mode 100644
index 000000000..552b667c9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_client_interceptor_test.py
@@ -0,0 +1,109 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+import unittest
+
+import grpc
+import grpc_testing
+
+from fedlearner_webconsole.proto.testing import service_pb2, service_pb2_grpc
+from fedlearner_webconsole.rpc.auth import PROJECT_NAME_HEADER, X_HOST_HEADER
+from fedlearner_webconsole.rpc.v2.auth_client_interceptor import AuthClientInterceptor
+from google.protobuf.descriptor import ServiceDescriptor
+from testing.rpc.client import RpcClientTestCase
+
+_TEST_SERVICE_DESCRIPTOR: ServiceDescriptor = service_pb2.DESCRIPTOR.services_by_name['TestService']
+
+
+class AuthClientInterceptorTest(RpcClientTestCase):
+ _X_HOST = 'fedlearner-webconsole-v2.fl-test.com'
+
+ def set_up(self, project_name: Optional[str] = None) -> service_pb2_grpc.TestServiceStub:
+ self._fake_channel: grpc_testing.Channel = grpc_testing.channel([_TEST_SERVICE_DESCRIPTOR],
+ grpc_testing.strict_real_time())
+ channel = grpc.intercept_channel(self._fake_channel,
+ AuthClientInterceptor(x_host=self._X_HOST, project_name=project_name))
+ self._stub = service_pb2_grpc.TestServiceStub(channel)
+
+ def test_x_host(self):
+ self.set_up()
+ call = self.client_execution_pool.submit(self._stub.FakeUnaryUnary, service_pb2.FakeUnaryUnaryRequest())
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _TEST_SERVICE_DESCRIPTOR.methods_by_name['FakeUnaryUnary'])
+
+ self.assertIn((X_HOST_HEADER, self._X_HOST), invocation_metadata)
+ rpc.terminate(response=service_pb2.FakeUnaryUnaryResponse(),
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ # Waits for finish
+ call.result()
+
+ def test_x_host_unauthenticated(self):
+ self.set_up()
+
+ call = self.client_execution_pool.submit(self._stub.FakeUnaryUnary, service_pb2.FakeUnaryUnaryRequest())
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _TEST_SERVICE_DESCRIPTOR.methods_by_name['FakeUnaryUnary'])
+
+ self.assertIn((X_HOST_HEADER, self._X_HOST), invocation_metadata)
+ rpc.terminate(response=service_pb2.FakeUnaryUnaryResponse(),
+ code=grpc.StatusCode.UNAUTHENTICATED,
+ trailing_metadata=(),
+ details=None)
+ with self.assertRaises(grpc.RpcError) as cm:
+ call.result()
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+
+ def test_project_name(self):
+ self.set_up(project_name='test-project-113')
+
+ call = self.client_execution_pool.submit(self._stub.FakeUnaryUnary, service_pb2.FakeUnaryUnaryRequest())
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _TEST_SERVICE_DESCRIPTOR.methods_by_name['FakeUnaryUnary'])
+
+ self.assertIn((X_HOST_HEADER, self._X_HOST), invocation_metadata)
+ self.assertIn((PROJECT_NAME_HEADER, 'test-project-113'), invocation_metadata)
+ rpc.terminate(response=service_pb2.FakeUnaryUnaryResponse(),
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ # Waits for finish
+ call.result()
+
+ def test_project_name_unicode(self):
+ self.set_up(project_name='test中文')
+
+ call = self.client_execution_pool.submit(self._stub.FakeUnaryUnary, service_pb2.FakeUnaryUnaryRequest())
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _TEST_SERVICE_DESCRIPTOR.methods_by_name['FakeUnaryUnary'])
+
+ self.assertIn((X_HOST_HEADER, self._X_HOST), invocation_metadata)
+ self.assertIn((PROJECT_NAME_HEADER, 'dGVzdOS4reaWhw=='), invocation_metadata)
+ rpc.terminate(response=service_pb2.FakeUnaryUnaryResponse(),
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None)
+ # Waits for finish
+ call.result()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_server_interceptor.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_server_interceptor.py
new file mode 100644
index 000000000..5dbdc8cbe
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_server_interceptor.py
@@ -0,0 +1,157 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Callable, Any, Optional, Tuple
+
+import grpc
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.rpc.auth import get_common_name, SSL_CLIENT_SUBJECT_DN_HEADER, PROJECT_NAME_HEADER
+from fedlearner_webconsole.rpc.v2.utils import decode_project_name
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+
+# Which services should disable the auth interceptors.
+DISABLED_SERVICES = frozenset([
+ # Skips the old gRPC service as it has a separate way to check auth.
+ 'fedlearner_webconsole.proto.WebConsoleV2Service',
+])
+# Which services should use project-based auth interceptors.
+PROJECT_BASED_SERVICES = frozenset([
+ 'fedlearner_webconsole.proto.rpc.v2.JobService',
+])
+
+
+class AuthException(Exception):
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+
+
+def _get_handler_factory(handler: grpc.RpcMethodHandler) -> Callable:
+ if handler.unary_unary:
+ return grpc.unary_unary_rpc_method_handler
+ if handler.unary_stream:
+ return grpc.unary_stream_rpc_method_handler
+ if handler.stream_unary:
+ return grpc.stream_unary_rpc_method_handler
+ if handler.stream_stream:
+ return grpc.stream_stream_rpc_method_handler
+ raise RuntimeError(f'Unrecognized rpc handler: {handler}')
+
+
+def _parse_method_name(method_full_name: str) -> Tuple[str, str]:
+ """Parses grpc method name in service interceptor.
+
+ Arguments:
+ method_full_name: Full name of the method, e.g. /fedlearner_webconsole.proto.testing.TestService/FakeUnaryUnary
+
+ Returns:
+ A tuple of service name and method name, e.g. 'fedlearner_webconsole.proto.testing.TestService'
+ and 'FakeUnaryUnary'.
+ """
+ names = method_full_name.split('/')
+ return names[-2], names[-1]
+
+
+class AuthServerInterceptor(grpc.ServerInterceptor):
+ """Auth related stuff on server side, which will work for those service which injects this
+ interceptor.
+
+ Ref: https://github.com/grpc/grpc/blob/v1.40.x/examples/python/interceptors/headers/request_header_validator_interceptor.py # pylint:disable=line-too-long
+ """
+
+ def _build_rpc_terminator(self, message: str):
+
+ def terminate(request_or_iterator: Any, context: grpc.ServicerContext):
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
+
+ return terminate
+
+ def _verify_domain_name(self, handler_call_details: grpc.HandlerCallDetails) -> str:
+ """Verifies if the traffic is secure by checking ssl-client-subject-dn header.
+
+ Returns:
+ The pure domain name.
+
+ Raises:
+ AuthException: if the traffic is insecure.
+ """
+ ssl_client_subject_dn = None
+ for header, value in handler_call_details.invocation_metadata:
+ if header == SSL_CLIENT_SUBJECT_DN_HEADER:
+ ssl_client_subject_dn = value
+ break
+ if not ssl_client_subject_dn:
+ raise AuthException('No client subject dn found')
+ # If this header is set, it passed the TLS verification
+ common_name = get_common_name(ssl_client_subject_dn)
+ if not common_name:
+ logging.error('[gRPC auth] invalid subject dn: %s', ssl_client_subject_dn)
+ raise AuthException('Invalid subject dn')
+ # Extracts the pure domain name, e.g. bytedance-test
+ pure_domain_name = get_pure_domain_name(common_name)
+ if not pure_domain_name:
+ logging.error('[gRPC auth] no valid domain name found in %s', ssl_client_subject_dn)
+ raise AuthException('Invalid domain name')
+ return pure_domain_name
+
+ def _verify_project_info(self, handler_call_details: grpc.HandlerCallDetails, pure_domain_name: str):
+ project_name = None
+ for header, value in handler_call_details.invocation_metadata:
+ if header == PROJECT_NAME_HEADER:
+ project_name = decode_project_name(value)
+ break
+ if not project_name:
+ raise AuthException('No project name found')
+ with db.session_scope() as session:
+ project = session.query(Project.id).filter_by(name=project_name).first()
+ if not project:
+ logging.error('[gRPC auth] invalid project: %s', project_name)
+ raise AuthException(f'Invalid project {project_name}')
+ project_id, = project
+ # Checks if the caller has the access to this project
+ service = ParticipantService(session)
+ participants = service.get_participants_by_project(project_id)
+ has_access = False
+ for p in participants:
+ if p.pure_domain_name() == pure_domain_name:
+ has_access = True
+ break
+ if not has_access:
+ raise AuthException(f'No access to {project_name}')
+
+ def intercept_service(self, continuation: Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler],
+ handler_call_details: grpc.HandlerCallDetails) -> Optional[grpc.RpcMethodHandler]:
+ next_handler = continuation(handler_call_details)
+
+ package_service_name, _ = _parse_method_name(handler_call_details.method)
+ # Skips the interceptor if the service does not intend to use it
+ if package_service_name in DISABLED_SERVICES:
+ return next_handler
+
+ try:
+ pure_domain_name = self._verify_domain_name(handler_call_details)
+ # Project based service
+ if package_service_name in PROJECT_BASED_SERVICES:
+ self._verify_project_info(handler_call_details, pure_domain_name)
+ # Go ahead!
+ return next_handler
+ except AuthException as e:
+ handler_factory = _get_handler_factory(next_handler)
+ return handler_factory(self._build_rpc_terminator(e.message))
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_server_interceptor_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_server_interceptor_test.py
new file mode 100644
index 000000000..61c68c9e9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/auth_server_interceptor_test.py
@@ -0,0 +1,184 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import contextlib
+import unittest
+from unittest.mock import patch
+
+import grpc
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.testing import service_pb2_grpc
+from fedlearner_webconsole.proto.testing.service_pb2 import FakeUnaryUnaryRequest, FakeStreamStreamRequest
+from fedlearner_webconsole.rpc.v2.auth_server_interceptor import AuthServerInterceptor, _parse_method_name
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.rpc.client import testing_channel
+from testing.rpc.service import TestService
+
+
+class ParseMethodNameTest(unittest.TestCase):
+
+ def test_parse_method_name(self):
+ self.assertEqual(_parse_method_name('/fedlearner_webconsole.proto.testing.TestService/FakeUnaryUnary'),
+ ('fedlearner_webconsole.proto.testing.TestService', 'FakeUnaryUnary'))
+ self.assertEqual(_parse_method_name('test-service/TestM'), ('test-service', 'TestM'))
+
+
+class AuthServerInterceptorTest(NoWebServerTestCase):
+
+ def set_up_client(self, is_project_based=False, skip=False) -> service_pb2_grpc.TestServiceStub:
+ if is_project_based:
+ project_based_patcher = patch(
+ 'fedlearner_webconsole.rpc.v2.auth_server_interceptor.PROJECT_BASED_SERVICES',
+ frozenset(['fedlearner_webconsole.proto.testing.TestService']),
+ )
+ project_based_patcher.start()
+ if skip:
+ skip_patcher = patch(
+ 'fedlearner_webconsole.rpc.v2.auth_server_interceptor.DISABLED_SERVICES',
+ frozenset(['fedlearner_webconsole.proto.testing.TestService']),
+ )
+ skip_patcher.start()
+
+ def stop_patchers():
+ if is_project_based:
+ project_based_patcher.stop()
+ if skip:
+ skip_patcher.stop()
+
+ def register_service(server: grpc.Server):
+ service_pb2_grpc.add_TestServiceServicer_to_server(TestService(), server)
+
+ with contextlib.ExitStack() as stack:
+ channel = stack.enter_context(
+ testing_channel(
+ register_service,
+ server_interceptors=[AuthServerInterceptor()],
+ ))
+ stub = service_pb2_grpc.TestServiceStub(channel)
+ # Cleans up for the server
+ self.addCleanup(stack.pop_all().close)
+ self.addCleanup(stop_patchers)
+ return stub
+
+ def test_verify_domain_name(self):
+ stub = self.set_up_client()
+ valid_subject_dn = 'CN=aaa.fedlearner.net,OU=security,O=security,L=beijing,ST=beijing,C=CN'
+ # Normal unary-unary
+ resp = stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(),
+ metadata=[('ssl-client-subject-dn', valid_subject_dn)])
+ self.assertIsNotNone(resp)
+
+ # Normal stream-stream
+ def generate_request():
+ yield FakeStreamStreamRequest()
+
+ # Makes sure the stream-stream request is executed
+ self.assertEqual(
+ len(list(stub.FakeStreamStream(generate_request(),
+ metadata=[('ssl-client-subject-dn', valid_subject_dn)]))), 1)
+
+ with self.assertRaisesRegex(grpc.RpcError, 'No client subject dn found') as cm:
+ # No ssl header
+ stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest())
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+ with self.assertRaisesRegex(grpc.RpcError, 'No client subject dn found') as cm:
+ # No ssl header
+ list(stub.FakeStreamStream(generate_request()))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+ with self.assertRaisesRegex(grpc.RpcError, 'Invalid subject dn') as cm:
+ stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(),
+ metadata=[('ssl-client-subject-dn', 'invalid subject dn')])
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+ with self.assertRaisesRegex(grpc.RpcError, 'Invalid domain name') as cm:
+ stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(),
+ metadata=[('ssl-client-subject-dn',
+ 'CN=test.net,OU=security,O=security,L=beijing,ST=beijing,C=CN')])
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+
+ def test_verify_project(self):
+ valid_subject_dn = 'CN=test.fedlearner.net,OU=security,O=security,L=beijing,ST=beijing,C=CN'
+ with db.session_scope() as session:
+ project = Project(id=123, name='test-project')
+ participant = Participant(
+ id=666,
+ name='test-participant',
+ domain_name='fl-test.com',
+ host='127.0.0.1',
+ port=32443,
+ )
+ relationship = ProjectParticipant(project_id=project.id, participant_id=participant.id)
+ session.add_all([project, participant, relationship])
+ session.commit()
+ stub = self.set_up_client(is_project_based=True)
+
+ # Valid request
+ resp = stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(),
+ metadata=[('ssl-client-subject-dn', valid_subject_dn),
+ ('project-name', 'test-project')])
+ self.assertIsNotNone(resp)
+
+ # No project name
+ with self.assertRaisesRegex(grpc.RpcError, 'No project name found') as cm:
+ stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(), metadata=[('ssl-client-subject-dn', valid_subject_dn)])
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+ # Invalid project
+ with self.assertRaisesRegex(grpc.RpcError, 'Invalid project hhh-project') as cm:
+ stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(),
+ metadata=[('ssl-client-subject-dn', valid_subject_dn), ('project-name', 'hhh-project')])
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+ # No access
+ with self.assertRaisesRegex(grpc.RpcError, 'No access to test-project') as cm:
+ stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(),
+ metadata=[
+ ('ssl-client-subject-dn',
+ 'CN=another.fedlearner.net,OU=security,O=security,L=beijing,ST=beijing,C=CN'),
+ ('project-name', 'test-project')
+ ])
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.UNAUTHENTICATED)
+
+ def test_verify_project_unicode(self):
+ valid_subject_dn = 'CN=test.fedlearner.net,OU=security,O=security,L=beijing,ST=beijing,C=CN'
+ with db.session_scope() as session:
+ project = Project(id=123, name='测试工作区')
+ participant = Participant(
+ id=666,
+ name='test-participant',
+ domain_name='fl-test.com',
+ host='127.0.0.1',
+ port=32443,
+ )
+ relationship = ProjectParticipant(project_id=project.id, participant_id=participant.id)
+ session.add_all([project, participant, relationship])
+ session.commit()
+ stub = self.set_up_client(is_project_based=True)
+
+ # Valid request
+ resp = stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest(),
+ metadata=[('ssl-client-subject-dn', valid_subject_dn),
+ ('project-name', '5rWL6K+V5bel5L2c5Yy6')])
+ self.assertIsNotNone(resp)
+
+ def test_skip(self):
+ stub = self.set_up_client(skip=True)
+ # No auth related info
+ resp = stub.FakeUnaryUnary(request=FakeUnaryUnaryRequest())
+ self.assertIsNotNone(resp)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/client_base.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/client_base.py
new file mode 100644
index 000000000..942f628c1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/client_base.py
@@ -0,0 +1,100 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC
+from typing import Optional
+
+import grpc
+from envs import Envs
+
+from fedlearner_webconsole.rpc.v2.auth_client_interceptor import AuthClientInterceptor
+from fedlearner_webconsole.utils.decorators.lru_cache import lru_cache
+
+
+@lru_cache(timeout=60, maxsize=100)
+def build_grpc_channel(nginx_controller_url: str,
+ peer_domain_name: str,
+ project_name: Optional[str] = None) -> grpc.Channel:
+ """A helper function to build gRPC channel with cache.
+
+ Notice that as we cache the channel, if nginx controller gets restarted, the channel may break.
+ This practice is following official best practice: https://grpc.io/docs/guides/performance/
+
+ Args:
+ nginx_controller_url: Nginx controller url in current cluster,
+ e.g. fedlearner-stack-ingress-nginx-controller.default.svc:80
+ peer_domain_name: Domain name of the peer which we want to connect to, e.g. fl-test.com
+ project_name: Project name which the client works on.
+
+ Returns:
+ A grpc service channel to construct grpc clients.
+ """
+ # Authority is used to route the traffic out of cluster, specificly it will look like fl-test-client-auth.com
+ domain_name_prefix = peer_domain_name.rpartition('.')[0]
+ authority = f'{domain_name_prefix}-client-auth.com'
+
+ channel = grpc.insecure_channel(
+ target=nginx_controller_url,
+ # options defined at
+ # https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h
+ options=[('grpc.default_authority', authority)])
+
+ x_host = f'fedlearner-webconsole-v2.{peer_domain_name}'
+ # Adds auth client interceptor to auto-populate auth related headers
+ channel = grpc.intercept_channel(channel, AuthClientInterceptor(x_host=x_host, project_name=project_name))
+ return channel
+
+
+def get_nginx_controller_url() -> str:
+ """Generates nginx controller url in current cluster.
+
+ Basically our gRPC client talks to the nginx controller.
+ """
+ if Envs.DEBUG and Envs.GRPC_SERVER_URL is not None:
+ return Envs.GRPC_SERVER_URL
+ return 'fedlearner-stack-ingress-nginx-controller.default.svc:80'
+
+
+class ParticipantRpcClient(ABC):
+ """Abstract class for clients which only work on participant system level, e.g. system service to check health.
+ """
+
+ def __init__(self, channel: grpc.Channel):
+ pass
+
+ @classmethod
+ def from_participant(cls, domain_name: str):
+ channel = build_grpc_channel(
+ get_nginx_controller_url(),
+ domain_name,
+ )
+ return cls(channel)
+
+
+class ParticipantProjectRpcClient(ABC):
+ """Abstract class for clients which work on participant's project level, e.g. model service to train/eval.
+ """
+
+ def __init__(self, channel: grpc.Channel):
+ pass
+
+ @classmethod
+ def from_project_and_participant(cls, domain_name: str, project_name: str):
+ channel = build_grpc_channel(
+ get_nginx_controller_url(),
+ domain_name,
+ project_name,
+ )
+ return cls(channel)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/client_base_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/client_base_test.py
new file mode 100644
index 000000000..894792eaa
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/client_base_test.py
@@ -0,0 +1,155 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, Mock, call, patch
+import grpc
+import grpc_testing
+
+from fedlearner_webconsole.proto.testing import service_pb2
+from fedlearner_webconsole.rpc.v2.client_base import (ParticipantProjectRpcClient, ParticipantRpcClient,
+ build_grpc_channel, get_nginx_controller_url)
+from testing.fake_time_patcher import FakeTimePatcher
+from testing.rpc.client import RpcClientTestCase
+
+
+class GetNginxControllerUrlTest(unittest.TestCase):
+
+ def test_prod(self):
+ self.assertEqual(
+ get_nginx_controller_url(),
+ 'fedlearner-stack-ingress-nginx-controller.default.svc:80',
+ )
+
+ @patch('envs.Envs.DEBUG', 'True')
+ @patch('envs.Envs.GRPC_SERVER_URL', 'xxx.default.svc:443')
+ def test_custom_url(self):
+ self.assertEqual(
+ get_nginx_controller_url(),
+ 'xxx.default.svc:443',
+ )
+
+
+class BuildGrpcChannelTest(RpcClientTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._insecure_channel_patcher = patch('fedlearner_webconsole.rpc.v2.client_base.grpc.insecure_channel')
+ self._mock_insecure_channel: Mock = self._insecure_channel_patcher.start()
+ self._mock_insecure_channel.return_value = grpc_testing.channel(
+ service_pb2.DESCRIPTOR.services_by_name.values(), grpc_testing.strict_real_time())
+
+ def tearDown(self):
+ self._insecure_channel_patcher.stop()
+ super().tearDown()
+
+ @patch('fedlearner_webconsole.rpc.v2.client_base.AuthClientInterceptor', spec=grpc.UnaryUnaryClientInterceptor)
+ def test_build_same_channel(self, mock_auth_client_interceptor: Mock):
+ fake_timer = FakeTimePatcher()
+ fake_timer.start()
+ nginx_controller_url = 'test-nginx.default.svc:80'
+ channel1 = build_grpc_channel(nginx_controller_url, 'fl-test1.com')
+ # Within 60s
+ channel2 = build_grpc_channel(nginx_controller_url, 'fl-test1.com')
+ # Checks if it is the same instance
+ self.assertTrue(channel1 is channel2)
+ self._mock_insecure_channel.assert_called_once_with(
+ target=nginx_controller_url,
+ options=[('grpc.default_authority', 'fl-test1-client-auth.com')],
+ )
+ mock_auth_client_interceptor.assert_called_once_with(
+ x_host='fedlearner-webconsole-v2.fl-test1.com',
+ project_name=None,
+ )
+
+ # Ticks 62 seconds to timeout
+ fake_timer.interrupt(62)
+ channel3 = build_grpc_channel(nginx_controller_url, 'fl-test1.com')
+ self.assertTrue(channel3 is not channel1)
+ self.assertEqual(self._mock_insecure_channel.call_count, 2)
+ self.assertEqual(mock_auth_client_interceptor.call_count, 2)
+
+ @patch('fedlearner_webconsole.rpc.v2.client_base.AuthClientInterceptor', spec=grpc.UnaryUnaryClientInterceptor)
+ def test_build_different_channels(self, mock_auth_client_interceptor: Mock):
+ nginx_controller_url = 'test.default.svc:80'
+ channel1 = build_grpc_channel(nginx_controller_url, 'fl-test1.com')
+ channel2 = build_grpc_channel(nginx_controller_url, 'fl-test1.com', project_name='test-project')
+ self.assertTrue(channel1 is not channel2)
+
+ self.assertEqual(self._mock_insecure_channel.call_args_list, [
+ call(
+ target=nginx_controller_url,
+ options=[('grpc.default_authority', 'fl-test1-client-auth.com')],
+ ),
+ call(
+ target=nginx_controller_url,
+ options=[('grpc.default_authority', 'fl-test1-client-auth.com')],
+ ),
+ ])
+ self.assertEqual(mock_auth_client_interceptor.call_args_list, [
+ call(
+ x_host='fedlearner-webconsole-v2.fl-test1.com',
+ project_name=None,
+ ),
+ call(
+ x_host='fedlearner-webconsole-v2.fl-test1.com',
+ project_name='test-project',
+ ),
+ ])
+
+
+class _FakeRpcClient(ParticipantRpcClient, ParticipantProjectRpcClient):
+
+ def __init__(self, channel):
+ super().__init__(channel)
+ self.channel = channel
+
+
+class ParticipantRpcClientTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.rpc.v2.client_base.build_grpc_channel')
+ def test_from_participant(self, mock_build_grpc_channel: Mock):
+ fake_channel = MagicMock()
+ mock_build_grpc_channel.return_value = fake_channel
+
+ domain_name = 'fl-test.com'
+ fake_client = _FakeRpcClient.from_participant(domain_name=domain_name)
+ self.assertTrue(fake_client.channel is fake_channel)
+ mock_build_grpc_channel.assert_called_once_with(
+ 'fedlearner-stack-ingress-nginx-controller.default.svc:80',
+ domain_name,
+ )
+
+
+class ParticipantProjectRpcClientTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.rpc.v2.client_base.build_grpc_channel')
+ def test_from_project_and_participant(self, mock_build_grpc_channel: Mock):
+ fake_channel = MagicMock()
+ mock_build_grpc_channel.return_value = fake_channel
+
+ domain_name = 'fl-test.com'
+ project_name = 'test-prrrr'
+ fake_client = _FakeRpcClient.from_project_and_participant(domain_name, project_name)
+ self.assertTrue(fake_client.channel is fake_channel)
+ mock_build_grpc_channel.assert_called_once_with(
+ 'fedlearner-stack-ingress-nginx-controller.default.svc:80',
+ domain_name,
+ project_name,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_client.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_client.py
new file mode 100644
index 000000000..64e414add
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_client.py
@@ -0,0 +1,174 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from datetime import datetime
+from google.protobuf import empty_pb2
+from typing import Optional
+
+from envs import Envs
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2_grpc import JobServiceStub
+from fedlearner_webconsole.rpc.v2.client_base import ParticipantProjectRpcClient
+from fedlearner_webconsole.mmgr.models import ModelJobType, AlgorithmType, GroupAutoUpdateStatus
+from fedlearner_webconsole.dataset.models import DatasetJobSchedulerState
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, AlgorithmProjectList, ModelJobPb, ModelJobGroupPb
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import CreateModelJobRequest, InformTrustedJobGroupRequest, \
+ UpdateTrustedJobGroupRequest, DeleteTrustedJobGroupRequest, GetTrustedJobGroupRequest, \
+ GetTrustedJobGroupResponse, CreateDatasetJobStageRequest, GetDatasetJobStageRequest, GetDatasetJobStageResponse, \
+ CreateModelJobGroupRequest, GetModelJobRequest, GetModelJobGroupRequest, InformModelJobGroupRequest, \
+ InformTrustedJobRequest, GetTrustedJobRequest, GetTrustedJobResponse, CreateTrustedExportJobRequest, \
+ UpdateDatasetJobSchedulerStateRequest, UpdateModelJobGroupRequest, InformModelJobRequest
+
+
+def _need_retry_for_get(err: Exception) -> bool:
+ if not isinstance(err, grpc.RpcError):
+ return False
+ # No need to retry for NOT_FOUND
+ return err.code() != grpc.StatusCode.NOT_FOUND
+
+
+def _need_retry_for_create(err: Exception) -> bool:
+ if not isinstance(err, grpc.RpcError):
+ return False
+ # No need to retry for INVALID_ARGUMENT
+ return err.code() != grpc.StatusCode.INVALID_ARGUMENT
+
+
+def _default_need_retry(err: Exception) -> bool:
+ return isinstance(err, grpc.RpcError)
+
+
+class JobServiceClient(ParticipantProjectRpcClient):
+
+ def __init__(self, channel: grpc.Channel):
+ super().__init__(channel)
+ self._stub: JobServiceStub = JobServiceStub(channel)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def inform_trusted_job_group(self, uuid: str, auth_status: AuthStatus) -> empty_pb2.Empty:
+ msg = InformTrustedJobGroupRequest(uuid=uuid, auth_status=auth_status.name)
+ return self._stub.InformTrustedJobGroup(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def update_trusted_job_group(self, uuid: str, algorithm_uuid: str) -> empty_pb2.Empty:
+ msg = UpdateTrustedJobGroupRequest(uuid=uuid, algorithm_uuid=algorithm_uuid)
+ return self._stub.UpdateTrustedJobGroup(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def delete_trusted_job_group(self, uuid: str) -> empty_pb2.Empty:
+ msg = DeleteTrustedJobGroupRequest(uuid=uuid)
+ return self._stub.DeleteTrustedJobGroup(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_trusted_job_group(self, uuid: str) -> GetTrustedJobGroupResponse:
+ msg = GetTrustedJobGroupRequest(uuid=uuid)
+ return self._stub.GetTrustedJobGroup(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_create)
+ def create_trusted_export_job(self, uuid: str, name: str, export_count: int, parent_uuid: str,
+ ticket_uuid: str) -> empty_pb2.Empty:
+ msg = CreateTrustedExportJobRequest(uuid=uuid,
+ name=name,
+ export_count=export_count,
+ parent_uuid=parent_uuid,
+ ticket_uuid=ticket_uuid)
+ return self._stub.CreateTrustedExportJob(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_model_job(self, uuid: str) -> ModelJobPb:
+ return self._stub.GetModelJob(request=GetModelJobRequest(uuid=uuid), timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def create_model_job(self, name: str, uuid: str, group_uuid: str, model_job_type: ModelJobType,
+ algorithm_type: AlgorithmType, global_config: ModelJobGlobalConfig,
+ version: int) -> empty_pb2.Empty:
+ request = CreateModelJobRequest(name=name,
+ uuid=uuid,
+ group_uuid=group_uuid,
+ model_job_type=model_job_type.name,
+ algorithm_type=algorithm_type.name,
+ global_config=global_config,
+ version=version)
+ return self._stub.CreateModelJob(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def inform_model_job(self, uuid: str, auth_status: AuthStatus) -> empty_pb2.Empty:
+ msg = InformModelJobRequest(uuid=uuid, auth_status=auth_status.name)
+ return self._stub.InformModelJob(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_model_job_group(self, uuid: str) -> ModelJobGroupPb:
+ return self._stub.GetModelJobGroup(request=GetModelJobGroupRequest(uuid=uuid), timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def inform_model_job_group(self, uuid: str, auth_status: AuthStatus) -> empty_pb2.Empty:
+ msg = InformModelJobGroupRequest(uuid=uuid, auth_status=auth_status.name)
+ return self._stub.InformModelJobGroup(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def update_model_job_group(self,
+ uuid: str,
+ auto_update_status: Optional[GroupAutoUpdateStatus] = None,
+ start_dataset_job_stage_uuid: Optional[str] = None) -> empty_pb2.Empty:
+ msg = UpdateModelJobGroupRequest(uuid=uuid,
+ auto_update_status=auto_update_status.name if auto_update_status else None,
+ start_dataset_job_stage_uuid=start_dataset_job_stage_uuid)
+ return self._stub.UpdateModelJobGroup(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_create)
+ def create_dataset_job_stage(self,
+ dataset_job_uuid: str,
+ dataset_job_stage_uuid: str,
+ name: str,
+ event_time: Optional[datetime] = None) -> empty_pb2.Empty:
+ request = CreateDatasetJobStageRequest(dataset_job_uuid=dataset_job_uuid,
+ dataset_job_stage_uuid=dataset_job_stage_uuid,
+ name=name,
+ event_time=to_timestamp(event_time) if event_time else None)
+ return self._stub.CreateDatasetJobStage(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_dataset_job_stage(self, dataset_job_stage_uuid: str) -> GetDatasetJobStageResponse:
+ msg = GetDatasetJobStageRequest(dataset_job_stage_uuid=dataset_job_stage_uuid)
+ return self._stub.GetDatasetJobStage(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_create)
+ def update_dataset_job_scheduler_state(self, uuid: str,
+ scheduler_state: DatasetJobSchedulerState) -> empty_pb2.Empty:
+ request = UpdateDatasetJobSchedulerStateRequest(uuid=uuid, scheduler_state=scheduler_state.name)
+ return self._stub.UpdateDatasetJobSchedulerState(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_create)
+ def create_model_job_group(self, name: str, uuid: str, algorithm_type: AlgorithmType, dataset_uuid: str,
+ algorithm_project_list: AlgorithmProjectList):
+ request = CreateModelJobGroupRequest(name=name,
+ uuid=uuid,
+ algorithm_type=algorithm_type.name,
+ dataset_uuid=dataset_uuid,
+ algorithm_project_list=algorithm_project_list)
+ return self._stub.CreateModelJobGroup(request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def inform_trusted_job(self, uuid: str, auth_status: AuthStatus) -> empty_pb2.Empty:
+ msg = InformTrustedJobRequest(uuid=uuid, auth_status=auth_status.name)
+ return self._stub.InformTrustedJob(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_trusted_job(self, uuid: str) -> GetTrustedJobResponse:
+ msg = GetTrustedJobRequest(uuid=uuid)
+ return self._stub.GetTrustedJob(request=msg, timeout=Envs.GRPC_CLIENT_TIMEOUT)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_client_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_client_test.py
new file mode 100644
index 000000000..d077dccfa
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_client_test.py
@@ -0,0 +1,428 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+import unittest
+import grpc
+import grpc_testing
+from google.protobuf.empty_pb2 import Empty
+from google.protobuf.descriptor import ServiceDescriptor
+from testing.rpc.client import RpcClientTestCase
+from fedlearner_webconsole.proto.dataset_pb2 import DatasetJobConfig, DatasetJobGlobalConfigs, DatasetJobStage
+from fedlearner_webconsole.proto.rpc.v2 import job_service_pb2
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.rpc.v2.job_service_client import JobServiceClient
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.mmgr.models import ModelJobType, GroupAutoUpdateStatus
+from fedlearner_webconsole.dataset.models import DatasetJobSchedulerState
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, AlgorithmProjectList, ModelJobPb, ModelJobGroupPb
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import CreateModelJobRequest, InformTrustedJobGroupRequest, \
+ UpdateTrustedJobGroupRequest, DeleteTrustedJobGroupRequest, GetTrustedJobGroupRequest, \
+ GetTrustedJobGroupResponse, CreateDatasetJobStageRequest, GetDatasetJobStageRequest, GetDatasetJobStageResponse, \
+ CreateModelJobGroupRequest, GetModelJobRequest, GetModelJobGroupRequest, InformModelJobGroupRequest, \
+ InformTrustedJobRequest, GetTrustedJobRequest, GetTrustedJobResponse, CreateTrustedExportJobRequest, \
+ UpdateDatasetJobSchedulerStateRequest, UpdateModelJobGroupRequest, InformModelJobRequest
+
+_SERVICE_DESCRIPTOR: ServiceDescriptor = job_service_pb2.DESCRIPTOR.services_by_name['JobService']
+
+
+class JobServiceClientTest(RpcClientTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._fake_channel: grpc_testing.Channel = grpc_testing.channel([_SERVICE_DESCRIPTOR],
+ grpc_testing.strict_real_time())
+ self._client = JobServiceClient(self._fake_channel)
+
+ def test_inform_trusted_job_group(self):
+ call = self.client_execution_pool.submit(self._client.inform_trusted_job_group,
+ uuid='uuid',
+ auth_status=AuthStatus.AUTHORIZED)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['InformTrustedJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, InformTrustedJobGroupRequest(uuid='uuid', auth_status='AUTHORIZED'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_update_trusted_job_group(self):
+ call = self.client_execution_pool.submit(self._client.update_trusted_job_group,
+ uuid='uuid',
+ algorithm_uuid='algorithm-uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['UpdateTrustedJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, UpdateTrustedJobGroupRequest(uuid='uuid', algorithm_uuid='algorithm-uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_delete_trusted_job_group(self):
+ call = self.client_execution_pool.submit(self._client.delete_trusted_job_group, uuid='uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['DeleteTrustedJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, DeleteTrustedJobGroupRequest(uuid='uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_trusted_job_group(self):
+ call = self.client_execution_pool.submit(self._client.get_trusted_job_group, uuid='uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['GetTrustedJobGroup'])
+ expected_response = GetTrustedJobGroupResponse(auth_status='AUTHORIZED')
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetTrustedJobGroupRequest(uuid='uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_trusted_job(self):
+ call = self.client_execution_pool.submit(self._client.get_trusted_job, uuid='uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['GetTrustedJob'])
+ expected_response = GetTrustedJobResponse(auth_status=AuthStatus.WITHDRAW.name)
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetTrustedJobRequest(uuid='uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_model_job(self):
+ call = self.client_execution_pool.submit(self._client.get_model_job, uuid='uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['GetModelJob'])
+ expected_response = ModelJobPb(name='name')
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetModelJobRequest(uuid='uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_create_model_job(self):
+ call = self.client_execution_pool.submit(self._client.create_model_job,
+ name='name',
+ uuid='uuid',
+ group_uuid='group_uuid',
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ global_config=ModelJobGlobalConfig(),
+ version=3)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CreateModelJob'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ CreateModelJobRequest(name='name',
+ uuid='uuid',
+ group_uuid='group_uuid',
+ model_job_type='TRAINING',
+ algorithm_type='NN_VERTICAL',
+ global_config=ModelJobGlobalConfig(),
+ version=3))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_inform_model_job(self):
+ call = self.client_execution_pool.submit(self._client.inform_model_job,
+ uuid='uuid',
+ auth_status=AuthStatus.AUTHORIZED)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['InformModelJob'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, InformModelJobRequest(uuid='uuid', auth_status=AuthStatus.AUTHORIZED.name))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_create_dataset_job_stage(self):
+ event_time = datetime(2022, 1, 1)
+ call = self.client_execution_pool.submit(self._client.create_dataset_job_stage,
+ dataset_job_uuid='dataset_job_uuid',
+ dataset_job_stage_uuid='dataset_job_stage_uuid',
+ name='20220101',
+ event_time=event_time)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CreateDatasetJobStage'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ CreateDatasetJobStageRequest(
+ dataset_job_uuid='dataset_job_uuid',
+ dataset_job_stage_uuid='dataset_job_stage_uuid',
+ name='20220101',
+ event_time=to_timestamp(event_time),
+ ))
+ self.assertEqual(call.result(), expected_response)
+
+ # test event_time is None
+ call = self.client_execution_pool.submit(self._client.create_dataset_job_stage,
+ dataset_job_uuid='dataset_job_uuid',
+ dataset_job_stage_uuid='dataset_job_stage_uuid',
+ name='20220101',
+ event_time=None)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CreateDatasetJobStage'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ CreateDatasetJobStageRequest(dataset_job_uuid='dataset_job_uuid',
+ dataset_job_stage_uuid='dataset_job_stage_uuid',
+ name='20220101'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_dataset_job_stage(self):
+ call = self.client_execution_pool.submit(self._client.get_dataset_job_stage,
+ dataset_job_stage_uuid='dataset_job_stage_uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['GetDatasetJobStage'])
+ dataset_job_stage = DatasetJobStage(
+ id=1,
+ uuid='fake stage uuid',
+ name='test_dataset_job_stage',
+ dataset_job_uuid='fake job uuid',
+ global_configs=DatasetJobGlobalConfigs(
+ global_configs={'test_domain': DatasetJobConfig(dataset_uuid='dataset uuid', variables=[])}),
+ workflow_definition=WorkflowDefinition(group_alias='fake template', variables=[], job_definitions=[]))
+ expected_response = GetDatasetJobStageResponse(dataset_job_stage=dataset_job_stage)
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetDatasetJobStageRequest(dataset_job_stage_uuid='dataset_job_stage_uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_update_dataset_job_scheduler_state(self):
+ call = self.client_execution_pool.submit(self._client.update_dataset_job_scheduler_state,
+ uuid='dataset_job_uuid',
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['UpdateDatasetJobSchedulerState'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request,
+ UpdateDatasetJobSchedulerStateRequest(
+ uuid='dataset_job_uuid',
+ scheduler_state='RUNNABLE',
+ ))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_model_job_group(self):
+ call = self.client_execution_pool.submit(self._client.get_model_job_group, uuid='uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['GetModelJobGroup'])
+ expected_response = ModelJobGroupPb(name='12')
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetModelJobGroupRequest(uuid='uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_inform_model_job_group(self):
+ call = self.client_execution_pool.submit(self._client.inform_model_job_group,
+ uuid='uuid',
+ auth_status=AuthStatus.AUTHORIZED)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['InformModelJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, InformModelJobGroupRequest(uuid='uuid', auth_status=AuthStatus.AUTHORIZED.name))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_update_model_job_group(self):
+ call = self.client_execution_pool.submit(self._client.update_model_job_group,
+ uuid='uuid',
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE,
+ start_dataset_job_stage_uuid='stage_uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['UpdateModelJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ UpdateModelJobGroupRequest(uuid='uuid',
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE.name,
+ start_dataset_job_stage_uuid='stage_uuid'))
+ self.assertEqual(call.result(), expected_response)
+ call = self.client_execution_pool.submit(self._client.update_model_job_group,
+ uuid='uuid',
+ auto_update_status=GroupAutoUpdateStatus.STOPPED)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['UpdateModelJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ UpdateModelJobGroupRequest(uuid='uuid',
+ auto_update_status=GroupAutoUpdateStatus.STOPPED.name,
+ start_dataset_job_stage_uuid=None))
+ self.assertEqual(call.result(), expected_response)
+ call = self.client_execution_pool.submit(self._client.update_model_job_group, uuid='uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['UpdateModelJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request, UpdateModelJobGroupRequest(uuid='uuid', auto_update_status=None,
+ start_dataset_job_stage_uuid=None))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_create_model_job_group(self):
+ call = self.client_execution_pool.submit(
+ self._client.create_model_job_group,
+ name='name',
+ uuid='uuid',
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ dataset_uuid='uuid',
+ algorithm_project_list=AlgorithmProjectList(algorithm_projects={'test': 'uuid'}))
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CreateModelJobGroup'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ CreateModelJobGroupRequest(
+ name='name',
+ uuid='uuid',
+ algorithm_type='NN_VERTICAL',
+ dataset_uuid='uuid',
+ algorithm_project_list=AlgorithmProjectList(algorithm_projects={'test': 'uuid'})))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_inform_trusted_job(self):
+ call = self.client_execution_pool.submit(self._client.inform_trusted_job,
+ uuid='uuid',
+ auth_status=AuthStatus.AUTHORIZED)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['InformTrustedJob'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, InformTrustedJobRequest(uuid='uuid', auth_status=AuthStatus.AUTHORIZED.name))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_create_trusted_export_job(self):
+ call = self.client_execution_pool.submit(self._client.create_trusted_export_job,
+ uuid='uuid1',
+ name='V1-domain1-1',
+ export_count=1,
+ parent_uuid='uuid2',
+ ticket_uuid='ticket uuid')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CreateTrustedExportJob'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ CreateTrustedExportJobRequest(uuid='uuid1',
+ name='V1-domain1-1',
+ export_count=1,
+ parent_uuid='uuid2',
+ ticket_uuid='ticket uuid'))
+ self.assertEqual(call.result(), expected_response)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_server.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_server.py
new file mode 100644
index 000000000..096034549
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_server.py
@@ -0,0 +1,395 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import grpc
+from grpc import ServicerContext
+from google.protobuf import empty_pb2
+import sqlalchemy
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.rpc.v2 import job_service_pb2_grpc
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import CreateModelJobRequest, InformTrustedJobGroupRequest, \
+ UpdateTrustedJobGroupRequest, DeleteTrustedJobGroupRequest, GetTrustedJobGroupRequest, \
+ GetTrustedJobGroupResponse, CreateDatasetJobStageRequest, GetDatasetJobStageRequest, GetDatasetJobStageResponse, \
+ CreateModelJobGroupRequest, GetModelJobGroupRequest, GetModelJobRequest, InformModelJobGroupRequest, \
+ InformTrustedJobRequest, GetTrustedJobRequest, GetTrustedJobResponse, CreateTrustedExportJobRequest, \
+ UpdateDatasetJobSchedulerStateRequest, UpdateModelJobGroupRequest, InformModelJobRequest
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobPb, ModelJobGroupPb
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.tee.services import TrustedJobGroupService, TrustedJobService
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob, TrustedJobStatus, TrustedJobType
+from fedlearner_webconsole.mmgr.models import ModelJobRole, ModelJobType, AlgorithmType, ModelJobGroup, ModelJob, \
+ GroupCreateStatus, GroupAutoUpdateStatus
+from fedlearner_webconsole.mmgr.service import ModelJobService, ModelJobGroupService
+from fedlearner_webconsole.rpc.v2.utils import get_grpc_context_info
+from fedlearner_webconsole.dataset.job_configer.dataset_job_configer import DatasetJobConfiger
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobSchedulerState, DatasetJobStage, Dataset
+from fedlearner_webconsole.dataset.services import DatasetService
+from fedlearner_webconsole.dataset.local_controllers import DatasetJobStageLocalController
+from fedlearner_webconsole.dataset.services import DatasetJobService
+from fedlearner_webconsole.utils.pp_datetime import from_timestamp
+from fedlearner_webconsole.utils.proto import remove_secrets
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.exceptions import NotFoundException
+from fedlearner_webconsole.audit.decorators import emits_rpc_event
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+
+
+class JobServiceServicer(job_service_pb2_grpc.JobServiceServicer):
+
+ @emits_rpc_event(resource_type=Event.ResourceType.TRUSTED_JOB_GROUP,
+ op_type=Event.OperationType.INFORM,
+ resource_name_fn=lambda request: request.uuid)
+ def InformTrustedJobGroup(self, request: InformTrustedJobGroupRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ group: TrustedJobGroup = session.query(TrustedJobGroup).populate_existing().with_for_update().filter_by(
+ project_id=project_id, uuid=request.uuid).first()
+ if group is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'trusted job group {request.uuid} not found')
+ try:
+ auth_status = AuthStatus[request.auth_status]
+ except KeyError:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'auth_status {request.auth_status} is invalid')
+ unauth_set = set(group.get_unauth_participant_ids())
+ if auth_status == AuthStatus.AUTHORIZED:
+ unauth_set.discard(client_id)
+ else:
+ unauth_set.add(client_id)
+ group.set_unauth_participant_ids(list(unauth_set))
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.TRUSTED_JOB_GROUP,
+ op_type=Event.OperationType.UPDATE,
+ resource_name_fn=lambda request: request.uuid)
+ def UpdateTrustedJobGroup(self, request: UpdateTrustedJobGroupRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ group: TrustedJobGroup = session.query(TrustedJobGroup).filter_by(project_id=project_id,
+ uuid=request.uuid).first()
+ if group is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'trusted job group {request.uuid} not found')
+ if client_id != group.coordinator_id:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, 'only coordinator can update algorithm')
+ try:
+ algorithm = AlgorithmFetcher(project_id).get_algorithm(request.algorithm_uuid)
+ old_algorithm = AlgorithmFetcher(project_id).get_algorithm(group.algorithm_uuid)
+ if algorithm.algorithm_project_uuid != old_algorithm.algorithm_project_uuid:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, 'upstream algorithm project mismatch')
+ except NotFoundException as e:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, e.message)
+ group.algorithm_uuid = request.algorithm_uuid
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.TRUSTED_JOB_GROUP,
+ op_type=Event.OperationType.DELETE,
+ resource_name_fn=lambda request: request.uuid)
+ def DeleteTrustedJobGroup(self, request: DeleteTrustedJobGroupRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ group: TrustedJobGroup = session.query(TrustedJobGroup).filter_by(project_id=project_id,
+ uuid=request.uuid).first()
+ if group is None:
+ return empty_pb2.Empty()
+ if client_id != group.coordinator_id:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, 'only coordinator can delete the trusted job group')
+ if not group.is_deletable():
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, 'trusted job is not deletable')
+ TrustedJobGroupService(session).delete(group)
+ session.commit()
+ return empty_pb2.Empty()
+
+ def GetTrustedJobGroup(self, request: GetTrustedJobGroupRequest,
+ context: ServicerContext) -> GetTrustedJobGroupResponse:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ group: TrustedJobGroup = session.query(TrustedJobGroup).filter_by(project_id=project_id,
+ uuid=request.uuid).first()
+ if group is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'trusted job group {request.uuid} not found')
+ return GetTrustedJobGroupResponse(auth_status=group.auth_status.name)
+
+ def GetModelJob(self, request: GetModelJobRequest, context: ServicerContext) -> ModelJobPb:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ model_job: ModelJob = session.query(ModelJob).filter_by(uuid=request.uuid).first()
+ if model_job is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'model job with uuid {request.uuid} is not found')
+ return remove_secrets(model_job.to_proto())
+
+ @emits_rpc_event(resource_type=Event.ResourceType.MODEL_JOB,
+ op_type=Event.OperationType.CREATE,
+ resource_name_fn=lambda request: request.uuid)
+ def CreateModelJob(self, request: CreateModelJobRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ if session.query(ModelJob).filter_by(uuid=request.uuid).first() is not None:
+ return empty_pb2.Empty()
+ if session.query(ModelJob).filter_by(name=request.name).first() is not None:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'model job {request.name} already exist')
+ group = session.query(ModelJobGroup).filter_by(uuid=request.group_uuid).first()
+ if group is None:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'model job group {request.group_uuid} not found')
+ model_job_type = ModelJobType[request.model_job_type]
+ if model_job_type in [ModelJobType.TRAINING] and group.latest_version >= request.version:
+ context.abort(
+ grpc.StatusCode.INVALID_ARGUMENT, f'the latest version of model group {group.name} '
+ f'is larger than or equal to the given version')
+ service = ModelJobService(session)
+ algorithm_type = AlgorithmType[request.algorithm_type]
+ data_batch_id = None
+ dataset_job_stage_uuid = request.global_config.dataset_job_stage_uuid
+ if dataset_job_stage_uuid != '':
+ dataset_job_stage = session.query(DatasetJobStage).filter_by(uuid=dataset_job_stage_uuid).first()
+ data_batch_id = dataset_job_stage.data_batch_id
+ model_job = service.create_model_job(name=request.name,
+ uuid=request.uuid,
+ group_id=group.id,
+ project_id=project_id,
+ role=ModelJobRole.PARTICIPANT,
+ model_job_type=model_job_type,
+ algorithm_type=algorithm_type,
+ coordinator_id=client_id,
+ data_batch_id=data_batch_id,
+ global_config=request.global_config,
+ version=request.version)
+ if model_job_type in [ModelJobType.TRAINING]:
+ group.latest_version = model_job.version
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.MODEL_JOB,
+ op_type=Event.OperationType.INFORM,
+ resource_name_fn=lambda request: request.uuid)
+ def InformModelJob(self, request: InformModelJobRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ model_job: ModelJob = session.query(ModelJob).populate_existing().with_for_update().filter_by(
+ project_id=project_id, uuid=request.uuid).first()
+ if model_job is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'model job {request.uuid} is not found')
+ try:
+ auth_status = AuthStatus[request.auth_status]
+ except KeyError:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'auth_status {request.auth_status} is invalid')
+ pure_domain_name = session.query(Participant).get(client_id).pure_domain_name()
+ participants_info = model_job.get_participants_info()
+ participants_info.participants_map[pure_domain_name].auth_status = auth_status.name
+ model_job.set_participants_info(participants_info)
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.DATASET_JOB_STAGE,
+ op_type=Event.OperationType.CREATE,
+ resource_name_fn=lambda request: request.dataset_job_uuid)
+ def CreateDatasetJobStage(self, request: CreateDatasetJobStageRequest, context: ServicerContext) -> empty_pb2.Empty:
+ try:
+ with db.session_scope() as session:
+ # we set isolation_level to SERIALIZABLE to make sure state won't be changed within this session
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ _, client_id = get_grpc_context_info(session, context)
+ dataset_job: DatasetJob = session.query(DatasetJob).filter(
+ DatasetJob.uuid == request.dataset_job_uuid).first()
+ if dataset_job is None:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT,
+ f'dataset_job {request.dataset_job_uuid} is not found')
+ if dataset_job.output_dataset is None:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT,
+ f'output dataset is not found, dataset_job uuid: {request.dataset_job_uuid}')
+ # check authorization
+ if not AuthService(session=session, dataset_job=dataset_job).check_local_authorized():
+ message = '[CreateDatasetJobStage] still waiting for authorized, ' \
+ f'dataset_job_uuid: {request.dataset_job_uuid}'
+ logging.warning(message)
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, message)
+ event_time = from_timestamp(request.event_time) if request.event_time else None
+ # check data_batch ready
+ data_batch = DatasetService(session).get_data_batch(dataset=dataset_job.input_dataset,
+ event_time=event_time)
+ if data_batch is None or not data_batch.is_available():
+ message = '[CreateDatasetJobStage] input_dataset data_batch is not ready, ' \
+ f'datasetJob uuid: {request.dataset_job_uuid}'
+ logging.warning(message)
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, message)
+ DatasetJobStageLocalController(session=session).create_data_batch_and_job_stage_as_participant(
+ dataset_job_id=dataset_job.id,
+ coordinator_id=client_id,
+ uuid=request.dataset_job_stage_uuid,
+ name=request.name,
+ event_time=event_time)
+ session.commit()
+ except sqlalchemy.exc.OperationalError as e:
+ # catch deadlock exception
+ logging.warning('[create dataset job stage rpc]: [SKIP] catch operation error in session', exc_info=True)
+ return empty_pb2.Empty()
+
+ def GetDatasetJobStage(self, request: GetDatasetJobStageRequest,
+ context: ServicerContext) -> GetDatasetJobStageResponse:
+ with db.session_scope() as session:
+ dataset_job_stage: DatasetJobStage = session.query(DatasetJobStage).filter(
+ DatasetJobStage.uuid == request.dataset_job_stage_uuid).first()
+ if dataset_job_stage is None:
+ context.abort(code=grpc.StatusCode.NOT_FOUND,
+ details=f'could not find dataset_job_stage {request.dataset_job_stage_uuid}')
+ dataset_job_stage_proto = dataset_job_stage.to_proto()
+ dataset_job_stage_proto.workflow_definition.MergeFrom(
+ DatasetJobConfiger.from_kind(dataset_job_stage.dataset_job.kind, session).get_config())
+ return GetDatasetJobStageResponse(dataset_job_stage=dataset_job_stage_proto)
+
+ def UpdateDatasetJobSchedulerState(self, request: UpdateDatasetJobSchedulerStateRequest,
+ context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ dataset_job: DatasetJob = session.query(DatasetJob).filter(DatasetJob.project_id == project_id).filter(
+ DatasetJob.uuid == request.uuid).first()
+ if dataset_job is None:
+ context.abort(code=grpc.StatusCode.NOT_FOUND, details=f'could not find dataset_job {request.uuid}')
+ if request.scheduler_state == DatasetJobSchedulerState.RUNNABLE.name:
+ DatasetJobService(session=session).start_cron_scheduler(dataset_job=dataset_job)
+ elif request.scheduler_state == DatasetJobSchedulerState.STOPPED.name:
+ DatasetJobService(session=session).stop_cron_scheduler(dataset_job=dataset_job)
+ else:
+ context.abort(code=grpc.StatusCode.INVALID_ARGUMENT,
+ details='scheduler state must in [RUNNABLE, STOPPED]')
+ session.commit()
+ return empty_pb2.Empty()
+
+ def GetModelJobGroup(self, request: GetModelJobGroupRequest, context: ServicerContext) -> ModelJobGroupPb:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ group = session.query(ModelJobGroup).filter_by(uuid=request.uuid).first()
+ if group is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'model job group with uuid {request.uuid} is not found')
+ return remove_secrets(group.to_proto())
+
+ @emits_rpc_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP,
+ op_type=Event.OperationType.INFORM,
+ resource_name_fn=lambda request: request.uuid)
+ def InformModelJobGroup(self, request: InformModelJobGroupRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ group: ModelJobGroup = session.query(ModelJobGroup).populate_existing().with_for_update().filter_by(
+ project_id=project_id, uuid=request.uuid).first()
+ if group is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'model job group {request.uuid} is not found')
+ try:
+ auth_status = AuthStatus[request.auth_status]
+ except KeyError:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'auth_status {request.auth_status} is invalid')
+ pure_domain_name = session.query(Participant).get(client_id).pure_domain_name()
+ participants_info = group.get_participants_info()
+ participants_info.participants_map[pure_domain_name].auth_status = auth_status.name
+ group.set_participants_info(participants_info)
+ session.commit()
+ return empty_pb2.Empty()
+
+ def UpdateModelJobGroup(self, request: UpdateModelJobGroupRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ group: ModelJobGroup = session.query(ModelJobGroup).populate_existing().with_for_update().filter_by(
+ project_id=project_id, uuid=request.uuid).first()
+ if group is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'model job group {request.uuid} is not found')
+ if request.auto_update_status != '':
+ try:
+ auto_update_status = GroupAutoUpdateStatus[request.auto_update_status]
+ except KeyError:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT,
+ f'auto_update_status {request.auto_update_status} is invalid')
+ group.auto_update_status = auto_update_status
+ if request.start_dataset_job_stage_uuid != '':
+ dataset_job_stage = session.query(DatasetJobStage).filter_by(
+ uuid=request.start_dataset_job_stage_uuid).first()
+ group.start_data_batch_id = dataset_job_stage.data_batch_id
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.MODEL_JOB_GROUP,
+ op_type=Event.OperationType.INFORM,
+ resource_name_fn=lambda request: request.uuid)
+ def CreateModelJobGroup(self, request: CreateModelJobGroupRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ group = session.query(ModelJobGroup).filter_by(uuid=request.uuid).first()
+ if group is not None:
+ return empty_pb2.Empty()
+ dataset = session.query(Dataset).filter_by(uuid=request.dataset_uuid).first()
+ if dataset is None:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'dataset with uuid {request.uuid} is not found')
+ service = ModelJobGroupService(session)
+ algorithm_type = AlgorithmType[request.algorithm_type]
+ group = service.create_group(name=request.name,
+ uuid=request.uuid,
+ project_id=project_id,
+ role=ModelJobRole.PARTICIPANT,
+ dataset_id=dataset.id,
+ algorithm_type=algorithm_type,
+ algorithm_project_list=request.algorithm_project_list,
+ coordinator_id=client_id)
+ group.status = GroupCreateStatus.SUCCEEDED
+ session.add(group)
+ session.commit()
+ return empty_pb2.Empty()
+
+ def InformTrustedJob(self, request: InformTrustedJobRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ trusted_job: TrustedJob = session.query(TrustedJob).filter_by(project_id=project_id,
+ uuid=request.uuid).first()
+ if trusted_job is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'trusted job {request.uuid} not found')
+ try:
+ auth_status = AuthStatus[request.auth_status]
+ except KeyError:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'auth_status {request.auth_status} is invalid')
+ pure_domain_name = session.query(Participant).get(client_id).pure_domain_name()
+ participants_info = trusted_job.get_participants_info()
+ participants_info.participants_map[pure_domain_name].auth_status = auth_status.name
+ trusted_job.set_participants_info(participants_info)
+ session.commit()
+ return empty_pb2.Empty()
+
+ def GetTrustedJob(self, request: GetTrustedJobRequest, context: ServicerContext) -> GetTrustedJobResponse:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ trusted_job: TrustedJob = session.query(TrustedJob).filter_by(project_id=project_id,
+ uuid=request.uuid).first()
+ if trusted_job is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'trusted job {request.uuid} not found')
+ return GetTrustedJobResponse(auth_status=trusted_job.auth_status.name)
+
+ def CreateTrustedExportJob(self, request: CreateTrustedExportJobRequest,
+ context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, context)
+ validate = get_ticket_helper(session).validate_ticket(request.ticket_uuid,
+ lambda ticket: ticket.details.uuid == request.uuid)
+ if not validate:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, f'ticket {request.ticket_uuid} is not validated')
+ tee_analyze_job = session.query(TrustedJob).filter_by(project_id=project_id,
+ type=TrustedJobType.ANALYZE,
+ uuid=request.parent_uuid).first()
+ if tee_analyze_job is None or tee_analyze_job.get_status() != TrustedJobStatus.SUCCEEDED:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'tee_analyze_job {request.parent_uuid} invalid')
+ TrustedJobService(session).create_external_export(request.uuid, request.name, client_id,
+ request.export_count, request.ticket_uuid,
+ tee_analyze_job)
+ session.commit()
+ return empty_pb2.Empty()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_server_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_server_test.py
new file mode 100644
index 000000000..0984fb3da
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/job_service_server_test.py
@@ -0,0 +1,881 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime, timedelta
+import unittest
+from unittest.mock import patch, MagicMock
+import grpc
+from google.protobuf.empty_pb2 import Empty
+from concurrent import futures
+from testing.dataset import FakeDatasetJobConfiger
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.rpc.client import FakeRpcError
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.rpc.auth import SSL_CLIENT_SUBJECT_DN_HEADER, PROJECT_NAME_HEADER
+from fedlearner_webconsole.proto.rpc.v2 import job_service_pb2_grpc
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.rpc.v2.job_service_server import JobServiceServicer
+from fedlearner_webconsole.rpc.v2.utils import get_grpc_context_info
+from fedlearner_webconsole.dataset.models import DataBatch, Dataset, DatasetJob, DatasetJobKind, DatasetJobStage, \
+ DatasetJobState, DatasetKindV2, DatasetType, DatasetJobSchedulerState
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob, TrustedJobStatus, TrustedJobType
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, AlgorithmType
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobGroup, ModelJobType, ModelJobRole, ModelJobStatus, \
+ AuthStatus as ModelAuthStatus, GroupAutoUpdateStatus
+from fedlearner_webconsole.job.models import Job, JobType
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.mmgr_pb2 import ModelJobGlobalConfig, ModelJobConfig, AlgorithmProjectList, \
+ ModelJobPb, ModelJobGroupPb
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import CreateModelJobRequest, InformTrustedJobGroupRequest, \
+ UpdateTrustedJobGroupRequest, DeleteTrustedJobGroupRequest, GetTrustedJobGroupRequest, \
+ GetTrustedJobGroupResponse, CreateDatasetJobStageRequest, GetDatasetJobStageRequest, CreateModelJobGroupRequest, \
+ GetModelJobRequest, GetModelJobGroupRequest, InformModelJobGroupRequest, InformTrustedJobRequest, \
+ GetTrustedJobRequest, GetTrustedJobResponse, CreateTrustedExportJobRequest, UpdateDatasetJobSchedulerStateRequest, \
+ UpdateModelJobGroupRequest, InformModelJobRequest
+from fedlearner_webconsole.review.common import NO_CENTRAL_SERVER_UUID
+
+
+class FakeContext:
+
+ def __init__(self, metadata):
+ self._metadata = metadata
+
+ def invocation_metadata(self):
+ return self._metadata
+
+
+class GetGrpcContextInfoTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='proj-name')
+ participant = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ session.add_all([project, participant])
+ session.commit()
+
+ def test_get_grpc_context_info(self):
+ metadata = ((SSL_CLIENT_SUBJECT_DN_HEADER,
+ 'CN=domain2.fedlearner.net,OU=security,O=security,L=beijing,ST=beijing,C=CN'),
+ (PROJECT_NAME_HEADER, 'proj-name'))
+ # since interceptor has already validated the info, only test happy case
+ with db.session_scope() as session:
+ project_id, client_id = get_grpc_context_info(session, FakeContext(metadata))
+ self.assertEqual(project_id, 1)
+ self.assertEqual(client_id, 1)
+
+
+class SystemServiceTest(NoWebServerTestCase):
+ LISTEN_PORT = 2000
+
+ def setUp(self):
+ super().setUp()
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=20))
+ job_service_pb2_grpc.add_JobServiceServicer_to_server(JobServiceServicer(), self._server)
+ self._server.add_insecure_port(f'[::]:{self.LISTEN_PORT}')
+ self._server.start()
+ self._channel = grpc.insecure_channel(target=f'localhost:{self.LISTEN_PORT}')
+ self._stub = job_service_pb2_grpc.JobServiceStub(self._channel)
+
+ def tearDown(self):
+ self._channel.close()
+ self._server.stop(5)
+ return super().tearDown()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_inform_trusted_job_group(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ project = Project(id=1, name='proj-name')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ participant2 = Participant(id=2, name='part3', domain_name='fl-domain3.com')
+ group = TrustedJobGroup(id=1, name='group', uuid='uuid', project_id=1, coordinator_id=0)
+ group.set_unauth_participant_ids([1, 2])
+ session.add_all([project, participant1, participant2, group])
+ session.commit()
+ # authorize
+ self._stub.InformTrustedJobGroup(InformTrustedJobGroupRequest(uuid='uuid', auth_status='AUTHORIZED'))
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).filter_by(uuid='uuid').first()
+ self.assertCountEqual(group.get_unauth_participant_ids(), [2])
+ # pend
+ self._stub.InformTrustedJobGroup(InformTrustedJobGroupRequest(uuid='uuid', auth_status='PENDING'))
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).filter_by(uuid='uuid').first()
+ self.assertCountEqual(group.get_unauth_participant_ids(), [1, 2])
+ # fail due to group uuid not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformTrustedJobGroup(InformTrustedJobGroupRequest(uuid='not-exist', auth_status='AUTHORIZED'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ # fail due to invalid auth status
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformTrustedJobGroup(InformTrustedJobGroupRequest(uuid='uuid', auth_status='AUTHORIZE'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+
+ @patch('fedlearner_webconsole.algorithm.fetcher.AlgorithmFetcher.get_algorithm_from_participant')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_update_trusted_job_group(self, mock_get_grpc_context_info: MagicMock, mock_get_algorithm: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ mock_get_algorithm.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'not found')
+ with db.session_scope() as session:
+ project = Project(id=1, name='proj-name')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ algorithm_proj1 = AlgorithmProject(id=1, uuid='algorithm-proj-uuid1')
+ algorithm_proj2 = AlgorithmProject(id=2, uuid='algorithm-proj-uuid2')
+ algorithm1 = Algorithm(id=1, algorithm_project_id=1, uuid='algorithm-uuid1')
+ algorithm2 = Algorithm(id=2, algorithm_project_id=1, uuid='algorithm-uuid2')
+ algorithm3 = Algorithm(id=3, algorithm_project_id=2, uuid='algorithm-uuid3')
+ group1 = TrustedJobGroup(id=1,
+ name='group1',
+ uuid='uuid1',
+ project_id=1,
+ algorithm_uuid='algorithm-uuid1',
+ coordinator_id=1)
+ group2 = TrustedJobGroup(id=2,
+ name='group2',
+ uuid='uuid2',
+ project_id=1,
+ algorithm_uuid='algorithm-uuid1',
+ coordinator_id=0)
+ session.add_all([
+ project, participant1, algorithm_proj1, algorithm_proj2, algorithm1, algorithm2, algorithm3, group1,
+ group2
+ ])
+ session.commit()
+ self._stub.UpdateTrustedJobGroup(UpdateTrustedJobGroupRequest(uuid='uuid1', algorithm_uuid='algorithm-uuid2'))
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).filter_by(uuid='uuid1').first()
+ self.assertEqual(group.algorithm_uuid, 'algorithm-uuid2')
+ # fail due to group uuid not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.UpdateTrustedJobGroup(
+ UpdateTrustedJobGroupRequest(uuid='not-exist', algorithm_uuid='algorithm-uuid2'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ # fail due to client not coordinator
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.UpdateTrustedJobGroup(
+ UpdateTrustedJobGroupRequest(uuid='uuid2', algorithm_uuid='algorithm-uuid2'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+ # fail due to algorithm not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.UpdateTrustedJobGroup(
+ UpdateTrustedJobGroupRequest(uuid='uuid1', algorithm_uuid='algorithm-not-exist'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ # fail due to mismatched algorithm project
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.UpdateTrustedJobGroup(
+ UpdateTrustedJobGroupRequest(uuid='uuid1', algorithm_uuid='algorithm-uuid3'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_delete_trusted_job_group(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ project = Project(id=1, name='proj-name')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ group1 = TrustedJobGroup(id=1, uuid='uuid1', project_id=1, coordinator_id=1)
+ group2 = TrustedJobGroup(id=2, uuid='uuid2', project_id=1, coordinator_id=0)
+ trusted_job1 = TrustedJob(id=1,
+ name='V1',
+ trusted_job_group_id=1,
+ job_id=1,
+ status=TrustedJobStatus.RUNNING)
+ job1 = Job(id=1, name='job-name1', job_type=JobType.CUSTOMIZED, workflow_id=0, project_id=1)
+ trusted_job2 = TrustedJob(id=2,
+ name='V2',
+ trusted_job_group_id=1,
+ job_id=2,
+ status=TrustedJobStatus.SUCCEEDED)
+ job2 = Job(id=2, name='job-name2', job_type=JobType.CUSTOMIZED, workflow_id=0, project_id=1)
+ session.add_all([project, participant1, group1, group2, trusted_job1, job1, trusted_job2, job2])
+ session.commit()
+ # fail due to client is not coordinator
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.DeleteTrustedJobGroup(DeleteTrustedJobGroupRequest(uuid='uuid2'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+ # delete group not exist
+ resp = self._stub.DeleteTrustedJobGroup(DeleteTrustedJobGroupRequest(uuid='not-exist'))
+ self.assertEqual(resp, Empty())
+ # fail due to trusted job is still running
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.DeleteTrustedJobGroup(DeleteTrustedJobGroupRequest(uuid='uuid1'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.FAILED_PRECONDITION)
+ # successful
+ with db.session_scope() as session:
+ trusted_job1 = session.query(TrustedJob).get(1)
+ trusted_job1.status = TrustedJobStatus.FAILED
+ session.commit()
+ self._stub.DeleteTrustedJobGroup(DeleteTrustedJobGroupRequest(uuid='uuid1'))
+ with db.session_scope() as session:
+ self.assertIsNone(session.query(TrustedJobGroup).get(1))
+ self.assertIsNone(session.query(TrustedJob).get(1))
+ self.assertIsNone(session.query(TrustedJob).get(2))
+ self.assertIsNone(session.query(Job).get(1))
+ self.assertIsNone(session.query(Job).get(2))
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_get_trusted_job_group(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ group = TrustedJobGroup(id=1, name='group', uuid='uuid', project_id=1, auth_status=AuthStatus.AUTHORIZED)
+ session.add_all([group])
+ session.commit()
+ resp = self._stub.GetTrustedJobGroup(GetTrustedJobGroupRequest(uuid='uuid'))
+ self.assertEqual(resp, GetTrustedJobGroupResponse(auth_status='AUTHORIZED'))
+ # fail due to not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.GetTrustedJobGroup(GetTrustedJobGroupRequest(uuid='uuid-not-exist'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_create_trusted_export_job(self, mock_get_grpc_context_info: MagicMock, mock_get_system_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ participant2 = Participant(id=2, name='part3', domain_name='fl-domain3.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ proj_part2 = ProjectParticipant(project_id=1, participant_id=2)
+ tee_analyze_job = TrustedJob(id=1,
+ uuid='uuid1',
+ type=TrustedJobType.ANALYZE,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=2,
+ status=TrustedJobStatus.SUCCEEDED)
+ session.add_all([project, participant1, participant2, proj_part1, proj_part2, tee_analyze_job])
+ session.commit()
+ # successful
+ req = CreateTrustedExportJobRequest(uuid='uuid2',
+ name='V1-domain2-1',
+ export_count=1,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID,
+ parent_uuid='uuid1')
+ self._stub.CreateTrustedExportJob(req)
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).filter_by(uuid='uuid2').first()
+ self.assertEqual(tee_export_job.name, 'V1-domain2-1')
+ self.assertEqual(tee_export_job.type, TrustedJobType.EXPORT)
+ self.assertEqual(tee_export_job.export_count, 1)
+ self.assertEqual(tee_export_job.project_id, 1)
+ self.assertEqual(tee_export_job.trusted_job_group_id, 1)
+ self.assertEqual(tee_export_job.status, TrustedJobStatus.CREATED)
+ self.assertEqual(tee_export_job.auth_status, AuthStatus.PENDING)
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['domain1'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['domain2'].auth_status = AuthStatus.AUTHORIZED.name
+ participants_info.participants_map['domain3'].auth_status = AuthStatus.PENDING.name
+ self.assertEqual(tee_export_job.get_participants_info(), participants_info)
+ # failed due to tee_analyze_job not valid
+ with self.assertRaises(grpc.RpcError) as cm:
+ req.parent_uuid = 'not-exist'
+ self._stub.CreateTrustedExportJob(req)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ # failed due to ticket invalid
+ with self.assertRaises(grpc.RpcError) as cm:
+ req.ticket_uuid = 'invalid ticket'
+ self._stub.CreateTrustedExportJob(req)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+
+ @patch('fedlearner_webconsole.mmgr.service.ModelJobService.create_model_job')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_create_model_job(self, mock_get_grpc_context_info: MagicMock, mock_create_model_job: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ # fail due to group not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.CreateModelJob(CreateModelJobRequest(group_uuid='uuid-not-exist'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ with db.session_scope() as session:
+ session.add(ModelJobGroup(id=2, name='name', uuid='group_uuid', project_id=1, latest_version=2))
+ session.add(ModelJob(id=1, name='model-job', project_id=1))
+ session.commit()
+ # fail due to model job name already exists
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.CreateModelJob(CreateModelJobRequest(group_uuid='group_uuid', uuid='uuid', name='model-job'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ # fail due to the model job version not larger than group's latest version
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.CreateModelJob(
+ CreateModelJobRequest(group_uuid='group_uuid',
+ name='name',
+ version=2,
+ model_job_type=ModelJobType.TRAINING.name,
+ algorithm_type=AlgorithmType.NN_VERTICAL.name))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ # create training model job successfully
+ mock_create_model_job.return_value = ModelJob(name='haha', uuid='uuid', version=3)
+ global_config = ModelJobGlobalConfig(global_config={'test': ModelJobConfig()})
+ self._stub.CreateModelJob(
+ CreateModelJobRequest(name='name',
+ uuid='uuid',
+ group_uuid='group_uuid',
+ model_job_type='TRAINING',
+ algorithm_type='NN_VERTICAL',
+ global_config=global_config,
+ version=3))
+ mock_create_model_job.assert_called_with(name='name',
+ uuid='uuid',
+ group_id=2,
+ project_id=1,
+ role=ModelJobRole.PARTICIPANT,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ coordinator_id=1,
+ data_batch_id=None,
+ global_config=global_config,
+ version=3)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).filter_by(name='name').first()
+ self.assertEqual(group.latest_version, 3)
+ # create evaluation model job successfully
+ mock_create_model_job.return_value = ModelJob(name='haha', uuid='uuid', version=None)
+ global_config = ModelJobGlobalConfig(global_config={'test': ModelJobConfig()})
+ self._stub.CreateModelJob(
+ CreateModelJobRequest(name='name',
+ uuid='uuid',
+ group_uuid='group_uuid',
+ model_job_type='EVALUATION',
+ algorithm_type='NN_VERTICAL',
+ global_config=global_config,
+ version=0))
+ mock_create_model_job.assert_called_with(name='name',
+ uuid='uuid',
+ group_id=2,
+ project_id=1,
+ role=ModelJobRole.PARTICIPANT,
+ model_job_type=ModelJobType.EVALUATION,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ coordinator_id=1,
+ data_batch_id=None,
+ global_config=global_config,
+ version=0)
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).filter_by(name='name').first()
+ self.assertEqual(group.latest_version, 3)
+ # create auto update model job
+ with db.session_scope() as session:
+ data_batch = DataBatch(id=1,
+ name='0',
+ dataset_id=1,
+ path='/test_dataset/1/batch/0',
+ event_time=datetime(2021, 10, 28, 16, 37, 37))
+ dataset_job_stage = DatasetJobStage(id=1,
+ name='data-join',
+ uuid='dataset-job-stage-uuid',
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ dataset_job_id=1,
+ data_batch_id=1)
+ session.add_all([data_batch, dataset_job_stage])
+ session.commit()
+ global_config = ModelJobGlobalConfig(dataset_job_stage_uuid='dataset-job-stage-uuid')
+ self._stub.CreateModelJob(
+ CreateModelJobRequest(name='name',
+ uuid='uuid',
+ group_uuid='group_uuid',
+ model_job_type='TRAINING',
+ algorithm_type='NN_VERTICAL',
+ global_config=global_config,
+ version=4))
+ mock_create_model_job.assert_called_with(name='name',
+ uuid='uuid',
+ group_id=2,
+ project_id=1,
+ role=ModelJobRole.PARTICIPANT,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ coordinator_id=1,
+ data_batch_id=1,
+ global_config=global_config,
+ version=4)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_inform_model_job(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant = Participant(id=1, name='part1', domain_name='fl-demo1.com')
+ pro_part = ProjectParticipant(id=1, project_id=1, participant_id=1)
+ model_job = ModelJob(id=1,
+ name='model_job',
+ uuid='uuid',
+ project_id=1,
+ auth_status=ModelAuthStatus.AUTHORIZED)
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['demo1'].auth_status = AuthStatus.PENDING.name
+ model_job.set_participants_info(participants_info)
+ session.add_all([project, participant, pro_part, model_job])
+ session.commit()
+ self._stub.InformModelJob(InformModelJobRequest(uuid='uuid', auth_status=AuthStatus.AUTHORIZED.name))
+ # authorized
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ participants_info = model_job.get_participants_info()
+ self.assertEqual(participants_info.participants_map['demo1'].auth_status, AuthStatus.AUTHORIZED.name)
+ # pending
+ self._stub.InformModelJob(InformModelJobRequest(uuid='uuid', auth_status=AuthStatus.PENDING.name))
+ with db.session_scope() as session:
+ model_job = session.query(ModelJob).get(1)
+ participants_info = model_job.get_participants_info()
+ self.assertEqual(participants_info.participants_map['demo1'].auth_status, AuthStatus.PENDING.name)
+ # fail due to model job not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformModelJob(InformModelJobRequest(uuid='uuid1', auth_status=AuthStatus.PENDING.name))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ # fail due to auth_status invalid
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformModelJob(InformModelJobRequest(uuid='uuid', auth_status='aaaaa'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+
+ @patch('fedlearner_webconsole.dataset.models.DataBatch.is_available')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_create_dataset_job_stage(self, mock_get_grpc_context_info: MagicMock, mock_is_available: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ mock_is_available.return_value = True
+ # test streaming
+ event_time = datetime(2022, 1, 1)
+ request = CreateDatasetJobStageRequest(dataset_job_uuid='dataset_job_123',
+ dataset_job_stage_uuid='dataset_job_stage_123',
+ name='test_stage',
+ event_time=to_timestamp(event_time))
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.CreateDatasetJobStage(request)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ self.assertEqual(cm.exception.details(), 'dataset_job dataset_job_123 is not found')
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job_123',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=10,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=1)
+ session.add(dataset_job)
+ dataset = Dataset(id=1,
+ uuid='dataset input',
+ name='default dataset input',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(dataset)
+ data_batch = DataBatch(id=1,
+ name='test_batch',
+ dataset_id=1,
+ latest_parent_dataset_job_stage_id=100,
+ latest_analyzer_dataset_job_stage_id=100)
+ session.add(data_batch)
+ session.commit()
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.CreateDatasetJobStage(request)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ self.assertEqual(cm.exception.details(), 'output dataset is not found, dataset_job uuid: dataset_job_123')
+ with db.session_scope() as session:
+ default_dataset = Dataset(id=10,
+ uuid='dataset_123',
+ name='default dataset',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(default_dataset)
+ session.commit()
+ resp = self._stub.CreateDatasetJobStage(request)
+ self.assertEqual(resp, Empty())
+ with db.session_scope() as session:
+ data_batch: DataBatch = session.query(DataBatch).filter(DataBatch.event_time == event_time).first()
+ self.assertEqual(data_batch.dataset_id, 10)
+ self.assertEqual(data_batch.name, '20220101')
+ self.assertEqual(data_batch.path, '/data/dataset/123/batch/20220101')
+ dataset_job_stage: DatasetJobStage = session.query(DatasetJobStage).filter(
+ DatasetJobStage.uuid == 'dataset_job_stage_123').first()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(dataset_job_stage.data_batch_id, data_batch.id)
+ self.assertEqual(dataset_job_stage.event_time, event_time)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+ # idempotent test
+ resp = self._stub.CreateDatasetJobStage(request)
+ self.assertEqual(resp, Empty())
+ with db.session_scope() as session:
+ data_batches = session.query(DataBatch).filter(DataBatch.event_time == event_time).all()
+ self.assertEqual(len(data_batches), 1)
+ # test psi
+ with db.session_scope() as session:
+ dataset_job_2 = DatasetJob(id=2,
+ uuid='dataset_job_2',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=11,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=1)
+ session.add(dataset_job_2)
+ dataset_2 = Dataset(id=11,
+ uuid='dataset_2',
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(dataset_2)
+ session.commit()
+ request = CreateDatasetJobStageRequest(dataset_job_uuid='dataset_job_2',
+ dataset_job_stage_uuid='dataset_job_stage_2',
+ name='test_stage')
+ resp = self._stub.CreateDatasetJobStage(request)
+ self.assertEqual(resp, Empty())
+ with db.session_scope() as session:
+ data_batch: DataBatch = session.query(DataBatch).filter(DataBatch.dataset_id == 11).first()
+ self.assertEqual(data_batch.name, '0')
+ self.assertEqual(data_batch.path, '/data/dataset/123/batch/0')
+ dataset_job_stage: DatasetJobStage = session.query(DatasetJobStage).filter(
+ DatasetJobStage.uuid == 'dataset_job_stage_2').first()
+ self.assertEqual(dataset_job_stage.dataset_job_id, 2)
+ self.assertEqual(dataset_job_stage.data_batch_id, data_batch.id)
+ self.assertIsNone(dataset_job_stage.event_time)
+ self.assertEqual(dataset_job_stage.project_id, 1)
+ self.assertEqual(dataset_job_stage.coordinator_id, 1)
+ # idempotent test
+ resp = self._stub.CreateDatasetJobStage(request)
+ self.assertEqual(resp, Empty())
+ with db.session_scope() as session:
+ data_batches = session.query(DataBatch).filter(DataBatch.dataset_id == 11).all()
+ self.assertEqual(len(data_batches), 1)
+
+ # test batch not ready
+ mock_is_available.reset_mock()
+ mock_is_available.return_value = False
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.CreateDatasetJobStage(request)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.FAILED_PRECONDITION)
+
+ @patch('fedlearner_webconsole.dataset.services.DatasetJobConfiger.from_kind',
+ lambda *args: FakeDatasetJobConfiger(None))
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_get_dataset_job_stage(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job_uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=0,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ state=DatasetJobState.RUNNING,
+ coordinator_id=0,
+ workflow_id=0)
+ session.add(dataset_job)
+ job_stage = DatasetJobStage(id=1,
+ uuid='job_stage_uuid',
+ name='default dataset job stage',
+ project_id=1,
+ workflow_id=1,
+ dataset_job_id=1,
+ data_batch_id=1,
+ event_time=datetime(2012, 1, 15),
+ state=DatasetJobState.PENDING)
+ session.add(job_stage)
+ session.commit()
+ request = GetDatasetJobStageRequest(dataset_job_stage_uuid='job_stage_uuid')
+ resp = self._stub.GetDatasetJobStage(request)
+ self.assertEqual(resp.dataset_job_stage.uuid, 'job_stage_uuid')
+ self.assertEqual(resp.dataset_job_stage.name, 'default dataset job stage')
+ self.assertEqual(resp.dataset_job_stage.dataset_job_id, 1)
+ self.assertEqual(resp.dataset_job_stage.event_time, to_timestamp(datetime(2012, 1, 15)))
+ self.assertEqual(resp.dataset_job_stage.is_ready, False)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_update_dataset_job_scheduler_state(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job_uuid',
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=0,
+ kind=DatasetJobKind.IMPORT_SOURCE,
+ state=DatasetJobState.RUNNING,
+ coordinator_id=0,
+ workflow_id=0,
+ scheduler_state=DatasetJobSchedulerState.PENDING,
+ time_range=timedelta(days=1))
+ session.add(dataset_job)
+ session.commit()
+ request = UpdateDatasetJobSchedulerStateRequest(uuid='dataset_job_uuid',
+ scheduler_state=DatasetJobSchedulerState.RUNNABLE.name)
+ resp = self._stub.UpdateDatasetJobSchedulerState(request)
+ self.assertEqual(resp, Empty())
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.RUNNABLE)
+ request = UpdateDatasetJobSchedulerStateRequest(uuid='dataset_job_uuid',
+ scheduler_state=DatasetJobSchedulerState.STOPPED.name)
+ resp = self._stub.UpdateDatasetJobSchedulerState(request)
+ self.assertEqual(resp, Empty())
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(1)
+ self.assertEqual(dataset_job.scheduler_state, DatasetJobSchedulerState.STOPPED)
+ request = UpdateDatasetJobSchedulerStateRequest(uuid='dataset_job_uuid',
+ scheduler_state=DatasetJobSchedulerState.PENDING.name)
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.UpdateDatasetJobSchedulerState(request)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_get_model_job(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ request = GetModelJobRequest(uuid='uuid')
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.GetModelJob(request=request)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ self.assertEqual(cm.exception.details(), 'model job uuid is not found')
+ with db.session_scope() as session:
+ model_job = ModelJob(uuid='uuid',
+ role=ModelJobRole.PARTICIPANT,
+ model_job_type=ModelJobType.TRAINING,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ auth_status=ModelAuthStatus.AUTHORIZED,
+ status=ModelJobStatus.CONFIGURED,
+ created_at=datetime(2022, 8, 16, 0, 0),
+ updated_at=datetime(2022, 8, 16, 0, 0))
+ session.add(model_job)
+ session.commit()
+ resp = self._stub.GetModelJob(request=request)
+ self.assertEqual(
+ resp,
+ ModelJobPb(id=1,
+ uuid='uuid',
+ role='PARTICIPANT',
+ model_job_type='TRAINING',
+ algorithm_type='NN_VERTICAL',
+ state='PENDING_ACCEPT',
+ auth_status='AUTHORIZED',
+ status='CONFIGURED',
+ auth_frontend_status='ALL_AUTHORIZED',
+ participants_info=ParticipantsInfo(),
+ created_at=1660608000,
+ updated_at=1660608000))
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_get_model_job_group(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ request = GetModelJobGroupRequest(uuid='uuid')
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.GetModelJobGroup(request=request)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ self.assertEqual(cm.exception.details(), 'model job group with uuid uuid is not found')
+ with db.session_scope() as session:
+ group = ModelJobGroup(id=1,
+ role=ModelJobRole.PARTICIPANT,
+ uuid='uuid',
+ authorized=True,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ created_at=datetime(2022, 8, 16, 0, 0),
+ updated_at=datetime(2022, 8, 16, 0, 0))
+ session.add(group)
+ session.commit()
+ resp = self._stub.GetModelJobGroup(request=request)
+ self.assertEqual(
+ resp,
+ ModelJobGroupPb(id=1,
+ uuid='uuid',
+ role='PARTICIPANT',
+ algorithm_type='NN_VERTICAL',
+ authorized=True,
+ auth_frontend_status='ALL_AUTHORIZED',
+ auth_status='PENDING',
+ auto_update_status='INITIAL',
+ participants_info=ParticipantsInfo(),
+ algorithm_project_uuid_list=AlgorithmProjectList(),
+ created_at=1660608000,
+ updated_at=1660608000))
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_inform_model_job_group(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part1', domain_name='fl-demo1.com')
+ participant2 = Participant(id=2, name='part2', domain_name='fl-demo2.com')
+ group = ModelJobGroup(id=1, uuid='uuid', project_id=1)
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['demo1'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['demo2'].auth_status = AuthStatus.PENDING.name
+ group.set_participants_info(participants_info)
+ session.add_all([project, participant1, participant2, group])
+ session.commit()
+ # authorized
+ self._stub.InformModelJobGroup(InformModelJobGroupRequest(uuid='uuid', auth_status=AuthStatus.AUTHORIZED.name))
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map['demo1'].auth_status, AuthStatus.AUTHORIZED.name)
+ # pending
+ self._stub.InformModelJobGroup(InformModelJobGroupRequest(uuid='uuid', auth_status=AuthStatus.PENDING.name))
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).get(1)
+ participants_info = group.get_participants_info()
+ self.assertEqual(participants_info.participants_map['demo1'].auth_status, AuthStatus.PENDING.name)
+ # fail due to group not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformModelJobGroup(
+ InformModelJobGroupRequest(uuid='uuid-1', auth_status=AuthStatus.PENDING.name))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ # fail due to auth_status invalid
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformModelJobGroup(InformModelJobGroupRequest(uuid='uuid', auth_status='aaaaa'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_update_model_job_group(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ group = ModelJobGroup(id=1,
+ name='group',
+ uuid='group_uuid',
+ project_id=1,
+ auto_update_status=GroupAutoUpdateStatus.INITIAL)
+ dataset_job_stage = DatasetJobStage(id=1,
+ name='data_join',
+ uuid='stage_uuid',
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ dataset_job_id=1,
+ data_batch_id=1)
+ session.add_all([project, group, dataset_job_stage])
+ session.commit()
+ # fail due to group not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.UpdateModelJobGroup(
+ UpdateModelJobGroupRequest(uuid='uuid',
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE.name,
+ start_dataset_job_stage_uuid='stage_uuid'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ # fail due to auto_update_status invalid
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.UpdateModelJobGroup(
+ UpdateModelJobGroupRequest(uuid='group_uuid',
+ auto_update_status='aaa',
+ start_dataset_job_stage_uuid='stage_uuid'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ # update auto_update_status and start_data_batch_id
+ self._stub.UpdateModelJobGroup(
+ UpdateModelJobGroupRequest(uuid='group_uuid',
+ auto_update_status=GroupAutoUpdateStatus.ACTIVE.name,
+ start_dataset_job_stage_uuid='stage_uuid'))
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).filter_by(uuid='group_uuid', project_id=1).first()
+ self.assertEqual(group.auto_update_status, GroupAutoUpdateStatus.ACTIVE)
+ self.assertEqual(group.start_data_batch_id, 1)
+ # only update auto_update_status
+ self._stub.UpdateModelJobGroup(
+ UpdateModelJobGroupRequest(uuid='group_uuid', auto_update_status=GroupAutoUpdateStatus.STOPPED.name))
+ with db.session_scope() as session:
+ group = session.query(ModelJobGroup).filter_by(uuid='group_uuid', project_id=1).first()
+ self.assertEqual(group.auto_update_status, GroupAutoUpdateStatus.STOPPED)
+
+ @patch('fedlearner_webconsole.mmgr.service.ModelJobGroupService.create_group')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_create_model_job_group(self, mock_get_grpc_context_info: MagicMock, mock_create_group: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ request = CreateModelJobGroupRequest(
+ name='name',
+ uuid='uuid',
+ algorithm_type=AlgorithmType.NN_VERTICAL.name,
+ dataset_uuid='uuid',
+ algorithm_project_list=AlgorithmProjectList(algorithm_projects={'test': 'uuid'}))
+ with self.assertRaises(grpc.RpcError) as cm:
+ # test dataset not found
+ resp = self._stub.CreateModelJobGroup(request)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+ self.assertEqual(cm.exception.details(), 'dataset with uuid uuid is not found')
+ with db.session_scope() as session:
+ session.add(Dataset(id=1, name='name', uuid='uuid'))
+ session.commit()
+ mock_create_group.return_value = ModelJobGroup(name='name', uuid='uuid')
+ resp = self._stub.CreateModelJobGroup(request)
+ # create group
+ mock_create_group.assert_called_with(
+ name='name',
+ uuid='uuid',
+ project_id=1,
+ role=ModelJobRole.PARTICIPANT,
+ dataset_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_project_list=AlgorithmProjectList(algorithm_projects={'test': 'uuid'}),
+ coordinator_id=1)
+ mock_create_group.reset_mock()
+ resp = self._stub.CreateModelJobGroup(request)
+ # create group not called if group is already created
+ mock_create_group.assert_not_called()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_inform_trusted_job(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part1', domain_name='fl-domain2.com')
+ trusted_job = TrustedJob(id=1, uuid='uuid', project_id=1)
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['domain2'].auth_status = AuthStatus.PENDING.name
+ trusted_job.set_participants_info(participants_info)
+ session.add_all([project, participant1, trusted_job])
+ session.commit()
+ self._stub.InformTrustedJob(InformTrustedJobRequest(uuid='uuid', auth_status=AuthStatus.AUTHORIZED.name))
+ with db.session_scope() as session:
+ trusted_job = session.query(TrustedJob).get(1)
+ participants_info = trusted_job.get_participants_info()
+ self.assertEqual(participants_info.participants_map['domain2'].auth_status, AuthStatus.AUTHORIZED.name)
+ # fail due to group not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformTrustedJob(InformTrustedJobRequest(uuid='not-exist', auth_status=AuthStatus.WITHDRAW.name))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ # fail due to auth_status invalid
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformTrustedJob(InformTrustedJobRequest(uuid='uuid', auth_status='AUTHORIZE'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_server.get_grpc_context_info')
+ def test_get_trusted_job(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ trusted_job = TrustedJob(id=1, name='name', uuid='uuid', project_id=1, auth_status=AuthStatus.AUTHORIZED)
+ session.add_all([trusted_job])
+ session.commit()
+ resp = self._stub.GetTrustedJob(GetTrustedJobRequest(uuid='uuid'))
+ self.assertEqual(resp, GetTrustedJobResponse(auth_status='AUTHORIZED'))
+ # fail due to not found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.GetTrustedJob(GetTrustedJobRequest(uuid='uuid-not-exist'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_client.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_client.py
new file mode 100644
index 000000000..3dd0a99c1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_client.py
@@ -0,0 +1,86 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict
+
+import grpc
+from google.protobuf import empty_pb2
+
+from envs import Envs
+from fedlearner_webconsole.project.models import PendingProject, PendingProjectState
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplateKind
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo
+from fedlearner_webconsole.proto.rpc.v2.project_service_pb2 import CreatePendingProjectRequest, \
+ UpdatePendingProjectRequest, SyncPendingProjectStateRequest, CreateProjectRequest, DeletePendingProjectRequest, \
+ SendTemplateRevisionRequest
+from fedlearner_webconsole.proto.rpc.v2.project_service_pb2_grpc import ProjectServiceStub
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.rpc.v2.client_base import ParticipantRpcClient
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+
+
+# only unavailable is caused by network jitter.
+def _retry_unavailable(err: Exception) -> bool:
+ if not isinstance(err, grpc.RpcError):
+ return False
+ return err.code() == grpc.StatusCode.UNAVAILABLE
+
+
+class ProjectServiceClient(ParticipantRpcClient):
+
+ def __init__(self, channel: grpc.Channel):
+ super().__init__(channel)
+ self._stub: ProjectServiceStub = ProjectServiceStub(channel)
+
+ @retry_fn(retry_times=3, need_retry=_retry_unavailable)
+ def create_pending_project(self, pending_project: PendingProject) -> empty_pb2.Empty:
+ request = CreatePendingProjectRequest(uuid=pending_project.uuid,
+ name=pending_project.name,
+ participants_info=pending_project.get_participants_info(),
+ comment=pending_project.comment,
+ creator_username=pending_project.creator_username,
+ config=pending_project.get_config(),
+ ticket_uuid=pending_project.ticket_uuid)
+ return self._stub.CreatePendingProject(request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_retry_unavailable)
+ def update_pending_project(self, uuid: str, participants_map: Dict[str, ParticipantInfo]) -> empty_pb2.Empty:
+ request = UpdatePendingProjectRequest(uuid=uuid, participants_map=participants_map)
+ return self._stub.UpdatePendingProject(request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_retry_unavailable)
+ def create_project(self, uuid: str):
+ request = CreateProjectRequest(uuid=uuid)
+ return self._stub.CreateProject(request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_retry_unavailable)
+ def sync_pending_project_state(self, uuid: str, state: PendingProjectState) -> empty_pb2.Empty:
+ request = SyncPendingProjectStateRequest(uuid=uuid, state=state.name)
+ return self._stub.SyncPendingProjectState(request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_retry_unavailable)
+ def delete_pending_project(self, uuid: str):
+ request = DeletePendingProjectRequest(uuid=uuid)
+ return self._stub.DeletePendingProject(request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_retry_unavailable)
+ def send_template_revision(self, config: WorkflowDefinition, name: str, comment: str, kind: WorkflowTemplateKind,
+ revision_index: int):
+ request = SendTemplateRevisionRequest(config=config,
+ name=name,
+ comment=comment,
+ kind=kind.name,
+ revision_index=revision_index)
+ return self._stub.SendTemplateRevision(request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_client_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_client_test.py
new file mode 100644
index 000000000..5724c1b2c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_client_test.py
@@ -0,0 +1,169 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import grpc
+import grpc_testing
+from google.protobuf.empty_pb2 import Empty
+from google.protobuf.descriptor import ServiceDescriptor
+
+from fedlearner_webconsole.project.models import PendingProject, PendingProjectState, ProjectRole
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.proto.rpc.v2.project_service_pb2 import CreatePendingProjectRequest, \
+ UpdatePendingProjectRequest, SyncPendingProjectStateRequest, CreateProjectRequest, DeletePendingProjectRequest, \
+ SendTemplateRevisionRequest
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.proto.rpc.v2 import project_service_pb2
+from fedlearner_webconsole.rpc.v2.project_service_client import ProjectServiceClient
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplateKind
+from testing.rpc.client import RpcClientTestCase
+
+_SERVER_DESCRIPTOR: ServiceDescriptor = project_service_pb2.DESCRIPTOR.services_by_name['ProjectService']
+
+
+class ProjectServiceClientTest(RpcClientTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._fake_channel: grpc_testing.Channel = grpc_testing.channel([_SERVER_DESCRIPTOR],
+ grpc_testing.strict_real_time())
+ self._client = ProjectServiceClient(self._fake_channel)
+
+ def test_create_pending_project(self):
+ pending_project = PendingProject(uuid='test', name='test-project', ticket_uuid='test')
+
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test':
+ ParticipantInfo(
+ name='test', state=PendingProjectState.ACCEPTED.name, role=ProjectRole.COORDINATOR.name),
+ 'part':
+ ParticipantInfo(
+ name='part', state=PendingProjectState.PENDING.name, role=ProjectRole.PARTICIPANT.name)
+ })
+ pending_project.set_participants_info(participants_info)
+ call = self.client_execution_pool.submit(self._client.create_pending_project, pending_project=pending_project)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['CreatePendingProject'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ CreatePendingProjectRequest(uuid='test',
+ participants_info=participants_info,
+ name='test-project',
+ config=pending_project.get_config(),
+ ticket_uuid='test'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_update_pending_project(self):
+ participants_map = {
+ 'part':
+ ParticipantInfo(name='part', state=PendingProjectState.ACCEPTED.name, role=ProjectRole.PARTICIPANT.name)
+ }
+
+ call = self.client_execution_pool.submit(self._client.update_pending_project,
+ uuid='test',
+ participants_map=participants_map)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['UpdatePendingProject'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, UpdatePendingProjectRequest(uuid='test', participants_map=participants_map))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_sync_pending_project_state(self):
+ call = self.client_execution_pool.submit(self._client.sync_pending_project_state,
+ uuid='test',
+ state=PendingProjectState.ACCEPTED)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['SyncPendingProjectState'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, SyncPendingProjectStateRequest(uuid='test', state='ACCEPTED'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_create_project(self):
+ call = self.client_execution_pool.submit(self._client.create_project, uuid='test')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['CreateProject'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, CreateProjectRequest(uuid='test'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_delete_pending_project(self):
+ call = self.client_execution_pool.submit(self._client.delete_pending_project, uuid='test')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['DeletePendingProject'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, DeletePendingProjectRequest(uuid='test'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_send_template_revision(self):
+ call = self.client_execution_pool.submit(self._client.send_template_revision,
+ config=WorkflowDefinition(),
+ name='test',
+ comment='test',
+ kind=WorkflowTemplateKind.PEER,
+ revision_index=1)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['SendTemplateRevision'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ SendTemplateRevisionRequest(config=WorkflowDefinition(),
+ name='test',
+ comment='test',
+ kind=WorkflowTemplateKind.PEER.name,
+ revision_index=1))
+ self.assertEqual(call.result(), expected_response)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_server.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_server.py
new file mode 100644
index 000000000..e2a73efd0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_server.py
@@ -0,0 +1,181 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Dict
+
+import grpc
+from google.protobuf import empty_pb2
+from grpc import ServicerContext
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import PendingProjectState, ProjectRole, PendingProject, Project
+from fedlearner_webconsole.project.services import PendingProjectService
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo, ParticipantsInfo
+from fedlearner_webconsole.proto.rpc.v2.project_service_pb2_grpc import ProjectServiceServicer
+from fedlearner_webconsole.proto.rpc.v2.project_service_pb2 import CreatePendingProjectRequest, \
+ UpdatePendingProjectRequest, SyncPendingProjectStateRequest, CreateProjectRequest, SendTemplateRevisionRequest
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+from fedlearner_webconsole.rpc.v2.utils import get_pure_domain_from_context
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.workflow_template.service import WorkflowTemplateRevisionService
+from fedlearner_webconsole.audit.decorators import emits_rpc_event
+from fedlearner_webconsole.proto.audit_pb2 import Event
+
+
+def _is_same_participants(participants_map: Dict[str, ParticipantInfo],
+ new_participants_map: Dict[str, ParticipantInfo]) -> bool:
+ return set(participants_map) != set(new_participants_map)
+
+
+class ProjectGrpcService(ProjectServiceServicer):
+
+ @emits_rpc_event(resource_type=Event.ResourceType.PENDING_PROJECT,
+ op_type=Event.OperationType.CREATE,
+ resource_name_fn=lambda request: request.uuid)
+ def CreatePendingProject(self, request: CreatePendingProjectRequest, context: ServicerContext):
+
+ with db.session_scope() as session:
+ existed = session.query(PendingProject).filter_by(uuid=request.uuid).first()
+ # make CreatePendingProject idempotent
+ if existed is not None:
+ return empty_pb2.Empty()
+ validate = get_ticket_helper(session).validate_ticket(request.ticket_uuid,
+ lambda ticket: ticket.details.uuid == request.uuid)
+ if not validate:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, f'ticket {request.ticket_uuid} is not validated')
+ pending_project = PendingProjectService(session).create_pending_project(
+ name=request.name,
+ config=request.config,
+ participants_info=request.participants_info,
+ comment=request.comment,
+ uuid=request.uuid,
+ creator_username=request.creator_username,
+ state=PendingProjectState.PENDING,
+ role=ProjectRole.PARTICIPANT)
+ pending_project.ticket_uuid = request.ticket_uuid
+ pending_project.ticket_status = TicketStatus.APPROVED
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.PENDING_PROJECT,
+ op_type=Event.OperationType.UPDATE,
+ resource_name_fn=lambda request: request.uuid)
+ def UpdatePendingProject(self, request: UpdatePendingProjectRequest, context: ServicerContext):
+ peer_pure_domain = get_pure_domain_from_context(context)
+ with db.session_scope() as session:
+ # we set isolation_level to SERIALIZABLE to make sure participants_info won't be changed within this session
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ pending_project: PendingProject = session.query(PendingProject).filter_by(uuid=request.uuid).first()
+ if pending_project is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'not found pending project uuid: {request.uuid}')
+ participants_map = pending_project.get_participants_info().participants_map
+ peer_info = pending_project.get_participant_info(peer_pure_domain)
+ if not peer_info or peer_info.role == ProjectRole.PARTICIPANT.name:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED,
+ f'{peer_pure_domain} is not coordinator in pending project {request.uuid}')
+ if _is_same_participants(participants_map, request.participants_map):
+ context.abort(grpc.StatusCode.PERMISSION_DENIED,
+ f'can not change participants when the pending project {request.uuid} has been approved')
+ participants_map.MergeFrom(request.participants_map)
+ pending_project.set_participants_info(ParticipantsInfo(participants_map=participants_map))
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.PENDING_PROJECT,
+ op_type=Event.OperationType.CONTROL_STATE,
+ resource_name_fn=lambda request: request.uuid)
+ def SyncPendingProjectState(self, request: SyncPendingProjectStateRequest, context: ServicerContext):
+ peer_pure_domain = get_pure_domain_from_context(context)
+ if request.state not in [PendingProjectState.ACCEPTED.name, PendingProjectState.CLOSED.name]:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED,
+ f'participant can only sync ACCEPTED or CLOSED but got: {request.state}')
+ with db.session_scope() as session:
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ pending_project: PendingProject = session.query(PendingProject).filter_by(uuid=request.uuid).first()
+ if pending_project is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'not found pending project uuid: {request.uuid}')
+ participants_info = pending_project.get_participants_info()
+ if peer_pure_domain not in participants_info.participants_map:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED,
+ f'{peer_pure_domain} is not in pending project {request.uuid}')
+ participants_info.participants_map[peer_pure_domain].state = request.state
+ pending_project.set_participants_info(participants_info)
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.PROJECT,
+ op_type=Event.OperationType.CREATE,
+ resource_name_fn=lambda request: request.uuid)
+ def CreateProject(self, request: CreateProjectRequest, context: ServicerContext):
+ with db.session_scope() as session:
+ pending_project: PendingProject = session.query(PendingProject).filter_by(uuid=request.uuid).first()
+
+ if pending_project is None:
+ message = f'failed to find pending project, uuid is {request.uuid}'
+ logging.error(message)
+ context.abort(grpc.StatusCode.NOT_FOUND, message)
+
+ if pending_project.state == PendingProjectState.CLOSED:
+ logging.info(f'{pending_project.uuid} pending project has closed')
+ return empty_pb2.Empty()
+
+ if pending_project.state != PendingProjectState.ACCEPTED:
+ message = f'{pending_project.uuid} pending project has not been accepted'
+ logging.info(message)
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, message)
+
+ project = session.query(Project).filter_by(name=pending_project.name).first()
+ if project is not None:
+ message = f'{pending_project.name} project has already existed, uuid is {pending_project.uuid}'
+ logging.error(message)
+ context.abort(grpc.StatusCode.ALREADY_EXISTS, message)
+
+ with db.session_scope() as session:
+ session.connection(execution_options={'isolation_level': 'SERIALIZABLE'})
+ PendingProjectService(session).create_project_locally(pending_project.uuid)
+ session.commit()
+ return empty_pb2.Empty()
+
+ @emits_rpc_event(resource_type=Event.ResourceType.PROJECT,
+ op_type=Event.OperationType.DELETE,
+ resource_name_fn=lambda request: request.uuid)
+ def DeletePendingProject(self, request: CreateProjectRequest, context: ServicerContext):
+ peer_pure_domain = get_pure_domain_from_context(context)
+ with db.session_scope() as session:
+ pending_project: PendingProject = session.query(PendingProject).filter_by(uuid=request.uuid).first()
+ if pending_project is None:
+ return empty_pb2.Empty()
+ if peer_pure_domain != pending_project.get_coordinator_info()[0]:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED,
+ f'{peer_pure_domain} is not coordinator in pending project {request.uuid}')
+ pending_project.deleted_at = now()
+ session.commit()
+ return empty_pb2.Empty()
+
+ def SendTemplateRevision(self, request: SendTemplateRevisionRequest, context: ServicerContext):
+ peer_pure_domain = get_pure_domain_from_context(context)
+ with db.session_scope() as session:
+ WorkflowTemplateRevisionService(session).create_revision(
+ name=request.name,
+ kind=request.kind,
+ config=request.config,
+ revision_index=request.revision_index,
+ comment=request.comment,
+ peer_pure_domain=peer_pure_domain,
+ )
+ session.commit()
+ return empty_pb2.Empty()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_server_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_server_test.py
new file mode 100644
index 000000000..914d6f2e6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/project_service_server_test.py
@@ -0,0 +1,268 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from concurrent import futures
+from unittest.mock import patch, MagicMock
+
+import grpc
+from google.protobuf.empty_pb2 import Empty
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import PendingProjectState, PendingProject, ProjectRole, Project
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.proto.rpc.v2 import project_service_pb2_grpc
+from fedlearner_webconsole.proto.rpc.v2.project_service_pb2 import CreatePendingProjectRequest, \
+ SyncPendingProjectStateRequest, UpdatePendingProjectRequest, CreateProjectRequest, DeletePendingProjectRequest, \
+ SendTemplateRevisionRequest
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.review import common
+from fedlearner_webconsole.rpc.v2.project_service_server import ProjectGrpcService, _is_same_participants
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplateKind, WorkflowTemplate, \
+ WorkflowTemplateRevision
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ProjectServiceTest(NoWebServerTestCase):
+ LISTEN_PORT = 2001
+
+ def setUp(self):
+ super().setUp()
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=20))
+ project_service_pb2_grpc.add_ProjectServiceServicer_to_server(ProjectGrpcService(), self._server)
+ self._server.add_insecure_port(f'[::]:{self.LISTEN_PORT}')
+ self._server.start()
+ self._channel = grpc.insecure_channel(target=f'localhost:{self.LISTEN_PORT}')
+ self._stub = project_service_pb2_grpc.ProjectServiceStub(self._channel)
+ self.participants_map = {
+ 'test':
+ ParticipantInfo(name='test', state=PendingProjectState.ACCEPTED.name,
+ role=ProjectRole.COORDINATOR.name),
+ 'part1':
+ ParticipantInfo(name='part', state=PendingProjectState.PENDING.name, role=ProjectRole.PARTICIPANT.name),
+ 'part2':
+ ParticipantInfo(name='part', state=PendingProjectState.PENDING.name, role=ProjectRole.PARTICIPANT.name)
+ }
+
+ def tearDown(self):
+ self._channel.close()
+ self._server.stop(5)
+ return super().tearDown()
+
+ @patch('fedlearner_webconsole.rpc.v2.project_service_server.get_ticket_helper')
+ def test_create_pending_project(self, mock_get_ticket_helper):
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test':
+ ParticipantInfo(
+ name='test', state=PendingProjectState.ACCEPTED.name, role=ProjectRole.COORDINATOR.name),
+ 'part':
+ ParticipantInfo(
+ name='part', state=PendingProjectState.PENDING.name, role=ProjectRole.PARTICIPANT.name)
+ })
+ request = CreatePendingProjectRequest(name='test-project',
+ uuid='test',
+ participants_info=participants_info,
+ ticket_uuid=common.NO_CENTRAL_SERVER_UUID)
+ mock_get_ticket_helper.return_value.validate_ticket.return_value = True
+ resp = self._stub.CreatePendingProject(request, None)
+ with db.session_scope() as session:
+ pending_project: PendingProject = session.query(PendingProject).filter_by(uuid='test').first()
+ self.assertEqual(resp, Empty())
+ self.assertEqual(pending_project.state, PendingProjectState.PENDING)
+ self.assertEqual(pending_project.role, ProjectRole.PARTICIPANT)
+ self.assertEqual(pending_project.get_participants_info(), participants_info)
+ self.assertEqual(pending_project.ticket_status, TicketStatus.APPROVED)
+ self.assertEqual(pending_project.ticket_uuid, common.NO_CENTRAL_SERVER_UUID)
+
+ @patch('fedlearner_webconsole.rpc.v2.project_service_server.get_ticket_helper')
+ def test_create_pending_project_ticket_wrong(self, mock_get_ticket_helper):
+ mock_get_ticket_helper.return_value.validate_ticket.return_value = False
+ participants_info = ParticipantsInfo(
+ participants_map={
+ 'test':
+ ParticipantInfo(
+ name='test', state=PendingProjectState.ACCEPTED.name, role=ProjectRole.COORDINATOR.name),
+ 'part':
+ ParticipantInfo(
+ name='part', state=PendingProjectState.PENDING.name, role=ProjectRole.PARTICIPANT.name)
+ })
+ request = CreatePendingProjectRequest(name='test-project',
+ uuid='test',
+ participants_info=participants_info,
+ ticket_uuid='wrong')
+ with self.assertRaisesRegex(grpc.RpcError, 'ticket wrong is not validated') as cm:
+ self._stub.CreatePendingProject(request, None)
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+ with db.session_scope() as session:
+ pending_project: PendingProject = session.query(PendingProject).filter_by(uuid='test').first()
+ self.assertIsNone(pending_project)
+
+ @patch('fedlearner_webconsole.rpc.v2.project_service_server.get_pure_domain_from_context')
+ def test_update_pending_project(self, mock_pure_domain):
+ participants_map = self.participants_map
+ participants_info = ParticipantsInfo(participants_map=participants_map)
+ pending_project = PendingProject(uuid='unique1',
+ name='test',
+ state=PendingProjectState.PENDING,
+ role=ProjectRole.PARTICIPANT)
+ pending_project.set_participants_info(participants_info)
+ with db.session_scope() as session:
+ session.add(pending_project)
+ session.commit()
+ mock_pure_domain.return_value = 'wrong coordinator'
+ participants_map['part2'].state = PendingProjectState.ACCEPTED.name
+ with self.assertRaisesRegex(grpc.RpcError,
+ 'wrong coordinator is not coordinator in pending project unique1') as cm:
+ self._stub.UpdatePendingProject(
+ UpdatePendingProjectRequest(uuid=pending_project.uuid, participants_map=participants_map))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+ mock_pure_domain.return_value = 'test'
+ resp = self._stub.UpdatePendingProject(
+ UpdatePendingProjectRequest(uuid=pending_project.uuid, participants_map=participants_map))
+ with db.session_scope() as session:
+ self.assertEqual(resp, Empty())
+ result: PendingProject = session.query(PendingProject).get(pending_project.id)
+ self.assertEqual(result.get_participants_info(), ParticipantsInfo(participants_map=participants_map))
+
+ @patch('fedlearner_webconsole.rpc.v2.project_service_server.get_pure_domain_from_context')
+ def test_sync_pending_project_state(self, mock_pure_domain):
+ # test participant sync to coordinator
+ participants_info = ParticipantsInfo(participants_map=self.participants_map)
+ pending_project = PendingProject(uuid='unique2',
+ name='test',
+ state=PendingProjectState.ACCEPTED,
+ role=ProjectRole.COORDINATOR)
+ pending_project.set_participants_info(participants_info)
+ with db.session_scope() as session:
+ session.add(pending_project)
+ session.commit()
+ mock_pure_domain.return_value = 'part1'
+ resp = self._stub.SyncPendingProjectState(
+ SyncPendingProjectStateRequest(uuid=pending_project.uuid, state='ACCEPTED'))
+ with db.session_scope() as session:
+ self.assertEqual(resp, Empty())
+ result: PendingProject = session.query(PendingProject).get(pending_project.id)
+ participants_info.participants_map['part1'].state = PendingProjectState.ACCEPTED.name
+ self.assertEqual(result.get_participants_info(), participants_info)
+
+ def test_is_same_participants(self):
+ self.assertFalse(
+ _is_same_participants(self.participants_map, {
+ 'test': ParticipantInfo(),
+ 'part1': ParticipantInfo(),
+ 'part2': ParticipantInfo()
+ }))
+ self.assertTrue(
+ _is_same_participants(self.participants_map, {
+ 'test': ParticipantInfo(),
+ 'part1': ParticipantInfo(),
+ 'part3': ParticipantInfo()
+ }))
+
+ @patch('fedlearner_webconsole.rpc.v2.project_service_server.PendingProjectService.create_project_locally')
+ def test_create_project(self, mock_create: MagicMock):
+ with db.session_scope() as session:
+ pending_project1 = PendingProject(uuid='test1',
+ id=1,
+ name='test project',
+ state=PendingProjectState.ACCEPTED)
+ pending_project2 = PendingProject(uuid='pending',
+ id=2,
+ name='test project 2',
+ state=PendingProjectState.PENDING)
+ pending_project3 = PendingProject(uuid='dup',
+ id=3,
+ name='test project 3',
+ state=PendingProjectState.ACCEPTED)
+ project = Project(name='test project 3')
+ session.add_all([pending_project1, pending_project2, pending_project3, project])
+ session.commit()
+ # successful
+ resp = self._stub.CreateProject(CreateProjectRequest(uuid='test1'))
+ self.assertEqual(resp, Empty())
+ mock_create.assert_called_once_with('test1')
+ # fail due to pending project not found
+ with self.assertRaisesRegex(grpc.RpcError, 'failed to find pending project, uuid is nothing') as cm:
+ self._stub.CreateProject(CreateProjectRequest(uuid='nothing'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+ mock_create.assert_not_called()
+ # fail due to state not valid
+ with self.assertRaisesRegex(grpc.RpcError, 'pending pending project has not been accepted') as cm:
+ self._stub.CreateProject(CreateProjectRequest(uuid='pending'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+ mock_create.assert_not_called()
+ # fail due to name duplicate
+ with self.assertRaisesRegex(grpc.RpcError, 'test project 3 project has already existed, uuid is dup') as cm:
+ self._stub.CreateProject(CreateProjectRequest(uuid='dup'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.ALREADY_EXISTS)
+ mock_create.assert_not_called()
+
+ @patch('fedlearner_webconsole.rpc.v2.project_service_server.get_pure_domain_from_context')
+ def test_delete_pending_project(self, mock_pure_domain: MagicMock):
+ with db.session_scope() as session:
+ pending_project1 = PendingProject(uuid='test1',
+ id=1,
+ name='test project',
+ state=PendingProjectState.ACCEPTED)
+ pending_project1.set_participants_info(ParticipantsInfo(participants_map=self.participants_map))
+ session.add(pending_project1)
+ session.commit()
+ mock_pure_domain.return_value = 'part1'
+ with self.assertRaisesRegex(grpc.RpcError, 'part1 is not coordinator in pending project test1') as cm:
+ self._stub.DeletePendingProject(DeletePendingProjectRequest(uuid='test1'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+ mock_pure_domain.return_value = 'test'
+ resp = self._stub.DeletePendingProject(DeletePendingProjectRequest(uuid='test1'))
+ self.assertEqual(resp, Empty())
+ with db.session_scope() as session:
+ self.assertIsNone(session.query(PendingProject).get(pending_project1.id))
+ resp = self._stub.DeletePendingProject(DeletePendingProjectRequest(uuid='test1'))
+ self.assertEqual(resp, Empty())
+
+ @patch('fedlearner_webconsole.rpc.v2.project_service_server.get_pure_domain_from_context')
+ def test_send_template_revision(self, mock_pure_domain: MagicMock):
+ mock_pure_domain.return_value = 'a'
+ self._stub.SendTemplateRevision(
+ SendTemplateRevisionRequest(config=WorkflowDefinition(group_alias='test'),
+ name='test',
+ revision_index=2,
+ comment='test comment',
+ kind=WorkflowTemplateKind.PEER.name))
+ self._stub.SendTemplateRevision(
+ SendTemplateRevisionRequest(config=WorkflowDefinition(group_alias='test', variables=[Variable()]),
+ name='test',
+ revision_index=3,
+ comment='test comment',
+ kind=WorkflowTemplateKind.PEER.name))
+ self._stub.SendTemplateRevision(
+ SendTemplateRevisionRequest(config=WorkflowDefinition(group_alias='test'),
+ name='test',
+ revision_index=1,
+ comment='test comment',
+ kind=WorkflowTemplateKind.PEER.name))
+ with db.session_scope() as session:
+ tpl = session.query(WorkflowTemplate).filter_by(name='test').first()
+ self.assertEqual(tpl.get_config(), WorkflowDefinition(group_alias='test', variables=[Variable()]))
+ self.assertEqual(tpl.coordinator_pure_domain_name, 'a')
+ self.assertEqual(tpl.kind, 2)
+ revisions = session.query(WorkflowTemplateRevision).filter_by(template_id=tpl.id).all()
+ self.assertEqual(sorted([r.revision_index for r in revisions]), [1, 2, 3])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_client.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_client.py
new file mode 100644
index 000000000..df4f9a848
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_client.py
@@ -0,0 +1,99 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from envs import Envs
+from typing import Iterable, Optional
+from google.protobuf import empty_pb2
+
+from fedlearner_webconsole.rpc.v2.client_base import ParticipantProjectRpcClient
+from fedlearner_webconsole.dataset.models import DatasetKindV2, ResourceState
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2_grpc import ResourceServiceStub
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import GetAlgorithmRequest, GetAlgorithmProjectRequest, \
+ InformDatasetRequest, ListAlgorithmProjectsRequest, ListAlgorithmProjectsResponse, ListAlgorithmsRequest, \
+ ListAlgorithmsResponse, GetAlgorithmFilesRequest, GetAlgorithmFilesResponse, ListDatasetsRequest, \
+ ListDatasetsResponse
+from fedlearner_webconsole.proto.dataset_pb2 import TimeRange
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmProjectPb, AlgorithmPb
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+
+
+def _need_retry_for_get(err: Exception) -> bool:
+ if not isinstance(err, grpc.RpcError):
+ return False
+ # No need to retry for NOT_FOUND and PERMISSION_DENIED
+ if err.code() == grpc.StatusCode.NOT_FOUND or err.code() == grpc.StatusCode.PERMISSION_DENIED:
+ return False
+ return True
+
+
+def _default_need_retry(err: Exception) -> bool:
+ return isinstance(err, grpc.RpcError)
+
+
+class ResourceServiceClient(ParticipantProjectRpcClient):
+
+ def __init__(self, channel: grpc.Channel):
+ super().__init__(channel)
+ self._stub: ResourceServiceStub = ResourceServiceStub(channel)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def list_algorithm_projects(self, filter_exp: Optional[FilterExpression] = None) -> ListAlgorithmProjectsResponse:
+ request = ListAlgorithmProjectsRequest(filter_exp=filter_exp)
+ return self._stub.ListAlgorithmProjects(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_algorithm_project(self, algorithm_project_uuid: str) -> AlgorithmProjectPb:
+ request = GetAlgorithmProjectRequest(algorithm_project_uuid=algorithm_project_uuid)
+ return self._stub.GetAlgorithmProject(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def list_algorithms(self, algorithm_project_uuid: str) -> ListAlgorithmsResponse:
+ request = ListAlgorithmsRequest(algorithm_project_uuid=algorithm_project_uuid)
+ return self._stub.ListAlgorithms(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_algorithm(self, algorithm_uuid: str) -> AlgorithmPb:
+ request = GetAlgorithmRequest(algorithm_uuid=algorithm_uuid)
+ return self._stub.GetAlgorithm(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_need_retry_for_get)
+ def get_algorithm_files(self, algorithm_uuid: str) -> Iterable[GetAlgorithmFilesResponse]:
+ request = GetAlgorithmFilesRequest(algorithm_uuid=algorithm_uuid)
+ return self._stub.GetAlgorithmFiles(request=request, timeout=Envs.GRPC_STREAM_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def inform_dataset(self, dataset_uuid: str, auth_status: AuthStatus) -> empty_pb2.Empty:
+ request = InformDatasetRequest(uuid=dataset_uuid, auth_status=auth_status.name)
+ return self._stub.InformDataset(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def list_datasets(self,
+ kind: Optional[DatasetKindV2] = None,
+ uuid: Optional[str] = None,
+ state: Optional[ResourceState] = None,
+ time_range: Optional[TimeRange] = None) -> ListDatasetsResponse:
+ request = ListDatasetsRequest()
+ if kind is not None:
+ request.kind = kind.name
+ if uuid is not None:
+ request.uuid = uuid
+ if state is not None:
+ request.state = state.name
+ if time_range is not None:
+ request.time_range.MergeFrom(time_range)
+ return self._stub.ListDatasets(request=request, timeout=Envs.GRPC_CLIENT_TIMEOUT)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_client_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_client_test.py
new file mode 100644
index 000000000..be6c725a5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_client_test.py
@@ -0,0 +1,185 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import unittest
+import grpc_testing
+from google.protobuf.descriptor import ServiceDescriptor
+from google.protobuf.empty_pb2 import Empty
+
+from testing.rpc.client import RpcClientTestCase
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.dataset.models import DatasetKindV2, ResourceState
+from fedlearner_webconsole.proto.rpc.v2 import resource_service_pb2
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmPb, AlgorithmProjectPb
+from fedlearner_webconsole.proto.dataset_pb2 import ParticipantDatasetRef, TimeRange
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import GetAlgorithmRequest, GetAlgorithmProjectRequest, \
+ InformDatasetRequest, ListAlgorithmProjectsRequest, ListAlgorithmProjectsResponse, ListAlgorithmsRequest, \
+ ListAlgorithmsResponse, GetAlgorithmFilesRequest, GetAlgorithmFilesResponse, ListDatasetsRequest, \
+ ListDatasetsResponse
+from fedlearner_webconsole.rpc.v2.resource_service_client import ResourceServiceClient
+
+_SERVER_DESCRIPTOR: ServiceDescriptor = resource_service_pb2.DESCRIPTOR.services_by_name['ResourceService']
+
+
+class ResourceServiceClientTest(RpcClientTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._fake_channel: grpc_testing.Channel = grpc_testing.channel([_SERVER_DESCRIPTOR],
+ grpc_testing.strict_real_time())
+ self._client = ResourceServiceClient(self._fake_channel)
+
+ def test_list_algorithm_projects(self):
+ call = self.client_execution_pool.submit(self._client.list_algorithm_projects)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['ListAlgorithmProjects'])
+ algorithm_projects = [
+ AlgorithmProjectPb(uuid='1', name='algo-project-1', type='NN_LOCAL', source='USER'),
+ AlgorithmProjectPb(uuid='2', name='algo-project-2', type='NN_VERTICAL', source='USER')
+ ]
+ expected_response = ListAlgorithmProjectsResponse(algorithm_projects=algorithm_projects)
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, ListAlgorithmProjectsRequest())
+ self.assertEqual(call.result(), expected_response)
+
+ def test_list_algorithms(self):
+ call = self.client_execution_pool.submit(self._client.list_algorithms, algorithm_project_uuid='1')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['ListAlgorithms'])
+ algorithms = [
+ AlgorithmPb(uuid='1', name='test-algo-1', version=1, type='NN_LOCAL', source='USER'),
+ AlgorithmPb(uuid='2', name='test-algo-2', version=2, type='NN_VERTICAL', source='USER')
+ ]
+ expected_response = ListAlgorithmsResponse(algorithms=algorithms)
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, ListAlgorithmsRequest(algorithm_project_uuid='1'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_algorithm_project(self):
+ call = self.client_execution_pool.submit(self._client.get_algorithm_project, algorithm_project_uuid='1')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['GetAlgorithmProject'])
+ expected_response = AlgorithmProjectPb(uuid='1', name='test-algo-project')
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetAlgorithmProjectRequest(algorithm_project_uuid='1'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_algorithm(self):
+ call = self.client_execution_pool.submit(self._client.get_algorithm, algorithm_uuid='1')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['GetAlgorithm'])
+ expected_response = AlgorithmPb(uuid='1', name='test-algo')
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetAlgorithmRequest(algorithm_uuid='1'))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_algorithm_files(self):
+ call = self.client_execution_pool.submit(self._client.get_algorithm_files, algorithm_uuid='1')
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_stream(
+ _SERVER_DESCRIPTOR.methods_by_name['GetAlgorithmFiles'])
+ resp = GetAlgorithmFilesResponse(hash='ac3ee699961c58ef80a78c2434efe0d0', chunk=b'')
+ rpc.send_response(resp)
+ rpc.send_response(resp)
+ rpc.terminate(
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, GetAlgorithmFilesRequest(algorithm_uuid='1'))
+ resps = list(call.result())
+ self.assertEqual(len(resps), 2)
+ for res in resps:
+ self.assertEqual(res, resp)
+
+ def test_inform_dataset(self):
+ call = self.client_execution_pool.submit(self._client.inform_dataset,
+ dataset_uuid='test dataset uuid',
+ auth_status=AuthStatus.AUTHORIZED)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['InformDataset'])
+ expected_response = Empty()
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, InformDatasetRequest(uuid='test dataset uuid',
+ auth_status=AuthStatus.AUTHORIZED.name))
+ self.assertEqual(call.result(), expected_response)
+
+ def test_list_datasets(self):
+ # test no args
+ call = self.client_execution_pool.submit(self._client.list_datasets)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['ListDatasets'])
+ participant_datasets_ref = [
+ ParticipantDatasetRef(uuid='dataset_1 uuid', project_id=1, name='dataset_1'),
+ ParticipantDatasetRef(uuid='dataset_2 uuid', project_id=1, name='dataset_2')
+ ]
+ expected_response = ListDatasetsResponse(participant_datasets=participant_datasets_ref)
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(request, ListDatasetsRequest())
+ self.assertEqual(call.result(), expected_response)
+ # test has args
+ call = self.client_execution_pool.submit(self._client.list_datasets,
+ kind=DatasetKindV2.RAW,
+ uuid='dataset_1 uuid',
+ state=ResourceState.SUCCEEDED,
+ time_range=TimeRange(days=1))
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVER_DESCRIPTOR.methods_by_name['ListDatasets'])
+ participant_datasets_ref = [ParticipantDatasetRef(uuid='dataset_1 uuid', project_id=1, name='dataset_1')]
+ expected_response = ListDatasetsResponse(participant_datasets=participant_datasets_ref)
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(
+ request,
+ ListDatasetsRequest(kind='RAW', uuid='dataset_1 uuid', state='SUCCEEDED', time_range=TimeRange(days=1)))
+ self.assertEqual(call.result(), expected_response)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_server.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_server.py
new file mode 100644
index 000000000..1f03d1056
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_server.py
@@ -0,0 +1,146 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from grpc import ServicerContext
+from google.protobuf import empty_pb2
+from typing import Iterable
+from datetime import timedelta
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.proto import remove_secrets
+from fedlearner_webconsole.proto.rpc.v2 import resource_service_pb2_grpc
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import GetAlgorithmRequest, GetAlgorithmProjectRequest, \
+ InformDatasetRequest, ListAlgorithmProjectsRequest, ListAlgorithmProjectsResponse, ListAlgorithmsRequest, \
+ ListAlgorithmsResponse, GetAlgorithmFilesRequest, GetAlgorithmFilesResponse, ListDatasetsRequest, \
+ ListDatasetsResponse
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmProjectPb, AlgorithmPb
+from fedlearner_webconsole.algorithm.service import AlgorithmProjectService, AlgorithmService
+from fedlearner_webconsole.algorithm.models import AlgorithmProject, Algorithm, PublishStatus
+from fedlearner_webconsole.algorithm.transmit.sender import AlgorithmSender
+from fedlearner_webconsole.dataset.models import Dataset, DatasetKindV2, ResourceState
+from fedlearner_webconsole.dataset.services import DatasetService
+from fedlearner_webconsole.dataset.auth_service import AuthService
+from fedlearner_webconsole.rpc.v2.utils import get_grpc_context_info, get_pure_domain_from_context
+from fedlearner_webconsole.audit.decorators import emits_rpc_event
+from fedlearner_webconsole.proto.audit_pb2 import Event
+
+
+class ResourceServiceServicer(resource_service_pb2_grpc.ResourceServiceServicer):
+
+ def ListAlgorithmProjects(self, request: ListAlgorithmProjectsRequest,
+ context: ServicerContext) -> ListAlgorithmProjectsResponse:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ algo_projects = AlgorithmProjectService(session).get_published_algorithm_projects(
+ project_id=project_id, filter_exp=request.filter_exp)
+ algorithm_projects = []
+ for algo_project in algo_projects:
+ algo_project.updated_at = AlgorithmProjectService(session).get_published_algorithms_latest_update_time(
+ algo_project.id)
+ algorithm_projects.append(remove_secrets(algo_project.to_proto()))
+ return ListAlgorithmProjectsResponse(algorithm_projects=algorithm_projects)
+
+ def ListAlgorithms(self, request: ListAlgorithmsRequest, context: ServicerContext) -> ListAlgorithmsResponse:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ algorithm_project: AlgorithmProject = session.query(AlgorithmProject). \
+ filter_by(project_id=project_id,
+ uuid=request.algorithm_project_uuid).first()
+ if algorithm_project is None:
+ context.abort(grpc.StatusCode.NOT_FOUND,
+ f'algorithm_project uuid: {request.algorithm_project_uuid} not found')
+ algos = AlgorithmService(session).get_published_algorithms(project_id=project_id,
+ algorithm_project_id=algorithm_project.id)
+ algorithms = []
+ for algo in algos:
+ algorithms.append(remove_secrets(algo.to_proto()))
+ return ListAlgorithmsResponse(algorithms=algorithms)
+
+ def GetAlgorithmProject(self, request: GetAlgorithmProjectRequest, context: ServicerContext) -> AlgorithmProjectPb:
+ with db.session_scope() as session:
+ algo_project: AlgorithmProject = session.query(AlgorithmProject).filter_by(
+ uuid=request.algorithm_project_uuid).first()
+ if algo_project is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'algorithm project uuid:'
+ f' {request.algorithm_project_uuid} not found')
+ if algo_project.publish_status != PublishStatus.PUBLISHED:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, f'algorithm project uuid:'
+ f' {request.algorithm_project_uuid} is not published')
+ return remove_secrets(algo_project.to_proto())
+
+ def GetAlgorithm(self, request: GetAlgorithmRequest, context: ServicerContext) -> AlgorithmPb:
+ with db.session_scope() as session:
+ algorithm: Algorithm = session.query(Algorithm).filter_by(uuid=request.algorithm_uuid).first()
+ if algorithm is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'algorithm uuid: {request.algorithm_uuid} not found')
+ if algorithm.publish_status != PublishStatus.PUBLISHED:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, f'algorithm uuid: {request.algorithm_uuid} '
+ f'is not published')
+ return remove_secrets(algorithm.to_proto())
+
+ def GetAlgorithmFiles(self, request: GetAlgorithmFilesRequest,
+ context: ServicerContext) -> Iterable[GetAlgorithmFilesResponse]:
+ with db.session_scope() as session:
+ algorithm: Algorithm = session.query(Algorithm).filter_by(uuid=request.algorithm_uuid).first()
+ if algorithm is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'algorithm uuid: {request.algorithm_uuid} not found')
+ if algorithm.publish_status != PublishStatus.PUBLISHED:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED, f'algorithm uuid: {request.algorithm_uuid} '
+ f'is not published')
+ yield from AlgorithmSender().make_algorithm_iterator(algorithm.path)
+
+ @emits_rpc_event(resource_type=Event.ResourceType.DATASET,
+ op_type=Event.OperationType.INFORM,
+ resource_name_fn=lambda request: request.uuid)
+ def InformDataset(self, request: InformDatasetRequest, context: ServicerContext) -> empty_pb2.Empty:
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ participant_pure_domain = get_pure_domain_from_context(context)
+ dataset: Dataset = session.query(Dataset).populate_existing().with_for_update().filter_by(
+ project_id=project_id, uuid=request.uuid).first()
+ if dataset is None:
+ context.abort(grpc.StatusCode.NOT_FOUND, f'dataset {request.uuid} not found')
+ try:
+ AuthStatus[request.auth_status]
+ except KeyError:
+ context.abort(grpc.StatusCode.INVALID_ARGUMENT, f'auth_status {request.auth_status} is invalid')
+
+ participants_info = dataset.get_participants_info()
+ if participant_pure_domain not in participants_info.participants_map:
+ context.abort(grpc.StatusCode.PERMISSION_DENIED,
+ f'{participant_pure_domain} is not participant of dataset {request.uuid}')
+ AuthService(session=session, dataset_job=dataset.parent_dataset_job).update_auth_status(
+ domain_name=participant_pure_domain, auth_status=AuthStatus[request.auth_status])
+ session.commit()
+ return empty_pb2.Empty()
+
+ def ListDatasets(self, request: ListDatasetsRequest, context: ServicerContext) -> ListDatasetsResponse:
+ kind = DatasetKindV2[request.kind] if request.kind else None
+ uuid = request.uuid if request.uuid else None
+ state = ResourceState[request.state] if request.state else None
+ time_range = timedelta(days=request.time_range.days, hours=request.time_range.hours)
+ # set time_range to None if time_range is empty
+ if not time_range:
+ time_range = None
+ with db.session_scope() as session:
+ project_id, _ = get_grpc_context_info(session, context)
+ datasets = DatasetService(session=session).get_published_datasets(project_id=project_id,
+ kind=kind,
+ uuid=uuid,
+ state=state,
+ time_range=time_range)
+ return ListDatasetsResponse(participant_datasets=datasets)
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_server_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_server_test.py
new file mode 100644
index 000000000..adfce842b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/resource_service_server_test.py
@@ -0,0 +1,438 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import unittest
+from unittest.mock import ANY, patch, MagicMock
+from concurrent import futures
+from datetime import datetime, timedelta
+from google.protobuf.json_format import ParseDict
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.rpc.v2.resource_service_server import ResourceServiceServicer
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.algorithm.models import AlgorithmProject, Algorithm, AlgorithmType, Source, PublishStatus,\
+ AlgorithmParameter
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobKind, DatasetJobState, DatasetKindV2,\
+ DatasetType, ResourceState
+from fedlearner_webconsole.utils.filtering import parse_expression
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.flask_utils import to_dict
+from fedlearner_webconsole.proto.rpc.v2 import resource_service_pb2_grpc
+from fedlearner_webconsole.proto.dataset_pb2 import TimeRange
+from fedlearner_webconsole.proto.rpc.v2.resource_service_pb2 import InformDatasetRequest, \
+ ListAlgorithmProjectsRequest, ListAlgorithmsRequest, GetAlgorithmRequest, GetAlgorithmFilesRequest, \
+ GetAlgorithmProjectRequest, ListDatasetsRequest
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo, ParticipantsInfo
+
+
+class ResourceServiceTest(NoWebServerTestCase):
+ LISTEN_PORT = 1989
+
+ def setUp(self):
+ super().setUp()
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=20))
+ resource_service_pb2_grpc.add_ResourceServiceServicer_to_server(ResourceServiceServicer(), self._server)
+ self._server.add_insecure_port(f'[::]:{self.LISTEN_PORT}')
+ self._server.start()
+ self._channel = grpc.insecure_channel(target=f'localhost:{self.LISTEN_PORT}')
+ self._stub = resource_service_pb2_grpc.ResourceServiceStub(self._channel)
+
+ with db.session_scope() as session:
+ project1 = Project(id=1, name='project-1')
+ project2 = Project(id=2, name='project-2')
+ algo_project1 = AlgorithmProject(id=1,
+ project_id=1,
+ uuid='algo-project-uuid-1',
+ name='algo-project-1',
+ type=AlgorithmType.NN_LOCAL,
+ source=Source.USER,
+ latest_version=1,
+ publish_status=PublishStatus.PUBLISHED,
+ comment='comment-1',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ algo_project2 = AlgorithmProject(id=2,
+ project_id=1,
+ uuid='algo-project-uuid-2',
+ name='algo-project-2',
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.THIRD_PARTY,
+ latest_version=2,
+ publish_status=PublishStatus.PUBLISHED,
+ comment='comment-2',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ algo_project3 = AlgorithmProject(id=3,
+ project_id=1,
+ uuid='algo-project-uuid-3',
+ name='algo-project-3',
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ latest_version=3,
+ publish_status=PublishStatus.UNPUBLISHED,
+ comment='comment-3',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ algo_project4 = AlgorithmProject(id=4,
+ project_id=2,
+ uuid='algo-project-uuid-4',
+ name='algo-project-4',
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ latest_version=4,
+ publish_status=PublishStatus.UNPUBLISHED,
+ comment='comment-4',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ algo1 = Algorithm(id=1,
+ algorithm_project_id=1,
+ project_id=1,
+ name='algo-1',
+ uuid='algo-uuid-1',
+ version=1,
+ publish_status=PublishStatus.PUBLISHED,
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ comment='comment-1',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 15, 12, 0, 5))
+ algo2 = Algorithm(id=2,
+ algorithm_project_id=1,
+ project_id=1,
+ name='algo-2',
+ uuid='algo-uuid-2',
+ version=2,
+ publish_status=PublishStatus.PUBLISHED,
+ type=AlgorithmType.NN_LOCAL,
+ source=Source.THIRD_PARTY,
+ comment='comment-2',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 16, 12, 0, 5))
+ algo3 = Algorithm(id=3,
+ algorithm_project_id=1,
+ project_id=1,
+ name='algo-3',
+ uuid='algo-uuid-3',
+ version=3,
+ publish_status=PublishStatus.UNPUBLISHED,
+ type=AlgorithmType.TREE_VERTICAL,
+ source=Source.UNSPECIFIED,
+ comment='comment-3',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ algo4 = Algorithm(id=4,
+ algorithm_project_id=2,
+ project_id=1,
+ name='algo-4',
+ uuid='algo-uuid-4',
+ version=4,
+ publish_status=PublishStatus.PUBLISHED,
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ comment='comment-4',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ algo5 = Algorithm(id=5,
+ algorithm_project_id=3,
+ project_id=1,
+ name='algo-5',
+ uuid='algo-uuid-5',
+ version=5,
+ publish_status=PublishStatus.UNPUBLISHED,
+ type=AlgorithmType.NN_VERTICAL,
+ source=Source.USER,
+ comment='comment-5',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ algo6 = Algorithm(id=6,
+ algorithm_project_id=4,
+ project_id=2,
+ name='algo-6',
+ uuid='algo-uuid-6',
+ version=4,
+ publish_status=PublishStatus.PUBLISHED,
+ type=AlgorithmType.NN_LOCAL,
+ source=Source.THIRD_PARTY,
+ comment='comment-6',
+ created_at=datetime(2012, 1, 14, 12, 0, 5),
+ updated_at=datetime(2012, 1, 14, 12, 0, 5))
+ parameter = ParseDict({'variables': [{'name': 'BATCH_SIZE', 'value': '128'}]}, AlgorithmParameter())
+ algo1.set_parameter(parameter)
+ session.add_all([
+ project1, project2, algo_project1, algo_project2, algo_project3, algo_project4, algo1, algo2, algo3,
+ algo4, algo5, algo6
+ ])
+ session.commit()
+
+ def tearDown(self):
+ self._channel.close()
+ self._server.stop(5)
+ return super().tearDown()
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_list_algorithm_projects(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ resp = self._stub.ListAlgorithmProjects(ListAlgorithmProjectsRequest())
+ algorithm_projects = sorted(resp.algorithm_projects, key=lambda x: x.uuid)
+ self.assertEqual(len(algorithm_projects), 2)
+ self.assertEqual(algorithm_projects[0].uuid, 'algo-project-uuid-1')
+ self.assertEqual(algorithm_projects[0].type, 'NN_LOCAL')
+ self.assertEqual(algorithm_projects[0].source, 'USER')
+ self.assertEqual(algorithm_projects[0].updated_at, to_timestamp(datetime(2012, 1, 16, 12, 0, 5)))
+ self.assertEqual(algorithm_projects[1].uuid, 'algo-project-uuid-2')
+ self.assertEqual(algorithm_projects[1].type, 'NN_VERTICAL')
+ self.assertEqual(algorithm_projects[1].source, 'THIRD_PARTY')
+ self.assertEqual(algorithm_projects[1].updated_at, to_timestamp(datetime(2012, 1, 14, 12, 0, 5)))
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_list_algorithm_project_with_filter_exp(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ filter_exp = parse_expression('(name~="1")')
+ resp = self._stub.ListAlgorithmProjects(ListAlgorithmProjectsRequest(filter_exp=filter_exp))
+ self.assertEqual(len(resp.algorithm_projects), 1)
+ algo_project = resp.algorithm_projects[0]
+ self.assertEqual(algo_project.uuid, 'algo-project-uuid-1')
+ self.assertEqual(algo_project.type, 'NN_LOCAL')
+ self.assertEqual(algo_project.source, 'USER')
+ filter_exp = parse_expression('(type:["NN_VERTICAL"])')
+ resp = self._stub.ListAlgorithmProjects(ListAlgorithmProjectsRequest(filter_exp=filter_exp))
+ self.assertEqual((len(resp.algorithm_projects)), 1)
+ algo_project = resp.algorithm_projects[0]
+ self.assertEqual(algo_project.uuid, 'algo-project-uuid-2')
+ self.assertEqual(algo_project.source, 'THIRD_PARTY')
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_list_algorithms(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ resp = self._stub.ListAlgorithms(ListAlgorithmsRequest(algorithm_project_uuid='algo-project-uuid-1'))
+ algorithms = resp.algorithms
+ self.assertEqual(len(algorithms), 2)
+ self.assertEqual(algorithms[0].uuid, 'algo-uuid-1')
+ self.assertEqual(algorithms[1].uuid, 'algo-uuid-2')
+ self.assertEqual(algorithms[0].type, 'NN_VERTICAL')
+ self.assertEqual(algorithms[1].source, 'THIRD_PARTY')
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_list_algorithms_with_wrong_algorithm_project_uuid(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.ListAlgorithms(ListAlgorithmsRequest(algorithm_project_uuid='algo-project-uuid-5'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_get_algorithm_project(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ resp = self._stub.GetAlgorithmProject(GetAlgorithmProjectRequest(algorithm_project_uuid='algo-project-uuid-1'))
+ self.assertEqual(resp.type, 'NN_LOCAL')
+ self.assertEqual(resp.source, 'USER')
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_get_algorithm_project_with_wrong_uuid(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with self.assertRaises(grpc.RpcError) as cm:
+ resp = self._stub.GetAlgorithmProject(GetAlgorithmProjectRequest(algorithm_project_uuid='1'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_get_unpublished_algorithm_project(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with self.assertRaises(grpc.RpcError) as cm:
+ resp = self._stub.GetAlgorithmProject(
+ GetAlgorithmProjectRequest(algorithm_project_uuid='algo-project-uuid-3'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+
+ def test_get_algorithm_files_with_wrong_algorithm_uuid(self):
+ with self.assertRaises(grpc.RpcError) as cm:
+ resp = self._stub.GetAlgorithmFiles(GetAlgorithmFilesRequest(algorithm_uuid='algo-uuid-7'))
+ # Grpc error cannot be thrown if no iterating when the rpc is streaming
+ for _ in resp:
+ pass
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+
+ def test_get_algorithm_with_unpublished_algorithm(self):
+ with self.assertRaises(grpc.RpcError) as cm:
+ resp = self._stub.GetAlgorithm(GetAlgorithmRequest(algorithm_uuid='algo-uuid-3'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.PERMISSION_DENIED)
+
+ def test_get_algorithm(self):
+ resp = self._stub.GetAlgorithm(GetAlgorithmRequest(algorithm_uuid='algo-uuid-1'))
+ self.assertEqual(resp.uuid, 'algo-uuid-1')
+ self.assertEqual(resp.name, 'algo-1')
+ self.assertEqual(resp.type, 'NN_VERTICAL')
+ self.assertEqual(resp.source, 'USER')
+ self.assertEqual(resp.version, 1)
+
+ def test_get_algorithm_files(self):
+ data_iterator = self._stub.GetAlgorithmFiles(GetAlgorithmFilesRequest(algorithm_uuid='algo-uuid-1'))
+ resps = list(data_iterator)
+ self.assertEqual(len(resps), 1)
+ self.assertEqual(resps[0].hash, 'd41d8cd98f00b204e9800998ecf8427e')
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_pure_domain_from_context')
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_inform_dataset(self, mock_get_grpc_context_info: MagicMock, mock_get_pure_domain_from_context: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ participant_domain_name = 'test participant'
+ mock_get_pure_domain_from_context.return_value = participant_domain_name
+ # test not_found
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformDataset(InformDatasetRequest(uuid='dataset uuid', auth_status='AUTHORIZED'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.NOT_FOUND)
+
+ # test invalidate participant
+ with db.session_scope() as session:
+ dataset = Dataset(id=1,
+ uuid='dataset uuid',
+ name='default dataset',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True,
+ auth_status=AuthStatus.PENDING)
+ session.add(dataset)
+ dataset_job = DatasetJob(id=1,
+ uuid='dataset_job uuid',
+ project_id=1,
+ input_dataset_id=0,
+ output_dataset_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=1)
+ session.add(dataset_job)
+ session.commit()
+ with self.assertRaises(grpc.RpcError) as cm:
+ self._stub.InformDataset(InformDatasetRequest(uuid='dataset uuid', auth_status='AUTHORIZED'))
+ self.assertEqual(cm.exception.code(), grpc.StatusCode.INVALID_ARGUMENT)
+
+ # test pass
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(1)
+ participants_info = ParticipantsInfo(
+ participants_map={participant_domain_name: ParticipantInfo(auth_status='PENDING')})
+ dataset.set_participants_info(participants_info=participants_info)
+ session.commit()
+ self._stub.InformDataset(InformDatasetRequest(uuid='dataset uuid', auth_status='AUTHORIZED'))
+ with db.session_scope() as session:
+ dataset: Dataset = session.query(Dataset).get(1)
+ participants_info = dataset.get_participants_info()
+ self.assertEqual(participants_info.participants_map[participant_domain_name].auth_status, 'AUTHORIZED')
+
+ @patch('fedlearner_webconsole.rpc.v2.resource_service_server.get_grpc_context_info')
+ def test_list_datasets(self, mock_get_grpc_context_info: MagicMock):
+ mock_get_grpc_context_info.return_value = 1, 1
+ with db.session_scope() as session:
+ dataset_job_1 = DatasetJob(id=1,
+ uuid='dataset_job_1',
+ project_id=1,
+ input_dataset_id=0,
+ output_dataset_id=1,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.PENDING,
+ coordinator_id=1)
+ session.add(dataset_job_1)
+ dataset_1 = Dataset(id=1,
+ uuid='dataset_1 uuid',
+ name='default dataset 1',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.RAW,
+ is_published=True)
+ session.add(dataset_1)
+ dataset_job_2 = DatasetJob(id=2,
+ uuid='dataset_job_2',
+ project_id=1,
+ input_dataset_id=0,
+ output_dataset_id=2,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN,
+ state=DatasetJobState.SUCCEEDED,
+ coordinator_id=1,
+ time_range=timedelta(days=1))
+ session.add(dataset_job_2)
+ dataset_2 = Dataset(id=2,
+ uuid='dataset_2 uuid',
+ name='default dataset 2',
+ dataset_type=DatasetType.PSI,
+ comment='test comment',
+ path='/data/dataset/123',
+ project_id=1,
+ dataset_kind=DatasetKindV2.PROCESSED,
+ is_published=True)
+ session.add(dataset_2)
+ session.commit()
+
+ # test no filter
+ expected_response = {
+ 'participant_datasets': [{
+ 'uuid': 'dataset_2 uuid',
+ 'name': 'default dataset 2',
+ 'format': 'TABULAR',
+ 'updated_at': ANY,
+ 'dataset_kind': 'PROCESSED',
+ 'dataset_type': 'PSI',
+ 'auth_status': 'PENDING',
+ 'project_id': ANY,
+ 'participant_id': ANY,
+ 'file_size': 0,
+ 'value': 0
+ }, {
+ 'uuid': 'dataset_1 uuid',
+ 'name': 'default dataset 1',
+ 'format': 'TABULAR',
+ 'updated_at': ANY,
+ 'dataset_kind': 'RAW',
+ 'dataset_type': 'PSI',
+ 'auth_status': 'PENDING',
+ 'project_id': ANY,
+ 'participant_id': ANY,
+ 'file_size': 0,
+ 'value': 0
+ }]
+ }
+ resp = self._stub.ListDatasets(ListDatasetsRequest())
+ self.assertEqual(to_dict(resp), expected_response)
+
+ # test with filter
+ expected_response = {
+ 'participant_datasets': [{
+ 'uuid': 'dataset_2 uuid',
+ 'name': 'default dataset 2',
+ 'format': 'TABULAR',
+ 'updated_at': ANY,
+ 'dataset_kind': 'PROCESSED',
+ 'dataset_type': 'PSI',
+ 'auth_status': 'PENDING',
+ 'project_id': ANY,
+ 'participant_id': ANY,
+ 'file_size': 0,
+ 'value': 0
+ }]
+ }
+ resp = self._stub.ListDatasets(
+ ListDatasetsRequest(uuid='dataset_2 uuid',
+ kind=DatasetKindV2.PROCESSED.name,
+ state=ResourceState.SUCCEEDED.name,
+ time_range=TimeRange(days=1)))
+ self.assertEqual(to_dict(resp), expected_response)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/review_service_client.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/review_service_client.py
new file mode 100644
index 000000000..9dee0dd2e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/review_service_client.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from fedlearner_webconsole.proto.rpc.v2 import review_service_pb2
+from fedlearner_webconsole.proto.rpc.v2.review_service_pb2_grpc import ReviewServiceStub
+from fedlearner_webconsole.proto import review_pb2
+from fedlearner_webconsole.rpc.v2.client_base import ParticipantRpcClient
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+
+
+def _default_need_retry(err: Exception) -> bool:
+ return isinstance(err, grpc.RpcError)
+
+
+class ReviewServiceClient(ParticipantRpcClient):
+
+ def __init__(self, channel: grpc.Channel):
+ super().__init__(channel)
+ self._stub: ReviewServiceStub = ReviewServiceStub(channel)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def create_ticket(self, ttype: review_pb2.TicketType, creator_username: str,
+ details: review_pb2.TicketDetails) -> review_pb2.Ticket:
+ return self._stub.CreateTicket(
+ review_service_pb2.CreateTicketRequest(
+ ttype=ttype,
+ creator_username=creator_username,
+ details=details,
+ ))
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def get_ticket(self, uuid: str) -> review_pb2.Ticket:
+ return self._stub.GetTicket(review_service_pb2.GetTicketRequest(uuid=uuid))
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/review_service_client_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/review_service_client_test.py
new file mode 100644
index 000000000..9604efee1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/review_service_client_test.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import grpc
+import grpc_testing
+
+from google.protobuf.descriptor import ServiceDescriptor
+from fedlearner_webconsole.proto.rpc.v2 import review_service_pb2
+from fedlearner_webconsole.proto import review_pb2
+
+from fedlearner_webconsole.rpc.v2.review_service_client import ReviewServiceClient
+from testing.rpc.client import RpcClientTestCase
+
+_SERVICE_DESCRIPTOR: ServiceDescriptor = review_service_pb2.DESCRIPTOR.services_by_name['ReviewService']
+
+
+class ReviewServiceClientTest(RpcClientTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._fake_channel: grpc_testing.Channel = grpc_testing.channel([_SERVICE_DESCRIPTOR],
+ grpc_testing.strict_real_time())
+ self._client = ReviewServiceClient(self._fake_channel)
+
+ def test_check_health(self):
+ call = self.client_execution_pool.submit(self._client.create_ticket,
+ ttype=review_pb2.TicketType.CREATE_PROJECT,
+ creator_username='fffff',
+ details=review_pb2.TicketDetails(uuid='u1234'))
+
+ _, _, rpc = self._fake_channel.take_unary_unary(_SERVICE_DESCRIPTOR.methods_by_name['CreateTicket'])
+
+ expected_response = review_pb2.Ticket(
+ type=review_pb2.TicketType.CREATE_PROJECT,
+ creator_username='fffff',
+ details=review_pb2.TicketDetails(uuid='u1234'),
+ uuid='u4321',
+ )
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(call.result(), expected_response)
+
+ def test_get_ticket(self):
+ call = self.client_execution_pool.submit(self._client.get_ticket, uuid='u4321')
+
+ _, _, rpc = self._fake_channel.take_unary_unary(_SERVICE_DESCRIPTOR.methods_by_name['GetTicket'])
+
+ expected_response = review_pb2.Ticket(
+ type=review_pb2.TicketType.CREATE_PROJECT,
+ creator_username='fffff',
+ details=review_pb2.TicketDetails(uuid='u1234'),
+ uuid='u4321',
+ )
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(call.result(), expected_response)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_client.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_client.py
new file mode 100644
index 000000000..c2b7a58ce
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_client.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+from fedlearner_webconsole.proto.rpc.v2.system_service_pb2 import (CheckHealthRequest, CheckHealthResponse,
+ ListFlagsRequest, ListFlagsResponse,
+ CheckTeeEnabledRequest, CheckTeeEnabledResponse)
+from fedlearner_webconsole.proto.rpc.v2.system_service_pb2_grpc import SystemServiceStub
+from fedlearner_webconsole.rpc.v2.client_base import ParticipantRpcClient
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+
+
+def _default_need_retry(err: Exception) -> bool:
+ return isinstance(err, grpc.RpcError)
+
+
+class SystemServiceClient(ParticipantRpcClient):
+
+ def __init__(self, channel: grpc.Channel):
+ super().__init__(channel)
+ self._stub: SystemServiceStub = SystemServiceStub(channel)
+
+ def check_health(self) -> CheckHealthResponse:
+ try:
+ return self._stub.CheckHealth(CheckHealthRequest())
+ except grpc.RpcError as e:
+ # For health check, we don't throw grpc error directly
+ return CheckHealthResponse(
+ healthy=False,
+ message=e.details(),
+ )
+
+ def list_flags(self) -> dict:
+ response: ListFlagsResponse = self._stub.ListFlags(ListFlagsRequest())
+ return to_dict(response.flags)
+
+ @retry_fn(retry_times=3, need_retry=_default_need_retry)
+ def check_tee_enabled(self) -> CheckTeeEnabledResponse:
+ return self._stub.CheckTeeEnabled(CheckTeeEnabledRequest())
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_client_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_client_test.py
new file mode 100644
index 000000000..b1bf1e36d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_client_test.py
@@ -0,0 +1,118 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import grpc
+import grpc_testing
+
+from google.protobuf.descriptor import ServiceDescriptor
+from google.protobuf.struct_pb2 import Struct, Value
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.rpc.v2 import system_service_pb2
+
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+from testing.rpc.client import RpcClientTestCase
+
+_SERVICE_DESCRIPTOR: ServiceDescriptor = system_service_pb2.DESCRIPTOR.services_by_name['SystemService']
+
+
+class SystemServiceClientTest(RpcClientTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._fake_channel: grpc_testing.Channel = grpc_testing.channel([_SERVICE_DESCRIPTOR],
+ grpc_testing.strict_real_time())
+ self._client = SystemServiceClient(self._fake_channel)
+
+ def test_check_health(self):
+ call = self.client_execution_pool.submit(self._client.check_health)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CheckHealth'])
+
+ expected_response = system_service_pb2.CheckHealthResponse(
+ application_version=common_pb2.ApplicationVersion(
+ revision='test rev',
+ branch_name='test branch',
+ version='1.0.0.1',
+ pub_date='20221212',
+ ),
+ healthy=True,
+ )
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(call.result(), expected_response)
+
+ def test_check_health_rpc_error(self):
+ call = self.client_execution_pool.submit(self._client.check_health)
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CheckHealth'])
+ rpc.terminate(
+ response=None,
+ code=grpc.StatusCode.UNKNOWN,
+ trailing_metadata=(),
+ details='unknown server error',
+ )
+ self.assertEqual(call.result(),
+ system_service_pb2.CheckHealthResponse(
+ healthy=False,
+ message='unknown server error',
+ ))
+
+ def test_list_flags(self):
+ call = self.client_execution_pool.submit(self._client.list_flags)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['ListFlags'])
+
+ expected_response = system_service_pb2.ListFlagsResponse(flags=Struct(
+ fields={
+ 'flag1': Value(bool_value=True),
+ 'flag2': Value(string_value='string_value'),
+ }))
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(call.result(), {
+ 'flag1': True,
+ 'flag2': 'string_value',
+ })
+
+ def test_check_tee_enabled(self):
+ call = self.client_execution_pool.submit(self._client.check_tee_enabled)
+
+ invocation_metadata, request, rpc = self._fake_channel.take_unary_unary(
+ _SERVICE_DESCRIPTOR.methods_by_name['CheckTeeEnabled'])
+
+ expected_response = system_service_pb2.CheckTeeEnabledResponse(tee_enabled=True)
+ rpc.terminate(
+ response=expected_response,
+ code=grpc.StatusCode.OK,
+ trailing_metadata=(),
+ details=None,
+ )
+ self.assertEqual(call.result(), expected_response)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_server.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_server.py
new file mode 100644
index 000000000..387cf06d0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_server.py
@@ -0,0 +1,42 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from grpc import ServicerContext
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.flag.models import get_flags
+from fedlearner_webconsole.proto.rpc.v2 import system_service_pb2_grpc
+from fedlearner_webconsole.proto.rpc.v2.system_service_pb2 import (CheckHealthResponse, CheckHealthRequest,
+ ListFlagsRequest, ListFlagsResponse,
+ CheckTeeEnabledRequest, CheckTeeEnabledResponse)
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.tee.services import check_tee_enabled
+
+
+# TODO(linfan.fine): adds request id decorator
+class SystemGrpcService(system_service_pb2_grpc.SystemServiceServicer):
+
+ def CheckHealth(self, request: CheckHealthRequest, context: ServicerContext):
+ with db.session_scope() as session:
+ version = SettingService(session).get_application_version()
+ return CheckHealthResponse(application_version=version.to_proto(), healthy=True)
+
+ def ListFlags(self, request: ListFlagsRequest, context: ServicerContext):
+ resp = ListFlagsResponse()
+ resp.flags.update(get_flags())
+ return resp
+
+ def CheckTeeEnabled(self, request: CheckTeeEnabledRequest, context: ServicerContext):
+ return CheckTeeEnabledResponse(tee_enabled=check_tee_enabled())
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_server_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_server_test.py
new file mode 100644
index 000000000..83b85fedf
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/system_service_server_test.py
@@ -0,0 +1,105 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from concurrent import futures
+from unittest.mock import patch, Mock
+
+import grpc
+
+from google.protobuf.struct_pb2 import Struct, Value
+
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.rpc.v2 import system_service_pb2_grpc
+from fedlearner_webconsole.proto.rpc.v2.system_service_pb2 import CheckHealthRequest, CheckHealthResponse, \
+ ListFlagsRequest, ListFlagsResponse, CheckTeeEnabledRequest, CheckTeeEnabledResponse
+from fedlearner_webconsole.rpc.v2.system_service_server import SystemGrpcService
+from fedlearner_webconsole.utils.app_version import ApplicationVersion
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class SystemServiceTest(NoWebServerTestCase):
+ LISTEN_PORT = 1999
+
+ def setUp(self):
+ super().setUp()
+ self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=20))
+ system_service_pb2_grpc.add_SystemServiceServicer_to_server(SystemGrpcService(), self._server)
+ self._server.add_insecure_port(f'[::]:{self.LISTEN_PORT}')
+ self._server.start()
+
+ self._stub = system_service_pb2_grpc.SystemServiceStub(
+ grpc.insecure_channel(target=f'localhost:{self.LISTEN_PORT}'))
+
+ def tearDown(self):
+ self._server.stop(5)
+ return super().tearDown()
+
+ @patch('fedlearner_webconsole.rpc.v2.system_service_server.SettingService.get_application_version')
+ def test_check_health(self, mock_get_application_version: Mock):
+ mock_get_application_version.return_value = ApplicationVersion(
+ revision='test rev',
+ branch_name='test branch',
+ version='1.0.0.1',
+ pub_date='20220101',
+ )
+
+ resp = self._stub.CheckHealth(CheckHealthRequest())
+ self.assertEqual(
+ resp,
+ CheckHealthResponse(
+ application_version=common_pb2.ApplicationVersion(
+ revision='test rev',
+ branch_name='test branch',
+ version='1.0.0.1',
+ pub_date='20220101',
+ ),
+ healthy=True,
+ ))
+
+ @patch('fedlearner_webconsole.rpc.v2.system_service_server.get_flags')
+ def test_list_flags(self, mock_flags: Mock):
+ mock_flags.return_value = {
+ 'flag1': True,
+ 'flag2': 'string_value',
+ 'flag3': {
+ 'key': 'value',
+ },
+ }
+
+ resp = self._stub.ListFlags(ListFlagsRequest())
+ self.assertEqual(
+ resp,
+ ListFlagsResponse(flags=Struct(
+ fields={
+ 'flag1': Value(bool_value=True),
+ 'flag2': Value(string_value='string_value'),
+ 'flag3': Value(struct_value=Struct(fields={
+ 'key': Value(string_value='value'),
+ }))
+ })))
+
+ def test_check_tee_enabled(self):
+ Flag.TEE_MACHINE_DEPLOYED.value = True
+ resp = self._stub.CheckTeeEnabled(CheckTeeEnabledRequest())
+ self.assertEqual(resp, CheckTeeEnabledResponse(tee_enabled=True))
+ Flag.TEE_MACHINE_DEPLOYED.value = False
+ resp = self._stub.CheckTeeEnabled(CheckTeeEnabledRequest())
+ self.assertEqual(resp, CheckTeeEnabledResponse(tee_enabled=False))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/utils.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/utils.py
new file mode 100644
index 000000000..46aa498e9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/utils.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from grpc import ServicerContext
+from typing import Tuple
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.rpc.auth import get_common_name, SSL_CLIENT_SUBJECT_DN_HEADER, PROJECT_NAME_HEADER
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+from fedlearner_webconsole.utils.pp_base64 import base64decode, base64encode
+
+
+def get_grpc_context_info(session: Session, context: ServicerContext) -> Tuple[int, int]:
+ metadata = dict(context.invocation_metadata())
+ project_name = decode_project_name(metadata.get(PROJECT_NAME_HEADER))
+ project_id, *_ = session.query(Project.id).filter_by(name=project_name).first()
+ cn = get_common_name(metadata.get(SSL_CLIENT_SUBJECT_DN_HEADER))
+ client_id = ParticipantService(session).get_participant_by_pure_domain_name(get_pure_domain_name(cn)).id
+ return project_id, client_id
+
+
+def get_pure_domain_from_context(context: ServicerContext) -> str:
+ metadata = dict(context.invocation_metadata())
+ cn = get_common_name(metadata.get(SSL_CLIENT_SUBJECT_DN_HEADER))
+ return get_pure_domain_name(cn)
+
+
+def _is_ascii(s: str) -> bool:
+ return all(ord(c) < 128 for c in s)
+
+
+def encode_project_name(project_name: str) -> str:
+ """Encodes project name to grpc-acceptable format.
+
+ gRPC does not recognize unicode in headers, and due to historical
+ reason, we have to be compatiable with anscii strings, otherwise
+ old gRPC server can not get the project name correctly."""
+ if _is_ascii(project_name):
+ return project_name
+ return base64encode(project_name)
+
+
+def decode_project_name(encoded: str) -> str:
+ try:
+ return base64decode(encoded)
+ except Exception: # pylint: disable=broad-except
+ # Not a base64 encoded string
+ pass
+ # Encoded as raw, see details in `encode_project_name`
+ return encoded
diff --git a/web_console_v2/api/fedlearner_webconsole/rpc/v2/utils_test.py b/web_console_v2/api/fedlearner_webconsole/rpc/v2/utils_test.py
new file mode 100644
index 000000000..02c68c426
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/rpc/v2/utils_test.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock
+
+from fedlearner_webconsole.rpc.auth import SSL_CLIENT_SUBJECT_DN_HEADER
+from fedlearner_webconsole.rpc.v2.utils import decode_project_name, encode_project_name, get_pure_domain_from_context
+
+
+class UtilsTest(unittest.TestCase):
+
+ def test_get_pure_domain_from_context(self):
+ mock_context = MagicMock(invocation_metadata=MagicMock(
+ return_value={
+ SSL_CLIENT_SUBJECT_DN_HEADER: 'CN=*.fl-xxx.com,OU=security,O=security,L=beijing,ST=beijing,C=CN'
+ }))
+ self.assertEqual(get_pure_domain_from_context(mock_context), 'xxx')
+
+ def test_encode_project_name_anscii(self):
+ self.assertEqual(encode_project_name('hello world'), 'hello world')
+ self.assertEqual(encode_project_name('-h%20w'), '-h%20w')
+
+ def test_encode_project_name_unicode(self):
+ self.assertEqual(encode_project_name('这是一个测试的名字'), '6L+Z5piv5LiA5Liq5rWL6K+V55qE5ZCN5a2X')
+ self.assertEqual(encode_project_name('中文 & en'), '5Lit5paHICYgZW4=')
+
+ def test_decode_project_name_anscii(self):
+ self.assertEqual(decode_project_name('hello world'), 'hello world')
+ self.assertEqual(decode_project_name('-h%20w'), '-h%20w')
+
+ def test_decode_project_name_unicode(self):
+ self.assertEqual(decode_project_name('6L+Z5piv5LiA5Liq5rWL6K+V55qE5ZCN5a2X'), '这是一个测试的名字')
+ self.assertEqual(decode_project_name('5Lit5paHICYgZW4='), '中文 & en')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/scheduler/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/scheduler/BUILD.bazel
new file mode 100644
index 000000000..9b2ca40ca
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/scheduler/BUILD.bazel
@@ -0,0 +1,59 @@
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "scheduler_lib",
+ srcs = [
+ "scheduler.py",
+ "transaction.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "scheduler_lib_test",
+ size = "medium",
+ srcs = [
+ "scheduler_test.py",
+ ],
+ flaky = True,
+ imports = ["../.."],
+ main = "scheduler_test.py",
+ deps = [
+ ":scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:server_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "workflow_commit_lib_test",
+ size = "small",
+ srcs = [
+ "workflow_commit_test.py",
+ ],
+ imports = ["../.."],
+ main = "workflow_commit_test.py",
+ deps = [
+ ":scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:yaml_formatter_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:service_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing/workflow_template",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/scheduler/scheduler.py b/web_console_v2/api/fedlearner_webconsole/scheduler/scheduler.py
index d3e15aa1d..d7b649585 100644
--- a/web_console_v2/api/fedlearner_webconsole/scheduler/scheduler.py
+++ b/web_console_v2/api/fedlearner_webconsole/scheduler/scheduler.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,52 +14,37 @@
# coding: utf-8
# pylint: disable=broad-except
-
-import os
import threading
import logging
import traceback
-from fedlearner_webconsole.job.yaml_formatter import generate_job_run_yaml
+from queue import Queue, Empty
+from envs import Envs
from fedlearner_webconsole.db import db
-from fedlearner_webconsole.dataset.import_handler import ImportHandler
-from fedlearner_webconsole.utils.k8s_client import k8s_client
from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
-from fedlearner_webconsole.job.models import Job, JobState
from fedlearner_webconsole.scheduler.transaction import TransactionManager
-from fedlearner_webconsole.db import get_session
-from fedlearner_webconsole.job.service import JobService
class Scheduler(object):
+
def __init__(self):
self._condition = threading.Condition(threading.RLock())
self._running = False
self._terminate = False
self._thread = None
- self._pending_workflows = []
- self._pending_jobs = []
- #TODO: remove app
- self._app = None
- self._db_engine = None
- self._import_handler = ImportHandler()
-
- def start(self, app, force=False):
+ self.workflow_queue = Queue()
+
+ def start(self, force=False):
if self._running:
if not force:
raise RuntimeError('Scheduler is already started')
self.stop()
- self._app = app
- with self._app.app_context():
- self._db_engine = db.get_engine()
-
with self._condition:
self._running = True
self._terminate = False
self._thread = threading.Thread(target=self._routine)
self._thread.daemon = True
self._thread.start()
- self._import_handler.init(app)
logging.info('Scheduler started')
def stop(self):
@@ -68,62 +53,37 @@ def stop(self):
with self._condition:
self._terminate = True
- self._condition.notify_all()
+ # Interrupt the block of workflow_queue.get to stop immediately.
+ self.workflow_queue.put(None)
print('stopping')
self._thread.join()
self._running = False
logging.info('Scheduler stopped')
- def wakeup(self, workflow_ids=None,
- job_ids=None,
- data_batch_ids=None):
- with self._condition:
- if workflow_ids:
- if isinstance(workflow_ids, int):
- workflow_ids = [workflow_ids]
- self._pending_workflows.extend(workflow_ids)
- if job_ids:
- if isinstance(job_ids, int):
- job_ids = [job_ids]
- self._pending_jobs.extend(job_ids)
- if data_batch_ids:
- self._import_handler.schedule_to_handle(data_batch_ids)
- self._condition.notify_all()
+ def wakeup(self, workflow_id=None):
+ self.workflow_queue.put(workflow_id)
def _routine(self):
- self._app.app_context().push()
- interval = int(os.environ.get(
- 'FEDLEARNER_WEBCONSOLE_POLLING_INTERVAL', 60))
+ interval = float(Envs.SCHEDULER_POLLING_INTERVAL)
while True:
- with self._condition:
- notified = self._condition.wait(interval)
-
- # TODO(wangsen): use Sqlalchemy insdtead of flask-Sqlalchemy
- # refresh a new session to catch the update of db
- db.session.remove()
- if self._terminate:
- return
- if notified:
- workflow_ids = self._pending_workflows
- self._pending_workflows = []
- self._poll_workflows(workflow_ids)
-
- job_ids = self._pending_jobs
- self._pending_jobs = []
- job_ids.extend(_get_waiting_jobs())
- self._poll_jobs(job_ids)
-
- self._import_handler.handle(pull=False)
- continue
-
- workflows = db.session.query(Workflow.id).filter(
- Workflow.target_state != WorkflowState.INVALID).all()
+ try:
+ try:
+ pending_workflow = self.workflow_queue.get(timeout=interval)
+ except Empty:
+ pending_workflow = None
+ with self._condition:
+ if self._terminate:
+ return
+ if pending_workflow:
+ self._poll_workflows([pending_workflow])
+
+ with db.session_scope() as session:
+ workflows = session.query(Workflow.id).filter(Workflow.target_state != WorkflowState.INVALID).all()
self._poll_workflows([wid for wid, in workflows])
-
- self._poll_jobs(_get_waiting_jobs())
-
- self._import_handler.handle(pull=True)
+ # make the scheduler routine run forever.
+ except Exception as e:
+ logging.error(f'Scheduler routine wrong: {str(e)}')
def _poll_workflows(self, workflow_ids):
logging.info(f'Scheduler polling {len(workflow_ids)} workflows...')
@@ -131,58 +91,13 @@ def _poll_workflows(self, workflow_ids):
try:
self._schedule_workflow(workflow_id)
except Exception as e:
- logging.warning(
- 'Error while scheduling workflow '
- f'{workflow_id}:\n{traceback.format_exc()}')
-
- def _poll_jobs(self, job_ids):
- logging.info(f'Scheduler polling {len(job_ids)} jobs...')
- for job_id in job_ids:
- try:
- self._schedule_job(job_id)
- except Exception as e:
- logging.warning(
- 'Error while scheduling job '
- f'{job_id}:\n{traceback.format_exc()}')
+ logging.warning('Error while scheduling workflow ' f'{workflow_id}:\n{traceback.format_exc()}')
def _schedule_workflow(self, workflow_id):
logging.debug(f'Scheduling workflow {workflow_id}')
- tm = TransactionManager(workflow_id)
- return tm.process()
-
- def _schedule_job(self, job_id):
- job = Job.query.get(job_id)
- assert job is not None, f'Job {job_id} not found'
- if job.state != JobState.WAITING:
- return job.state
-
- with get_session(self._db_engine) as session:
- job_service = JobService(session)
- if not job_service.is_ready(job):
- return job.state
- config = job.get_config()
- if config.is_federated:
- if not job_service.is_peer_ready(job):
- return job.state
-
- try:
- yaml = generate_job_run_yaml(job)
- k8s_client.create_flapp(yaml)
- except Exception as e:
- logging.error(f'Start job {job_id} has error msg: {e.args}')
- job.error_message = str(e)
- db.session.commit()
- return job.state
- job.error_message = None
- job.start()
- db.session.commit()
-
- return job.state
-
-
-def _get_waiting_jobs():
- return [jid for jid, in db.session.query(
- Job.id).filter(Job.state == JobState.WAITING)]
+ with db.session_scope() as session:
+ tm = TransactionManager(workflow_id, session)
+ return tm.process()
scheduler = Scheduler()
diff --git a/web_console_v2/api/fedlearner_webconsole/scheduler/scheduler_test.py b/web_console_v2/api/fedlearner_webconsole/scheduler/scheduler_test.py
new file mode 100644
index 000000000..356b5d011
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/scheduler/scheduler_test.py
@@ -0,0 +1,244 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+from unittest.mock import patch
+
+import time
+import copy
+import unittest
+import secrets
+import logging
+from http import HTTPStatus
+
+from uuid import uuid4
+from envs import Envs
+from testing.common import BaseTestCase
+from testing.common import multi_process_test
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.common_pb2 import CreateJobFlag
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.rpc.server import rpc_server
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+
+
+class LeaderConfig(object):
+ SQLALCHEMY_DATABASE_URI = f'sqlite:///{Envs.BASE_DIR}/{uuid4()}-leader.db'
+ SQLALCHEMY_TRACK_MODIFICATIONS = False
+ JWT_SECRET_KEY = secrets.token_urlsafe(64)
+ PROPAGATE_EXCEPTIONS = True
+ LOGGING_LEVEL = logging.DEBUG
+ GRPC_LISTEN_PORT = 3990
+
+
+class FollowerConfig(object):
+ SQLALCHEMY_DATABASE_URI = f'sqlite:///{Envs.BASE_DIR}/{uuid4()}-follower.db'
+ SQLALCHEMY_TRACK_MODIFICATIONS = False
+ JWT_SECRET_KEY = secrets.token_urlsafe(64)
+ PROPAGATE_EXCEPTIONS = True
+ LOGGING_LEVEL = logging.DEBUG
+ GRPC_LISTEN_PORT = 4990
+
+
+class WorkflowTest(BaseTestCase):
+
+ class Config(LeaderConfig):
+ pass
+
+ @classmethod
+ def setUpClass(cls):
+ cls._patcher = patch('envs.Envs.SCHEDULER_POLLING_INTERVAL', '0.1')
+ cls._patcher.start()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._patcher.stop()
+
+ def setUp(self):
+ super().setUp()
+ self.signin_as_admin()
+ template1 = WorkflowTemplate(name='t1', comment='comment for t1', group_alias='g1')
+ template1.set_config(WorkflowDefinition(group_alias='g1',))
+ with db.session_scope() as session:
+ session.add(template1)
+ session.commit()
+ # This is actually an integration test, so we need to start the rpc
+ rpc_server.stop()
+ rpc_server.start(self.Config.GRPC_LISTEN_PORT)
+ self._wf_template = {
+ 'group_alias':
+ 'test-template',
+ 'job_definitions': [{
+ 'is_federated': True,
+ 'name': 'job1',
+ 'variables': [{
+ 'name': 'x',
+ 'value': '1',
+ 'access_mode': 3
+ }]
+ }, {
+ 'is_federated': True,
+ 'name': 'job2',
+ 'variables': [{
+ 'name': 'y',
+ 'value': '2',
+ 'access_mode': 2
+ }]
+ }]
+ }
+
+ def leader_test_workflow(self):
+ self.setup_project('leader', FollowerConfig.GRPC_LISTEN_PORT)
+ cwf_resp = self.post_helper('/api/v2/projects/1/workflows',
+ data={
+ 'name': 'test-workflow',
+ 'project_id': 1,
+ 'forkable': True,
+ 'config': self._wf_template,
+ 'template_id': 1
+ })
+ self.assertEqual(cwf_resp.status_code, HTTPStatus.CREATED)
+ cwf_data = self.get_response_data(cwf_resp)
+ self.assertEqual(cwf_data['job_ids'], [1, 2])
+
+ self._check_workflow_state(1, 'READY_TO_RUN')
+
+ # test update
+ patch_config = copy.deepcopy(self._wf_template)
+ patch_config['job_definitions'][1]['variables'][0]['value'] = '4'
+ resp = self.patch_helper('/api/v2/projects/1/workflows/1', data={'config': patch_config, 'template_id': 1})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+
+ resp = self.get_helper('/api/v2/projects/1/workflows/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ ret_wf = resp.json['data']['config']
+ self.assertEqual(ret_wf['job_definitions'][1]['variables'][0]['value'], '4')
+
+ # test update remote
+ patch_config['job_definitions'][0]['variables'][0]['value'] = '5'
+ resp = self.patch_helper('/api/v2/projects/1/workflows/1/peer_workflows',
+ data={
+ 'config': patch_config,
+ 'template_id': 1
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+
+ resp = self.get_helper('/api/v2/projects/1/workflows/1/peer_workflows')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ ret_wf = list(resp.json['data'].values())[0]['config']
+ self.assertEqual(ret_wf['job_definitions'][0]['variables'][0]['value'], '5')
+
+ # test fork
+ cwf_resp = self.post_helper('/api/v2/projects/1/workflows',
+ data={
+ 'name': 'test-workflow2',
+ 'project_id': 1,
+ 'forkable': True,
+ 'forked_from': 1,
+ 'create_job_flags': [
+ CreateJobFlag.REUSE,
+ CreateJobFlag.NEW,
+ ],
+ 'peer_create_job_flags': [
+ CreateJobFlag.REUSE,
+ CreateJobFlag.NEW,
+ ],
+ 'template_id': 1,
+ 'config': self._wf_template,
+ 'fork_proposal_config': {
+ 'job_definitions': [{
+ 'variables': [{
+ 'name': 'x',
+ 'value': '2'
+ }]
+ }, {
+ 'variables': [{
+ 'name': 'y',
+ 'value': '3'
+ }]
+ }]
+ }
+ })
+ self.assertEqual(cwf_resp.status_code, HTTPStatus.CREATED)
+ cwf_data = self.get_response_data(cwf_resp)
+ self.assertEqual(cwf_data['job_ids'], [1, 3])
+ self._check_workflow_state(2, 'READY_TO_RUN')
+ time.sleep(2)
+ resp = self.post_helper('/api/v2/projects/1/workflows/2:invalidate')
+ self._check_workflow_state(2, 'INVALID')
+
+ def follower_test_workflow(self):
+ self.setup_project('follower', LeaderConfig.GRPC_LISTEN_PORT)
+ self._check_workflow_state(1, 'PENDING_ACCEPT')
+
+ cwf_resp = self.put_helper('/api/v2/projects/1/workflows/1',
+ data={
+ 'forkable': True,
+ 'config': self._wf_template,
+ 'template_id': 1
+ })
+ self.assertEqual(cwf_resp.status_code, HTTPStatus.OK)
+ cwf_data = self.get_response_data(cwf_resp)
+ self.assertEqual(cwf_data['job_ids'], [1, 2])
+ self._check_workflow_state(1, 'READY_TO_RUN')
+ with db.session_scope() as session:
+ self.assertEqual(len(session.query(Job).filter_by(workflow_id=1).all()), 2)
+
+ # test fork
+ json = self._check_workflow_state(2, 'READY_TO_RUN')
+ with db.session_scope() as session:
+ self.assertEqual(len(session.query(Job).all()), 3)
+ self.assertEqual(json['data']['create_job_flags'], [
+ CreateJobFlag.REUSE,
+ CreateJobFlag.NEW,
+ ])
+ self.assertEqual(json['data']['peer_create_job_flags'], [
+ CreateJobFlag.REUSE,
+ CreateJobFlag.NEW,
+ ])
+ jobs = json['data']['config']['job_definitions']
+ self.assertEqual(jobs[0]['variables'][0]['value'], '2')
+ self.assertEqual(jobs[1]['variables'][0]['value'], '2')
+ time.sleep(2)
+ resp = self.post_helper('/api/v2/projects/1/workflows/2:invalidate')
+ self._check_workflow_state(2, 'INVALID')
+
+ def _check_workflow_state(self, workflow_id, state, max_retries=10):
+ cnt = 0
+ while True:
+ time.sleep(0.1)
+ cnt = cnt + 1
+ if cnt > max_retries:
+ self.fail(f'workflow [{workflow_id}] state is unexpected')
+ resp = self.get_helper(f'/api/v2/projects/1/workflows/{workflow_id}')
+ if resp.status_code != HTTPStatus.OK:
+ logging.info(f'get workflow {workflow_id} failed: {resp.json}')
+ continue
+ if resp.json['data']['state'] == state:
+ return resp.json
+
+
+if __name__ == '__main__':
+ multi_process_test([{
+ 'class': WorkflowTest,
+ 'method': 'leader_test_workflow',
+ 'config': LeaderConfig
+ }, {
+ 'class': WorkflowTest,
+ 'method': 'follower_test_workflow',
+ 'config': FollowerConfig
+ }])
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/scheduler/transaction.py b/web_console_v2/api/fedlearner_webconsole/scheduler/transaction.py
index aa605e157..706e1a75a 100644
--- a/web_console_v2/api/fedlearner_webconsole/scheduler/transaction.py
+++ b/web_console_v2/api/fedlearner_webconsole/scheduler/transaction.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +13,21 @@
# limitations under the License.
# coding: utf-8
-
-from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.services import ParticipantService
from fedlearner_webconsole.rpc.client import RpcClient
-from fedlearner_webconsole.workflow.models import (
- Workflow, WorkflowState, TransactionState, VALID_TRANSITIONS
-)
+from fedlearner_webconsole.workflow.models import (Workflow, WorkflowState, TransactionState, VALID_TRANSITIONS)
from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.workflow.resource_manager import ResourceManager
+from fedlearner_webconsole.workflow.workflow_controller import invalidate_workflow_locally
+
class TransactionManager(object):
- def __init__(self, workflow_id):
+
+ def __init__(self, workflow_id, session):
self._workflow_id = workflow_id
- self._workflow = Workflow.query.get(workflow_id)
+ self._session = session
+ # TODO(hangweiqiang): remove workflow, project from __init__
+ self._workflow = session.query(Workflow).get(workflow_id)
assert self._workflow is not None
self._project = self._workflow.project
assert self._project is not None
@@ -39,15 +42,14 @@ def project(self):
def process(self):
# process local workflow
+ manager = ResourceManager(self._session, self._workflow)
if self._workflow.is_local():
- self._workflow.update_local_state()
+ manager.update_local_state()
self._reload()
return self._workflow
# reload workflow and resolve -ing states
- self._workflow.update_state(
- self._workflow.state, self._workflow.target_state,
- self._workflow.transaction_state)
+ manager.update_state(self._workflow.state, self._workflow.target_state, self._workflow.transaction_state)
self._reload()
if not self._recover_from_abort():
@@ -56,77 +58,67 @@ def process(self):
if self._workflow.target_state == WorkflowState.INVALID:
return self._workflow
- if self._workflow.state == WorkflowState.INVALID:
- raise RuntimeError(
- f'Cannot process invalid workflow {self._workflow.name}')
+ if self._workflow.is_invalid():
+ raise RuntimeError(f'Cannot process invalid workflow {self._workflow.name}')
assert (self._workflow.state, self._workflow.target_state) \
- in VALID_TRANSITIONS
+ in VALID_TRANSITIONS
if self._workflow.transaction_state == TransactionState.READY:
# prepare self as coordinator
- self._workflow.update_state(
- self._workflow.state,
- self._workflow.target_state,
- TransactionState.COORDINATOR_PREPARE)
+ manager.update_state(self._workflow.state, self._workflow.target_state,
+ TransactionState.COORDINATOR_PREPARE)
self._reload()
if self._workflow.transaction_state == \
TransactionState.COORDINATOR_COMMITTABLE:
# prepare self succeeded. Tell participants to prepare
- states = self._broadcast_state(
- self._workflow.state, self._workflow.target_state,
- TransactionState.PARTICIPANT_PREPARE)
+ states = self._broadcast_state(self._workflow.state, self._workflow.target_state,
+ TransactionState.PARTICIPANT_PREPARE)
committable = True
for state in states:
if state != TransactionState.PARTICIPANT_COMMITTABLE:
committable = False
if state == TransactionState.ABORTED:
# abort as coordinator if some participants aborted
- self._workflow.update_state(
- None, None, TransactionState.COORDINATOR_ABORTING)
+ manager.update_state(None, None, TransactionState.COORDINATOR_ABORTING)
self._reload()
break
# commit as coordinator if participants all committable
if committable:
- self._workflow.update_state(
- None, None, TransactionState.COORDINATOR_COMMITTING)
+ manager.update_state(None, None, TransactionState.COORDINATOR_COMMITTING)
self._reload()
if self._workflow.transaction_state == \
TransactionState.COORDINATOR_COMMITTING:
# committing as coordinator. tell participants to commit
- if self._broadcast_state_and_check(
- self._workflow.state, self._workflow.target_state,
- TransactionState.PARTICIPANT_COMMITTING,
- TransactionState.READY):
+ if self._broadcast_state_and_check(self._workflow.state, self._workflow.target_state,
+ TransactionState.PARTICIPANT_COMMITTING, TransactionState.READY):
# all participants committed. finish.
- self._workflow.commit()
+ manager.commit()
self._reload()
self._recover_from_abort()
return self._workflow
def _reload(self):
- db.session.commit()
- db.session.refresh(self._workflow)
+ self._session.commit()
+ self._session.refresh(self._workflow)
- def _broadcast_state(
- self, state, target_state, transaction_state):
- project_config = self._project.get_config()
+ def _broadcast_state(self, state, target_state, transaction_state):
+ service = ParticipantService(self._session)
+ participants = service.get_platform_participants_by_project(self._project.id)
states = []
- for party in project_config.participants:
- client = RpcClient(project_config, party)
- forked_from_uuid = Workflow.query.filter_by(
- id=self._workflow.forked_from
- ).first().uuid if self._workflow.forked_from else None
- resp = client.update_workflow_state(
- self._workflow.name, state, target_state, transaction_state,
- self._workflow.uuid,
- forked_from_uuid, self._workflow.extra)
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(self._project.name, self._project.token,
+ participant.domain_name)
+ forked_from_uuid = self._session.query(Workflow).filter_by(
+ id=self._workflow.forked_from).first().uuid if self._workflow.forked_from else None
+ resp = client.update_workflow_state(self._workflow.name, state, target_state, transaction_state,
+ self._workflow.uuid, forked_from_uuid, self._workflow.extra)
if resp.status.code == common_pb2.STATUS_SUCCESS:
- if resp.state == WorkflowState.INVALID:
- self._workflow.invalidate()
+ if WorkflowState(resp.state) == WorkflowState.INVALID:
+ invalidate_workflow_locally(self._session, self._workflow)
self._reload()
raise RuntimeError('Peer workflow invalidated. Abort.')
states.append(TransactionState(resp.transaction_state))
@@ -134,8 +126,7 @@ def _broadcast_state(
states.append(None)
return states
- def _broadcast_state_and_check(self,
- state, target_state, transaction_state, target_transaction_state):
+ def _broadcast_state_and_check(self, state, target_state, transaction_state, target_transaction_state):
states = self._broadcast_state(state, target_state, transaction_state)
for i in states:
if i != target_transaction_state:
@@ -145,13 +136,10 @@ def _broadcast_state_and_check(self,
def _recover_from_abort(self):
if self._workflow.transaction_state == \
TransactionState.COORDINATOR_ABORTING:
- if not self._broadcast_state_and_check(
- self._workflow.state, WorkflowState.INVALID,
- TransactionState.PARTICIPANT_ABORTING,
- TransactionState.ABORTED):
+ if not self._broadcast_state_and_check(self._workflow.state, WorkflowState.INVALID,
+ TransactionState.PARTICIPANT_ABORTING, TransactionState.ABORTED):
return False
- self._workflow.update_state(
- None, WorkflowState.INVALID, TransactionState.ABORTED)
+ self._workflow.update_state(None, WorkflowState.INVALID, TransactionState.ABORTED, self._session)
self._reload()
if self._workflow.transaction_state != TransactionState.ABORTED:
@@ -159,10 +147,9 @@ def _recover_from_abort(self):
assert self._workflow.target_state == WorkflowState.INVALID
- if not self._broadcast_state_and_check(
- self._workflow.state, WorkflowState.INVALID,
- TransactionState.READY, TransactionState.READY):
+ if not self._broadcast_state_and_check(self._workflow.state, WorkflowState.INVALID, TransactionState.READY,
+ TransactionState.READY):
return False
- self._workflow.update_state(None, None, TransactionState.READY)
+ self._workflow.update_state(None, None, TransactionState.READY, self._session)
self._reload()
return True
diff --git a/web_console_v2/api/fedlearner_webconsole/scheduler/workflow_commit_test.py b/web_console_v2/api/fedlearner_webconsole/scheduler/workflow_commit_test.py
new file mode 100644
index 000000000..2095a8cda
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/scheduler/workflow_commit_test.py
@@ -0,0 +1,110 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import os
+import time
+import unittest
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.workflow.service import WorkflowService
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.models import JobState
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.scheduler.transaction import TransactionState
+from fedlearner_webconsole.scheduler.scheduler import \
+ scheduler
+from fedlearner_webconsole.proto import project_pb2
+from fedlearner_webconsole.job.yaml_formatter import YamlFormatterService
+from testing.workflow_template.test_template_left import make_workflow_template
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class WorkflowsCommitTest(NoWebServerTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ os.environ['FEDLEARNER_WEBCONSOLE_POLLING_INTERVAL'] = '0.1'
+
+ def setUp(self):
+ super().setUp()
+ # Inserts project
+ config = {
+ 'variables': [{
+ 'name': 'namespace',
+ 'value': 'leader'
+ }, {
+ 'name': 'basic_envs',
+ 'value': '{}'
+ }, {
+ 'name': 'storage_root_path',
+ 'value': '/'
+ }]
+ }
+
+ project = Project(name='test', config=ParseDict(config, project_pb2.ProjectConfig()).SerializeToString())
+ participant = Participant(name='party_leader', host='127.0.0.1', port=5000, domain_name='fl-leader.com')
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ with db.session_scope() as session:
+ session.add(project)
+ session.add(participant)
+ session.add(relationship)
+ session.commit()
+
+ @staticmethod
+ def _wait_until(cond, retry_times: int = 5):
+ for _ in range(retry_times):
+ time.sleep(0.1)
+ with db.session_scope() as session:
+ if cond(session):
+ return
+
+ def test_workflow_commit(self):
+ # test the committing stage for workflow creating
+ workflow_def = make_workflow_template()
+ workflow = Workflow(id=20,
+ name='job_test1',
+ comment='这是一个测试工作流',
+ config=workflow_def.SerializeToString(),
+ project_id=1,
+ forkable=True,
+ state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.PARTICIPANT_COMMITTING,
+ creator='test_creator')
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.commit()
+ WorkflowService(session).setup_jobs(workflow)
+ session.commit()
+
+ scheduler.wakeup(20)
+ self._wait_until(lambda session: session.query(Workflow).get(20).state == WorkflowState.READY)
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(20)
+ jobs = workflow.get_jobs(session)
+ self.assertEqual(len(jobs), 2)
+ self.assertEqual(jobs[0].state, JobState.NEW)
+ self.assertEqual(jobs[1].state, JobState.NEW)
+ # test generate job run yaml
+ job_loaded_json = YamlFormatterService(session).generate_job_run_yaml(jobs[0])
+ self.assertEqual(job_loaded_json['metadata']['name'], jobs[0].name)
+ self.assertEqual(job_loaded_json['metadata']['labels']['owner'], workflow.creator)
+ session.commit()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/serving/BUILD.bazel
new file mode 100644
index 000000000..f8cb1f555
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/BUILD.bazel
@@ -0,0 +1,276 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "database_fetcher_lib",
+ srcs = ["database_fetcher.py"],
+ imports = ["../.."],
+)
+
+py_library(
+ name = "metrics_lib",
+ srcs = ["metrics.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = ["services.py"],
+ imports = ["../.."],
+ deps = [
+ ":database_fetcher_lib",
+ ":metrics_lib",
+ ":models_lib",
+ ":remote_lib",
+ ":serving_yaml_template_lib",
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:sorting_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_kubernetes//:pkg",
+ "@common_tensorflow//:pkg",
+ "@common_tensorflow_serving_api//:pkg",
+ ],
+)
+
+py_library(
+ name = "serving_yaml_template_lib",
+ srcs = ["serving_yaml_template.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_yaml_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "runners_lib",
+ srcs = [
+ "runners.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_time_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:process_utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = [
+ "utils.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "remote_lib",
+ srcs = [
+ "remote.py",
+ ],
+ imports = ["../.."],
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":metrics_lib",
+ ":models_lib",
+ ":participant_fetcher_lib",
+ ":remote_lib",
+ ":runners_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_sqlalchemy//:pkg",
+ "@common_tensorflow//:pkg",
+ ],
+)
+
+py_library(
+ name = "participant_fetcher_lib",
+ srcs = ["participant_fetcher.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ ],
+)
+
+py_test(
+ name = "apis_inference_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_inference_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_inference_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "apis_runner_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_runner_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_runner_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_test(
+ name = "services_lib_test",
+ size = "medium",
+ srcs = [
+ "services_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "services_test.py",
+ deps = [
+ ":services_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ ],
+)
+
+py_test(
+ name = "serving_yaml_template_lib_test",
+ size = "small",
+ srcs = [
+ "serving_yaml_template_test.py",
+ ],
+ imports = ["../.."],
+ main = "serving_yaml_template_test.py",
+ deps = [
+ ":serving_yaml_template_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "utils_lib_test",
+ size = "small",
+ srcs = [
+ "utils_test.py",
+ ],
+ imports = ["../.."],
+ main = "utils_test.py",
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/__init__.py b/web_console_v2/api/fedlearner_webconsole/serving/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/apis.py b/web_console_v2/api/fedlearner_webconsole/serving/apis.py
new file mode 100644
index 000000000..16d752eb2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/apis.py
@@ -0,0 +1,556 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from http import HTTPStatus
+from typing import Optional
+
+from flask_restful import Resource
+from google.protobuf import json_format
+from google.protobuf.text_format import Parse
+from marshmallow import Schema, fields, post_load
+from sqlalchemy.orm import joinedload
+from sqlalchemy.sql.elements import ColumnElement
+from tensorflow.core.example.example_pb2 import Example
+
+from fedlearner_webconsole.proto.serving_pb2 import ServingServiceRemotePlatform
+from fedlearner_webconsole.serving import remote
+from fedlearner_webconsole.utils.decorators.pp_flask import use_args, use_kwargs
+
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.proto import serving_pb2
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, ModelSignatureParserInput
+from fedlearner_webconsole.proto.filtering_pb2 import FilterOp, SimpleExpression, FilterExpression
+from fedlearner_webconsole.serving.metrics import serving_metrics_emit_counter
+from fedlearner_webconsole.serving.participant_fetcher import ParticipantFetcher
+from fedlearner_webconsole.serving.runners import ModelSignatureParser, start_query_participant, start_update_model
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import NotFoundException, InvalidArgumentException, \
+ InternalException
+from fedlearner_webconsole.serving.models import ServingModel, ServingDeployment, ServingNegotiator
+from fedlearner_webconsole.serving.services import TensorflowServingService, ServingDeploymentService, \
+ ServingModelService
+from fedlearner_webconsole.utils import filtering, sorting, flask_utils
+from fedlearner_webconsole.utils.decorators.pp_flask import input_validator
+from fedlearner_webconsole.utils.flask_utils import make_flask_response, FilterExpField
+from fedlearner_webconsole.utils.proto import to_dict
+
+SORT_SUPPORTED_COLUMN = ['created_at']
+
+
+class ResourceParams(Schema):
+ cpu = fields.Str(required=True)
+ memory = fields.Str(required=True)
+ replicas = fields.Integer(required=True)
+
+
+class RemotePlatformParams(Schema):
+ platform = fields.Str(required=True)
+ payload = fields.Str(required=True)
+
+
+class ServingCreateParams(Schema):
+ name = fields.Str(required=True)
+ comment = fields.Str(required=False)
+ model_id = fields.Integer(required=False)
+ model_group_id = fields.Integer(required=False)
+ is_local = fields.Boolean(required=False)
+ resource = fields.Nested(ResourceParams, required=False)
+ remote_platform = fields.Nested(RemotePlatformParams, required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ if 'resource' in data:
+ data['resource'] = json_format.ParseDict(data['resource'], serving_pb2.ServingServiceResource())
+ if 'remote_platform' in data:
+ data['remote_platform'] = json_format.ParseDict(data['remote_platform'],
+ serving_pb2.ServingServiceRemotePlatform())
+ return data
+
+
+class ServingUpdateParams(Schema):
+ comment = fields.Str(required=False)
+ model_id = fields.Integer(required=False)
+ model_group_id = fields.Integer(required=False)
+ resource = fields.Nested(ResourceParams, required=False)
+
+ @post_load
+ def make(self, data, **kwargs):
+ if 'resource' in data:
+ data['resource'] = json_format.ParseDict(data['resource'], serving_pb2.ServingServiceResource())
+ return data
+
+
+def _build_keyword_query(exp: SimpleExpression) -> ColumnElement:
+ return ServingModel.name.ilike(f'%{exp.string_value}%')
+
+
+class ServingServicesApiV2(Resource):
+
+ FILTER_FIELDS = {
+ 'name':
+ filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.EQUAL: None}),
+ 'keyword':
+ filtering.SupportedField(type=filtering.FieldType.STRING, ops={FilterOp.CONTAIN: _build_keyword_query}),
+ }
+
+ SORTER_FIELDS = ['created_at']
+
+ def __init__(self):
+ self._filter_builder = filtering.FilterBuilder(model_class=ServingModel, supported_fields=self.FILTER_FIELDS)
+ self._sorter_builder = sorting.SorterBuilder(model_class=ServingModel, supported_fields=self.SORTER_FIELDS)
+
+ @use_kwargs(
+ {
+ 'filter_exp': FilterExpField(data_key='filter', required=False, load_default=None),
+ 'sorter_exp': fields.String(data_key='order_by', required=False, load_default=None),
+ },
+ location='query')
+ @credentials_required
+ def get(self, project_id: int, filter_exp: Optional[FilterExpression], sorter_exp: Optional[str]):
+ """Get serving services list
+ ---
+ tags:
+ - serving
+ description: get serving services list
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: query
+ name: filter
+ schema:
+ type: string
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of service service information
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ServingService'
+ """
+ service_list = []
+ with db.session_scope() as session:
+ query = session.query(ServingModel)
+ query = query.filter(ServingModel.project_id == project_id)
+ if filter_exp is not None:
+ try:
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter_exp: {str(e)}') from e
+ if sorter_exp is not None:
+ try:
+ sorter_exp = sorting.parse_expression(sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ else:
+ sorter_exp = sorting.SortExpression(field='created_at', is_asc=False)
+ query = self._sorter_builder.build_query(query, sorter_exp)
+ query = query.outerjoin(ServingDeployment,
+ ServingDeployment.id == ServingModel.serving_deployment_id).options(
+ joinedload(ServingModel.serving_deployment))
+ all_records = query.all()
+ for serving_model in all_records:
+ serving_service = serving_model.to_serving_service()
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model_service.set_resource_and_status_on_ref(serving_service, serving_model)
+ serving_model_service.set_is_local_on_ref(serving_service, serving_model)
+ service_list.append(serving_service)
+ return make_flask_response(data=service_list, status=HTTPStatus.OK)
+
+ @use_args(ServingCreateParams(), location='json_or_form')
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.SERVING_SERVICE, op_type=Event.CREATE)
+ def post(self, body: dict, project_id: int):
+ """Create one serving service
+ ---
+ tags:
+ - serving
+ description: create one serving service
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/ServingCreateParams'
+ responses:
+ 201:
+ description: detail of one serving service
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ServingServiceDetail'
+ """
+ if 'remote_platform' in body: # need check sso for third-party serving
+ current_sso = flask_utils.get_current_sso()
+ if current_sso is None:
+ raise InvalidArgumentException('not a sso user')
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(
+ project_id=project_id,
+ name=body['name'],
+ is_local=body['is_local'],
+ comment=body['comment'] if 'comment' in body else None,
+ model_id=body['model_id'] if 'model_id' in body else None,
+ model_group_id=body['model_group_id'] if 'model_group_id' in body else None,
+ resource=body['resource'] if 'resource' in body else None,
+ remote_platform=body['remote_platform'] if 'remote_platform' in body else None)
+
+ # start async query participant serving status
+ if 'is_local' in body and not body['is_local']:
+ start_query_participant(session)
+
+ # start async query signature
+ if 'remote_platform' not in body:
+ runner_item_name = ModelSignatureParser.generate_task_name(serving_model.id, serving_model.name)
+ runner_input = RunnerInput(model_signature_parser_input=ModelSignatureParserInput(
+ serving_model_id=serving_model.id))
+ ComposerService(session).collect_v2(name=runner_item_name,
+ items=[(ItemType.SERVING_SERVICE_PARSE_SIGNATURE, runner_input)])
+
+ # start auto update model runner
+ if serving_model.model_group_id is not None:
+ start_update_model(session)
+
+ session.commit()
+ serving_metrics_emit_counter('serving.create.success', serving_model)
+ return make_flask_response(data=serving_model.to_serving_service_detail(), status=HTTPStatus.CREATED)
+
+
+class ServingServiceApiV2(Resource):
+
+ @use_kwargs({
+ 'sorter_exp': fields.String(data_key='order_by', required=False, load_default=None),
+ },
+ location='query')
+ @credentials_required
+ def get(self, project_id: int, serving_model_id: int, sorter_exp: Optional[str]):
+ """Get one serving service
+ ---
+ tags:
+ - serving
+ description: get one serving service
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: serving_model_id
+ schema:
+ type: integer
+ - in: query
+ name: order_by
+ schema:
+ type: string
+ responses:
+ 200:
+ description: detail of one serving service
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ServingServiceDetail'
+ """
+ sorter = None
+ if sorter_exp is not None:
+ try:
+ sorter = sorting.parse_expression(sorter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid sorter: {str(e)}') from e
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ result = serving_model_service.get_serving_service_detail(serving_model_id, project_id, sorter)
+ return make_flask_response(data=result)
+
+ @use_args(ServingUpdateParams(), location='json_or_form')
+ @credentials_required
+ @emits_event(resource_type=Event.SERVING_SERVICE, op_type=Event.UPDATE)
+ def patch(self, body: dict, project_id: int, serving_model_id: int):
+ """Modify one serving service
+ ---
+ tags:
+ - serving
+ description: get one serving service
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: serving_model_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/ServingUpdateParams'
+ responses:
+ 200:
+ description: detail of one serving service
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ServingServiceDetail'
+ """
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).filter_by(id=serving_model_id, project_id=project_id).options(
+ joinedload(ServingModel.serving_deployment)).one_or_none()
+ if not serving_model:
+ raise NotFoundException(f'Failed to find serving service: {serving_model_id}')
+ if 'comment' in body:
+ serving_model.comment = body['comment']
+ if serving_model.serving_deployment.is_remote_serving(): # need check sso for third-party serving
+ current_sso = flask_utils.get_current_sso()
+ if current_sso is None:
+ raise InvalidArgumentException('not a sso user')
+ need_update_model = False
+ if 'model_id' in body:
+ need_update_model = ServingModelService(session).update_model(model_id=body['model_id'],
+ model_group_id=None,
+ serving_model=serving_model)
+ elif 'model_group_id' in body:
+ need_update_model = ServingModelService(session).update_model(model_id=None,
+ model_group_id=body['model_group_id'],
+ serving_model=serving_model)
+ start_update_model(session)
+ if 'resource' in body:
+ current_resource = json.loads(serving_model.serving_deployment.resource)
+ new_resource = to_dict(body['resource'])
+ if new_resource != current_resource:
+ ServingModelService(session).update_resource(new_resource, serving_model)
+ if need_update_model and not serving_model.serving_deployment.is_remote_serving():
+ # start async query signature
+ runner_item_name = ModelSignatureParser.generate_task_name(serving_model.id, serving_model.name)
+ runner_input = RunnerInput(model_signature_parser_input=ModelSignatureParserInput(
+ serving_model_id=serving_model.id))
+ ComposerService(session).collect_v2(name=runner_item_name,
+ items=[(ItemType.SERVING_SERVICE_PARSE_SIGNATURE, runner_input)])
+ session.add(serving_model)
+ session.commit()
+ serving_metrics_emit_counter('serving.update.success', serving_model)
+ return make_flask_response(data=serving_model.to_serving_service_detail())
+
+ @credentials_required
+ @emits_event(resource_type=Event.SERVING_SERVICE, op_type=Event.DELETE)
+ def delete(self, project_id: int, serving_model_id: int):
+ """Delete one serving service
+ ---
+ tags:
+ - serving
+ description: delete one serving service
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: serving_model_id
+ schema:
+ type: integer
+ responses:
+ 204:
+ description: delete the sering service successfully
+ """
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).filter_by(id=serving_model_id,
+ project_id=project_id).one_or_none()
+ if not serving_model:
+ serving_metrics_emit_counter('serving.delete.db_error', serving_model)
+ raise NotFoundException(f'Failed to find serving model: {serving_model_id}')
+ ServingModelService(session).delete_serving_service(serving_model)
+ session.commit()
+ serving_metrics_emit_counter('serving.delete.success', serving_model)
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class ServingServiceInferenceApiV2(Resource):
+
+ @use_args({'input_data': fields.String(required=True, help='serving input data')}, location='json')
+ @credentials_required
+ def post(self, body: dict, project_id: int, serving_model_id: int):
+ """Get inference result from a serving service
+ ---
+ tags:
+ - serving
+ description: get inference result from a serving service
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: serving_model_id
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ description: input data to do inference
+ content:
+ application/json:
+ schema:
+ type: string
+ responses:
+ 200:
+ description: inference result
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.PredictResponse'
+ """
+ try:
+ input_data = Parse(body['input_data'], Example())
+ except Exception as err:
+ serving_metrics_emit_counter('serving.inference.invalid_arguments')
+ raise InvalidArgumentException(f'Failed to parse inference input: {serving_model_id}') from err
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).filter_by(id=serving_model_id,
+ project_id=project_id).one_or_none()
+ if not serving_model:
+ serving_metrics_emit_counter('serving.inference.db_error')
+ raise NotFoundException(f'Failed to find serving model: {serving_model_id}')
+ deployment_name = serving_model.serving_deployment.deployment_name
+ tf_serving_service = TensorflowServingService(deployment_name)
+ extend_input = {}
+ with db.session_scope() as session:
+ serving_negotiator = session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model_id).one_or_none()
+ if serving_negotiator is not None:
+ extend_input.update(ParticipantFetcher(session).fetch(serving_negotiator, '1'))
+ output = tf_serving_service.get_model_inference_output(input_data, extend_input)
+ if 'Error' in output:
+ serving_metrics_emit_counter('serving.inference.rpc_error')
+ raise InternalException(f'Failed to do inference: {output}')
+ serving_metrics_emit_counter('serving.inference.success')
+ return make_flask_response(data=output)
+
+
+class ServingServiceInstanceLogApiV2(Resource):
+
+ @use_args({'tail_lines': fields.Integer(required=True, help='tail lines is required')}, location='query')
+ @credentials_required
+ def get(self, body: dict, project_id: int, serving_model_id: int, instance_name: str):
+ """Get inference result from a serving service
+ ---
+ tags:
+ - serving
+ description: get inference result from a serving service
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: serving_model_id
+ schema:
+ type: integer
+ - in: path
+ name: instance_name
+ schema:
+ type: string
+ - in: query
+ name: tail_lines
+ schema:
+ type: integer
+ description: lines of log
+ responses:
+ 200:
+ description: inference result
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).filter_by(id=serving_model_id,
+ project_id=project_id).one_or_none()
+ if not serving_model:
+ serving_metrics_emit_counter('serving.logs.db_error')
+ raise NotFoundException(f'Failed to find serving model: {serving_model_id}')
+ tail_lines = body['tail_lines']
+ result = ServingDeploymentService.get_pod_log(instance_name, tail_lines)
+ return make_flask_response(data=result)
+
+
+class ServingServiceRemotePlatformsApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int):
+ """Get supported third-party serving platform
+ ---
+ tags:
+ - serving
+ description: get supported third-party serving platform
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: list of supported serving remote platform
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ServingServiceRemotePlatform'
+ """
+ result_list = []
+ current_sso = flask_utils.get_current_sso()
+ if current_sso is None:
+ return make_flask_response(data=result_list)
+ for key, value in remote.supported_remote_serving.items():
+ support_platform = ServingServiceRemotePlatform(platform=key)
+ result_list.append(support_platform)
+ return make_flask_response(data=result_list)
+
+
+def initialize_serving_services_apis(api):
+ api.add_resource(ServingServicesApiV2, '/projects//serving_services')
+ api.add_resource(ServingServiceApiV2, '/projects//serving_services/')
+ api.add_resource(ServingServiceInferenceApiV2,
+ '/projects//serving_services//inference')
+ api.add_resource(
+ ServingServiceInstanceLogApiV2, '/projects//serving_services//instances'
+ '//log')
+ api.add_resource(ServingServiceRemotePlatformsApi, '/projects//serving_services/remote_platforms')
+
+ # if a schema is used, one has to append it to schema_manager so Swagger knows there is a schema available
+ schema_manager.append(ServingCreateParams)
+ schema_manager.append(ServingUpdateParams)
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/apis_inference_test.py b/web_console_v2/api/fedlearner_webconsole/serving/apis_inference_test.py
new file mode 100644
index 000000000..82f4eccd9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/apis_inference_test.py
@@ -0,0 +1,313 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import unittest
+from http import HTTPStatus
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+
+from tensorflow.core.example.example_pb2 import Example
+from tensorflow.core.example.feature_pb2 import Feature, Int64List, Features
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.mmgr.models import Model
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto import service_pb2, serving_pb2, common_pb2
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput
+from fedlearner_webconsole.serving.database_fetcher import DatabaseFetcher
+from fedlearner_webconsole.serving.models import ServingModel, ServingModelStatus
+from fedlearner_webconsole.serving.runners import QueryParticipantStatusRunner
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.serving.services import NegotiatorServingService
+from testing.common import BaseTestCase
+
+TEST_SIGNATURE = {
+ 'inputs': [{
+ 'name': 'example_id',
+ 'type': 'DT_STRING'
+ }, {
+ 'name': 'raw_id',
+ 'type': 'DT_STRING'
+ }, {
+ 'name': 'x0',
+ 'type': 'DT_FLOAT'
+ }, {
+ 'name': 'x1',
+ 'type': 'DT_INT64',
+ 'dim': [4]
+ }],
+ 'from_participants': {
+ 'act1_f': {
+ 'name': 'act1_f:0',
+ 'dtype': 'DT_FLOAT',
+ 'tensorShape': {
+ 'unknownRank': True
+ }
+ },
+ 'act2_f': {
+ 'name': 'act2_f:0',
+ 'dtype': 'DT_DOUBLE'
+ },
+ 'act3_f': {
+ 'name': 'act3_f:0',
+ 'dtype': 'DT_INT32'
+ },
+ 'act4_f': {
+ 'name': 'act4_f:0',
+ 'dtype': 'DT_INT64'
+ },
+ 'act5_f': {
+ 'name': 'act5_f:0',
+ 'dtype': 'DT_UINT32'
+ },
+ 'act6_f': {
+ 'name': 'act6_f:0',
+ 'dtype': 'DT_UINT64'
+ },
+ 'act7_f': {
+ 'name': 'act7_f:0',
+ 'dtype': 'DT_STRING'
+ },
+ 'act8_f': {
+ 'name': 'act8_f:0',
+ 'dtype': 'DT_BOOL'
+ }
+ }
+}
+
+TEST_OUTPUT = {
+ 'result': {
+ 'act1_f': {
+ 'dtype': 'DT_FLOAT',
+ 'floatVal': 0.1
+ },
+ 'act2_f': {
+ 'dtype': 'DT_DOUBLE',
+ 'doubleVal': 0.1
+ },
+ 'act3_f': {
+ 'dtype': 'DT_INT32',
+ 'intVal': -11
+ },
+ 'act4_f': {
+ 'dtype': 'DT_INT64',
+ 'int64Val': -12
+ },
+ 'act5_f': {
+ 'dtype': 'DT_UINT32',
+ 'uint32Val': 13
+ },
+ 'act6_f': {
+ 'dtype': 'DT_UINT64',
+ 'uint64Val': 14
+ },
+ 'act7_f': {
+ 'dtype': 'DT_STRING',
+ 'stringVal': 'test'
+ },
+ 'act8_f': {
+ 'dtype': 'DT_BOOL',
+ 'boolVal': False
+ },
+ }
+}
+
+
+def _get_create_serving_service_input(name, model_id: int):
+ res = {
+ 'name': name,
+ 'comment': 'test-comment-1',
+ 'model_id': model_id,
+ 'is_local': True,
+ 'resource': {
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ }
+ return res
+
+
+class ServingServicesApiInferenceTest(BaseTestCase):
+
+ def setUp(self):
+ self.maxDiff = None
+ super().setUp()
+ # insert project
+ with db.session_scope() as session:
+ project = Project()
+ project.name = 'test_project_name'
+ session.add(project)
+ session.flush([project])
+
+ participant = Participant()
+ participant.name = 'test_participant_name'
+ participant.domain_name = 'test_domain_name'
+ participant.project_id = project.id
+ session.add(participant)
+ session.flush([participant])
+
+ project_participant = ProjectParticipant()
+ project_participant.participant_id = participant.id
+ project_participant.project_id = project.id
+ session.add(project_participant)
+
+ model = Model()
+ model.name = 'test_model_name'
+ model.model_path = '/test_path/'
+ model.group_id = 1
+ model.uuid = 'test_uuid_1'
+ model.project_id = project.id
+
+ session.add(model)
+ session.commit()
+ self.project_id = project.id
+ self.model_id = model.id
+ self.model_uuid = model.uuid
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_post_serving_service_inference(self, mock_create_deployment: MagicMock):
+ # create
+ name = 'test-serving-service-1'
+ serving_service = _get_create_serving_service_input(name, model_id=self.model_id)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ data = self.get_response_data(response)
+ serving_model_id = data['id'] # get id from create response
+ # make input
+ fake_data = {
+ 'raw': Feature(int64_list=Int64List(value=np.random.randint(low=0, high=255, size=(128 * 128 * 3)))),
+ 'label': Feature(int64_list=Int64List(value=[1]))
+ }
+ fake_input = {
+ 'input_data': str(Example(features=Features(feature=fake_data))),
+ }
+ # post
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services/{serving_model_id}/inference',
+ data=fake_input)
+ self.assertEqual(HTTPStatus.INTERNAL_SERVER_ERROR, response.status_code)
+ data = self.get_response_data(response)
+ self.assertIsNone(data)
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.operate_serving_service')
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.inference_serving_service')
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_post_federal_serving_service_inference(self, mock_create_deployment: MagicMock,
+ mock_federal_inference: MagicMock,
+ mock_federal_operation: MagicMock):
+ mock_inference_response = service_pb2.ServingServiceInferenceResponse(
+ status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ mock_inference_response.data.update(TEST_OUTPUT)
+ mock_federal_inference.return_value = mock_inference_response
+ mock_federal_operation.return_value = service_pb2.ServingServiceResponse(
+ status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ # create serving service
+ name = 'test-serving-service-1'
+ serving_service = _get_create_serving_service_input(name, model_id=self.model_id)
+ serving_service['is_local'] = False
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ data = self.get_response_data(response)
+ serving_model_id = data['id'] # get id from create response
+
+ # mock signature runner
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ serving_model.signature = json.dumps(TEST_SIGNATURE)
+ session.commit()
+
+ # mock query runner
+ runner = QueryParticipantStatusRunner()
+ test_context = RunnerContext(0, RunnerInput())
+ runner_status, _ = runner.run(test_context)
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertEqual(serving_model.status, ServingModelStatus.LOADING)
+
+ # inference, make input
+ fake_data = {
+ 'raw': Feature(int64_list=Int64List(value=np.random.randint(low=0, high=255, size=(128 * 128 * 3)))),
+ 'label': Feature(int64_list=Int64List(value=[1]))
+ }
+ fake_input = {
+ 'input_data': str(Example(features=Features(feature=fake_data))),
+ }
+ # post
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services/{serving_model_id}/inference',
+ data=fake_input)
+ self.assertEqual(HTTPStatus.INTERNAL_SERVER_ERROR, response.status_code)
+ data = self.get_response_data(response)
+ self.assertIsNone(data)
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_federal_serving_service_inference_from_participant(self, mock_create_deployment: MagicMock):
+ mock_serving_uuid = 'test_uuid_1'
+ mock_serving_model_name = 'test_serving_model_name_1'
+ # create from participant
+ with db.session_scope() as session:
+ project = session.query(Project).get(self.project_id)
+ request = service_pb2.ServingServiceRequest()
+ request.operation_type = serving_pb2.ServingServiceType.SERVING_SERVICE_CREATE
+ request.serving_model_uuid = mock_serving_uuid
+ request.model_uuid = self.model_uuid
+ request.serving_model_name = mock_serving_model_name
+ NegotiatorServingService(session).handle_participant_request(request, project)
+
+ # get list
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services')
+ data = self.get_response_data(response)
+ serving_model_id = data[0]['id']
+
+ # get one
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{serving_model_id}')
+ data = self.get_response_data(response)
+ self.assertEqual('WAITING_CONFIG', data['status'])
+
+ # config
+ serving_service = {
+ 'comment': 'test-comment-1',
+ 'model_id': self.model_id,
+ 'resource': {
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ }
+ response = self.patch_helper(f'/api/v2/projects/{self.project_id}/serving_services/{serving_model_id}',
+ data=serving_service)
+ data = self.get_response_data(response)
+
+ # get one
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{serving_model_id}')
+ data = self.get_response_data(response)
+ self.assertEqual('LOADING', data['status'])
+
+ # inference from participant
+ query_key = 1
+ test_signature = json.dumps(TEST_SIGNATURE)
+ data_record = DatabaseFetcher.fetch_by_int_key(query_key, test_signature)
+ self.assertEqual(len(data_record['x0']), 1)
+ self.assertEqual(data_record['x0'][0], 0.1)
+ self.assertEqual(len(data_record['x1']), 4)
+ self.assertEqual(data_record['x1'][0], 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/apis_runner_test.py b/web_console_v2/api/fedlearner_webconsole/serving/apis_runner_test.py
new file mode 100644
index 000000000..772eaf2d8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/apis_runner_test.py
@@ -0,0 +1,307 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import unittest
+from http import HTTPStatus
+from multiprocessing import Queue
+from unittest.mock import MagicMock, patch
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.mmgr.models import Model, ModelJobGroup
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto import service_pb2, serving_pb2, common_pb2
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, ModelSignatureParserInput
+from fedlearner_webconsole.serving.models import ServingModel, ServingNegotiator, ServingModelStatus
+from fedlearner_webconsole.serving.runners import ModelSignatureParser, QueryParticipantStatusRunner, UpdateModelRunner
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from testing.common import BaseTestCase
+
+
+def _get_create_serving_service_input(name, model_id: int):
+ res = {
+ 'name': name,
+ 'comment': 'test-comment-1',
+ 'model_id': model_id,
+ 'is_local': True,
+ 'resource': {
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ }
+ return res
+
+
+def _fake_update_parsed_signature_sub_process(q: Queue, model_path: str):
+ mock_signature_type_dict = {
+ '4s_code_ctr': 'DT_FLOAT',
+ 'event_name_ctr': 'DT_FLOAT',
+ 'source_account_ctr': 'DT_FLOAT',
+ 'source_channel_ctr': 'DT_FLOAT',
+ 'example_id': 'DT_STRING',
+ 'raw_id': 'DT_INT64',
+ }
+ mock_parsed_example = serving_pb2.ServingServiceSignature()
+ for key, value in mock_signature_type_dict.items():
+ mock_example_input = serving_pb2.ServingServiceSignatureInput(name=key, type=value)
+ mock_parsed_example.inputs.append(mock_example_input)
+ q.put(mock_parsed_example)
+
+
+class ServingServicesApiRunnerTest(BaseTestCase):
+
+ def setUp(self):
+ self.maxDiff = None
+ super().setUp()
+ # insert project
+ with db.session_scope() as session:
+ project = Project()
+ project.name = 'test_project_name'
+ session.add(project)
+ session.flush([project])
+
+ participant = Participant()
+ participant.name = 'test_participant_name'
+ participant.domain_name = 'test_domain_name'
+ participant.project_id = project.id
+ session.add(participant)
+ session.flush([participant])
+
+ project_participant = ProjectParticipant()
+ project_participant.participant_id = participant.id
+ project_participant.project_id = project.id
+ session.add(project_participant)
+
+ model_job_group = ModelJobGroup()
+ session.add(model_job_group)
+ session.flush([model_job_group])
+
+ model = Model()
+ model.name = 'test_model_name'
+ model.model_path = '/test_path/'
+ model.group_id = model_job_group.id
+ model.uuid = 'test_uuid_1'
+ model.project_id = project.id
+ model.version = 1
+
+ session.add(model)
+ session.commit()
+ self.project_id = project.id
+ self.model_id = model.id
+ self.model_uuid = model.uuid
+ self.model_group_id = model_job_group.id
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.operate_serving_service')
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_query_participant_runner(self, mock_create_deployment: MagicMock, mock_federal_operation: MagicMock):
+ mock_federal_operation.return_value = service_pb2.ServingServiceResponse(
+ status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ # create federal serving service
+ name = 'test-serving-service-1'
+ serving_service = _get_create_serving_service_input(name, model_id=self.model_id)
+ serving_service['is_local'] = False
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+
+ # create another federal serving service
+ name = 'test-serving-service-2'
+ serving_service = _get_create_serving_service_input(name, model_id=self.model_id)
+ serving_service['is_local'] = False
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+
+ # create another local serving service
+ name = 'test-serving-service-3'
+ serving_service = _get_create_serving_service_input(name, model_id=self.model_id)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+
+ # check status
+ with db.session_scope() as session:
+ query = session.query(ServingNegotiator)
+ query = query.filter(ServingNegotiator.is_local.is_(False))
+ query = query.outerjoin(
+ ServingNegotiator.serving_model).filter(ServingModel.status == ServingModelStatus.PENDING_ACCEPT)
+ all_records = query.all()
+ self.assertEqual(len(all_records), 2)
+
+ # call query runner
+ runner = QueryParticipantStatusRunner()
+ test_context = RunnerContext(0, RunnerInput())
+ runner_status, _ = runner.run(test_context)
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+
+ # check status again
+ with db.session_scope() as session:
+ query = session.query(ServingNegotiator)
+ query = query.filter(ServingNegotiator.is_local.is_(False))
+ query = query.outerjoin(
+ ServingNegotiator.serving_model).filter(ServingModel.status == ServingModelStatus.PENDING_ACCEPT)
+ all_records = query.all()
+ self.assertEqual(len(all_records), 0)
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_query_participant_runner_exception_branch(self, mock_create_deployment: MagicMock):
+ serving_model = ServingModel()
+ serving_model.project_id = self.project_id
+ serving_model.name = 'test_serving_model_name'
+ serving_model.status = ServingModelStatus.PENDING_ACCEPT
+ serving_negotiator = ServingNegotiator()
+ serving_negotiator.project_id = self.project_id
+ serving_negotiator.is_local = False
+ with db.session_scope() as session:
+ session.add(serving_model)
+ session.flush([serving_model])
+ serving_negotiator.serving_model_id = serving_model.id
+ session.add(serving_negotiator)
+ session.commit()
+
+ # call query runner
+ runner = QueryParticipantStatusRunner()
+ test_context = RunnerContext(0, RunnerInput())
+ runner_status, _ = runner.run(test_context)
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.operate_serving_service')
+ @patch('fedlearner_webconsole.serving.services.TensorflowServingService.get_model_signature')
+ @patch('fedlearner_webconsole.serving.runners._update_parsed_signature', _fake_update_parsed_signature_sub_process)
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_parse_signature_runner(self, mock_create_deployment: MagicMock, get_model_signature: MagicMock,
+ mock_federal_operation: MagicMock):
+ mock_federal_operation.return_value = service_pb2.ServingServiceResponse(
+ status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ get_model_signature.return_value = {
+ 'inputs': {
+ 'act1_f': {
+ 'name': 'act1_f:0',
+ 'dtype': 'DT_FLOAT',
+ 'tensorShape': {
+ 'unknownRank': True
+ }
+ },
+ 'examples': {
+ 'name': 'examples:0',
+ 'dtype': 'DT_STRING',
+ 'tensorShape': {
+ 'unknownRank': True
+ }
+ }
+ },
+ 'outputs': {
+ 'output': {
+ 'name': 'Sigmoid:0',
+ 'dtype': 'DT_FLOAT',
+ 'tensorShape': {
+ 'unknownRank': True
+ }
+ }
+ },
+ 'methodName': 'tensorflow/serving/predict'
+ }
+ # create serving service
+ name = 'test-serving-service-1'
+ serving_service = _get_create_serving_service_input(name, model_id=self.model_id)
+ serving_service['is_local'] = False
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ data = self.get_response_data(response)
+ serving_model_id = data['id'] # get id from create response
+
+ # call signature runner
+ runner = ModelSignatureParser()
+ runner_input = RunnerInput(model_signature_parser_input=ModelSignatureParserInput(
+ serving_model_id=serving_model_id))
+ runner.run(RunnerContext(0, runner_input))
+
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertEqual(serving_model.name, name)
+ self.assertEqual(serving_model.status, ServingModelStatus.PENDING_ACCEPT)
+ signature_dict = json.loads(serving_model.signature)
+ self.assertEqual(len(signature_dict['inputs']), 6)
+ self.assertIn('from_participants', signature_dict)
+ self.assertIn('act1_f', signature_dict['from_participants'])
+ self.assertIn('outputs', signature_dict)
+ self.assertIn('output', signature_dict['outputs'])
+ serving_negotiator = session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model_id).one_or_none()
+ self.assertIsNotNone(serving_negotiator)
+ self.assertEqual(serving_negotiator.project_id, self.project_id)
+ self.assertEqual(serving_negotiator.with_label, True)
+ raw_signature_dict = json.loads(serving_negotiator.raw_signature)
+ self.assertIn('inputs', raw_signature_dict)
+ self.assertIn('outputs', raw_signature_dict)
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.operate_serving_service')
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_update_model_runner(self, mock_create_deployment: MagicMock, mock_federal_operation: MagicMock):
+ mock_federal_operation.return_value = service_pb2.ServingServiceResponse(
+ status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ # create federal serving service
+ name = 'test-auto-update-1'
+ serving_service = {
+ 'name': name,
+ 'comment': 'test-comment-1',
+ 'model_group_id': self.model_group_id,
+ 'is_local': True,
+ 'resource': {
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ }
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ data = self.get_response_data(response)
+ serving_model_id = data['id'] # get id from create response
+
+ # create anothor model
+ with db.session_scope() as session:
+ model = Model()
+ model.name = 'test_model_name_2'
+ model.model_path = '/test_path_2/'
+ model.group_id = self.model_group_id
+ model.uuid = 'test_uuid_2'
+ model.project_id = self.project_id
+ model.version = 2
+ session.add(model)
+ session.commit()
+ model_id_2 = model.id
+
+ # check status
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).filter_by(id=serving_model_id).one_or_none()
+ self.assertEqual(serving_model.model_group_id, self.model_group_id)
+ self.assertEqual(serving_model.model_id, self.model_id)
+
+ # call update model runner
+ runner = UpdateModelRunner()
+ test_context = RunnerContext(0, RunnerInput())
+ runner_status, _ = runner.run(test_context)
+ self.assertEqual(runner_status, RunnerStatus.DONE)
+
+ # check status again
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).filter_by(id=serving_model_id).one_or_none()
+ self.assertEqual(serving_model.model_group_id, self.model_group_id)
+ self.assertEqual(serving_model.model_id, model_id_2)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/apis_test.py b/web_console_v2/api/fedlearner_webconsole/serving/apis_test.py
new file mode 100644
index 000000000..b514ff9ae
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/apis_test.py
@@ -0,0 +1,449 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import unittest
+import urllib.parse
+from datetime import datetime, timezone
+from http import HTTPStatus
+from unittest.mock import MagicMock, patch
+
+from envs import Envs
+from fedlearner_webconsole.composer.models import SchedulerItem
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.k8s.models import Pod, PodState
+from fedlearner_webconsole.mmgr.models import Model, ModelJobGroup
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.serving.models import ServingModel, ServingNegotiator
+from fedlearner_webconsole.serving.remote import register_remote_serving
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from testing.common import BaseTestCase
+from testing.fake_remote_serving import FakeRemoteServing
+
+
+def _get_create_serving_service_input(name, project_id: int):
+ res = {
+ 'name': name,
+ 'comment': 'test-comment-1',
+ 'cpu_per_instance': '2000m',
+ 'memory_per_instance': '2Gi',
+ 'instance_num': 3,
+ 'project_id': project_id,
+ 'is_local': True
+ }
+ return res
+
+
+def _get_create_serving_service_input_v2(name, model_id: int):
+ res = {
+ 'name': name,
+ 'comment': 'test-comment-1',
+ 'model_id': model_id,
+ 'is_local': True,
+ 'resource': {
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ }
+ return res
+
+
+class ServingServicesApiV2Test(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ # insert project
+ with db.session_scope() as session:
+ project = Project()
+ project.name = 'test_project_name'
+ session.add(project)
+ session.flush([project])
+
+ participant = Participant()
+ participant.name = 'test_participant_name'
+ participant.domain_name = 'test_domain_name'
+ participant.project_id = project.id
+ session.add(participant)
+ session.flush([participant])
+
+ project_participant = ProjectParticipant()
+ project_participant.participant_id = participant.id
+ project_participant.project_id = project.id
+ session.add(project_participant)
+
+ model_job_group = ModelJobGroup()
+ session.add(model_job_group)
+ session.flush([model_job_group])
+
+ model_1 = Model()
+ model_1.name = 'test_model_name_1'
+ model_1.model_path = '/test_path_1/'
+ model_1.uuid = 'test_uuid_1'
+ model_1.project_id = project.id
+ model_1.version = 1
+ model_1.group_id = model_job_group.id
+ model_2 = Model()
+ model_2.name = 'test_model_name_2'
+ model_2.model_path = '/test_path_2/'
+ model_2.uuid = 'test_uuid_2'
+ model_2.project_id = project.id
+ model_2.version = 2
+ model_2.group_id = model_job_group.id
+ session.add_all([model_1, model_2])
+
+ session.commit()
+ self.project_id = project.id
+ self.model_id_1 = model_1.id
+ self.model_group_id = model_job_group.id
+ self.model_id_2 = model_2.id
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_create_serving_service(self, mock_create_deployment: MagicMock):
+ # create serving service
+ name = 'test-serving-service-1'
+ serving_service = _get_create_serving_service_input_v2(name, model_id=self.model_id_1)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(HTTPStatus.CREATED, response.status_code)
+ data = self.get_response_data(response)
+ self.assertEqual(0, data['model_group_id'])
+ serving_model_id = data['id'] # get id from create response
+
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertEqual(serving_model.name, name)
+ self.assertEqual('/test_path_1/exported_models', serving_model.model_path)
+ serving_deployment = serving_model.serving_deployment
+ deployment_name_substr = f'serving-{serving_model_id}-'
+ self.assertIn(deployment_name_substr, serving_deployment.deployment_name)
+ self.assertEqual(serving_deployment.resource, json.dumps({
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }))
+ serving_negotiator = session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model_id).one_or_none()
+ self.assertIsNotNone(serving_negotiator)
+ self.assertEqual(serving_negotiator.project_id, self.project_id)
+
+ mock_create_deployment.assert_called_once()
+
+ # create same name
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ data = self.get_response_data(response)
+ self.assertEqual(HTTPStatus.CONFLICT, response.status_code)
+ self.assertIsNone(data)
+
+ # resource format error
+ serving_service['resource'] = {
+ 'cpu': 2,
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, response.status_code)
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_create_auto_update_serving_service(self, mock_create_deployment: MagicMock):
+ # create serving service
+ name = 'test-auto-update-1'
+ serving_service = {
+ 'name': name,
+ 'comment': 'test-comment-1',
+ 'model_group_id': self.model_group_id,
+ 'is_local': True,
+ 'resource': {
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ }
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(HTTPStatus.CREATED, response.status_code)
+ data = self.get_response_data(response)
+ self.assertEqual(self.model_id_2, data['model_id'])
+ serving_model_id = data['id'] # get id from create response
+
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertEqual(serving_model.name, name)
+ self.assertEqual('/test_path_2/exported_models', serving_model.model_path)
+
+ mock_create_deployment.assert_called_once()
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_get_serving_services(self, mock_create_deployment: MagicMock):
+ # create
+ name1 = 'test-get-services-1'
+ serving_service = _get_create_serving_service_input_v2(name1, model_id=self.model_id_1)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(HTTPStatus.CREATED, response.status_code)
+ name2 = 'test-get-services-2'
+ serving_service = _get_create_serving_service_input_v2(name2, model_id=self.model_id_1)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(HTTPStatus.CREATED, response.status_code)
+ # get list
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services')
+ data = self.get_response_data(response)
+ self.assertEqual(2, len(data))
+ self.assertIn(data[0]['name'], [name1, name2])
+ self.assertIn(data[1]['name'], [name1, name2])
+ self.assertEqual(self.project_id, data[0]['project_id'])
+ self.assertEqual('LOADING', data[0]['status'])
+ self.assertEqual('UNKNOWN', data[0]['instance_num_status'])
+
+ # get with filter
+ filter_param = urllib.parse.quote(f'(name="{name1}")')
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services?filter={filter_param}')
+ data = self.get_response_data(response)
+ self.assertEqual(1, len(data))
+
+ filter_param = urllib.parse.quote('(name="test-get-services-3")') # test not found
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services?filter={filter_param}')
+ data = self.get_response_data(response)
+ self.assertEqual(0, len(data))
+
+ filter_param = urllib.parse.quote('(keyword~="services-1")')
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services?filter={filter_param}')
+ data = self.get_response_data(response)
+ self.assertEqual(1, len(data))
+
+ sorter_param = urllib.parse.quote('created_at asc')
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services?order_by={sorter_param}')
+ data = self.get_response_data(response)
+ self.assertEqual(2, len(data))
+ self.assertEqual([name1, name2], [data[0]['name'], data[1]['name']])
+
+ sorter_param = urllib.parse.quote('created_at desc')
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services?order_by={sorter_param}')
+ data = self.get_response_data(response)
+ self.assertEqual(2, len(data))
+ self.assertEqual([name2, name1], [data[0]['name'], data[1]['name']])
+
+ sorter_param = urllib.parse.quote('something_unsupported desc')
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services?order_by={sorter_param}')
+ data = self.get_response_data(response)
+ self.assertIsNone(data)
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.get_pods_info')
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_get_serving_service(self, mock_create_deployment: MagicMock, mock_get_pods_info):
+ test_datetime = datetime(2022, 1, 1, 8, 8, 8, tzinfo=timezone.utc)
+ fake_pods = [
+ Pod(name='pod0', state=PodState.FAILED, creation_timestamp=to_timestamp(test_datetime)),
+ Pod(name='pod1', state=PodState.RUNNING, creation_timestamp=to_timestamp(test_datetime) - 1),
+ Pod(name='pod2', state=PodState.SUCCEEDED, creation_timestamp=to_timestamp(test_datetime) + 1)
+ ]
+ mock_get_pods_info.return_value = fake_pods
+ # create
+ name1 = 'test-get-services-1'
+ serving_service = _get_create_serving_service_input_v2(name1, model_id=self.model_id_1)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(HTTPStatus.CREATED, response.status_code)
+ data = self.get_response_data(response)
+ serving_service_id = data['id']
+ # get one
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{serving_service_id}')
+ data = self.get_response_data(response)
+ self.assertEqual(name1, data['name'])
+ self.assertEqual({'cpu': '2', 'memory': '2', 'replicas': 3}, data['resource'])
+ self.assertEqual(f'/api/v2/projects/{self.project_id}/serving_services/{serving_service_id}/inference',
+ data['endpoint'])
+ self.assertEqual('UNKNOWN', data['instance_num_status'])
+ self.assertEqual(['pod0', 'pod1', 'pod2'], [x['name'] for x in data['instances']])
+ sorter_param = urllib.parse.quote('created_at asc')
+ response = self.get_helper(
+ f'/api/v2/projects/{self.project_id}/serving_services/{serving_service_id}?order_by={sorter_param}')
+ data = self.get_response_data(response)
+ self.assertEqual(3, len(data['instances']))
+ self.assertEqual(['pod1', 'pod0', 'pod2'], [x['name'] for x in data['instances']])
+ sorter_param = urllib.parse.quote('created_at desc')
+ response = self.get_helper(
+ f'/api/v2/projects/{self.project_id}/serving_services/{serving_service_id}?order_by={sorter_param}')
+ data = self.get_response_data(response)
+ self.assertEqual(3, len(data['instances']))
+ self.assertEqual(['pod2', 'pod0', 'pod1'], [x['name'] for x in data['instances']])
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_update_serving_service(self, mock_create_deployment: MagicMock):
+ # create
+ name = 'test-update-service-1'
+ serving_service = _get_create_serving_service_input_v2(name, model_id=self.model_id_1)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ data = self.get_response_data(response)
+ service_id = data['id'] # get id from create response
+ # update comments
+ new_comment = 'test-comment-2'
+ serving_service = {
+ 'comment': new_comment,
+ 'resource': {
+ 'cpu': '2',
+ 'memory': '2',
+ 'replicas': 3,
+ }
+ }
+ response = self.patch_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}',
+ data=serving_service)
+ data = self.get_response_data(response)
+ self.assertEqual(data['comment'], new_comment)
+
+ # change from model_id to model_group_id
+ serving_service = {
+ 'model_group_id': self.model_group_id,
+ }
+ response = self.patch_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}',
+ data=serving_service)
+ data = self.get_response_data(response)
+ self.assertEqual(self.model_id_2, data['model_id'])
+ self.assertEqual(self.model_group_id, data['model_group_id'])
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(service_id)
+ self.assertEqual('/test_path_2/exported_models', serving_model.model_path)
+
+ # change from model_group_id to model_id
+ serving_service = {
+ 'model_id': self.model_id_1,
+ }
+ response = self.patch_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}',
+ data=serving_service)
+ data = self.get_response_data(response)
+ self.assertEqual(0, data['model_group_id'])
+ self.assertEqual(self.model_id_1, data['model_id'])
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(service_id)
+ self.assertEqual('/test_path_1/exported_models', serving_model.model_path)
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ @patch('fedlearner_webconsole.serving.services.k8s_client')
+ def test_delete_serving_service(self, mock_create_deployment: MagicMock, mock_k8s_client: MagicMock):
+ mock_k8s_client.delete_config_map = MagicMock()
+ mock_k8s_client.delete_deployment = MagicMock()
+ mock_k8s_client.delete_service = MagicMock()
+ # create
+ name = 'test-delete-service-1'
+ serving_service = _get_create_serving_service_input_v2(name, model_id=self.model_id_1)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ data = self.get_response_data(response)
+ service_id = data['id'] # get id from create response
+ # delete
+ response = self.delete_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}')
+ self.assertEqual(response.status_code, HTTPStatus.NO_CONTENT)
+ # get
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}')
+ self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+ data = self.get_response_data(response)
+ self.assertIsNone(data)
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ @patch('fedlearner_webconsole.serving.services.k8s_client.get_pod_log')
+ def test_get_serving_service_instance_log(self, mock_query_log: MagicMock, mock_create_deployment: MagicMock):
+ # create
+ name = 'test-get-service-instance-log-1'
+ serving_service = _get_create_serving_service_input_v2(name, model_id=self.model_id_1)
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ mock_create_deployment.assert_called_once()
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ data = self.get_response_data(response)
+ service_id = data['id'] # get id from create response
+ # get
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}')
+ data = self.get_response_data(response)
+ self.assertEqual(len(data['instances']), 1)
+ instance_name = data['instances'][0]['name'] # get id from create response
+ # get log
+ mock_query_log.return_value = ['test', 'hello']
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}'
+ f'/instances/{instance_name}/log?tail_lines={500}')
+ mock_query_log.assert_called_once_with(instance_name, namespace=Envs.K8S_NAMESPACE, tail_lines=500)
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ self.assertCountEqual(self.get_response_data(response), ['test', 'hello'])
+
+ @patch('fedlearner_webconsole.utils.flask_utils.get_current_sso', MagicMock(return_value='test-sso'))
+ def test_remote_platform_serving_service(self):
+ reckon_remote_serving = FakeRemoteServing()
+ register_remote_serving(FakeRemoteServing.SERVING_PLATFORM, reckon_remote_serving)
+ # get
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/remote_platforms')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(1, len(data))
+ self.assertEqual(FakeRemoteServing.SERVING_PLATFORM, data[0]['platform'])
+ self.assertEqual('', data[0]['payload'])
+
+ # create serving service
+ name = 'test-remote-serving-1'
+ serving_service = {
+ 'name': name,
+ 'model_group_id': self.model_group_id,
+ 'is_local': True,
+ 'remote_platform': {
+ 'platform': FakeRemoteServing.SERVING_PLATFORM,
+ 'payload': 'test-payload',
+ }
+ }
+ response = self.post_helper(f'/api/v2/projects/{self.project_id}/serving_services', data=serving_service)
+ self.assertEqual(HTTPStatus.CREATED, response.status_code)
+ data = self.get_response_data(response)
+ self.assertEqual({'platform': 'unittest_mock', 'payload': 'test-payload'}, data['remote_platform'])
+ service_id = data['id']
+
+ # get list
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services')
+ data = self.get_response_data(response)
+ self.assertEqual(1, len(data))
+ self.assertEqual('AVAILABLE', data[0]['status'])
+ self.assertTrue(data[0]['support_inference'])
+
+ # get one
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}')
+ data = self.get_response_data(response)
+ self.assertEqual('test_deploy_url', data['endpoint'])
+ self.assertEqual('AVAILABLE', data['status'])
+ self.assertTrue(data['support_inference'])
+ self.assertEqual({'platform': 'unittest_mock', 'payload': 'test-payload'}, data['remote_platform'])
+
+ # change from model_group_id to model_id
+ serving_service = {
+ 'model_id': self.model_id_1,
+ }
+ response = self.patch_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}',
+ data=serving_service)
+ data = self.get_response_data(response)
+ self.assertEqual(0, data['model_group_id'])
+ self.assertEqual(self.model_id_1, data['model_id'])
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(service_id)
+ self.assertEqual('/test_path_1/exported_models', serving_model.model_path)
+ item = session.query(SchedulerItem.id).filter(SchedulerItem.name.like(f'%{name}%')).first()
+ self.assertIsNone(item)
+
+ # delete
+ response = self.delete_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}')
+ self.assertEqual(response.status_code, HTTPStatus.NO_CONTENT)
+ response = self.get_helper(f'/api/v2/projects/{self.project_id}/serving_services/{service_id}')
+ data = self.get_response_data(response)
+ self.assertIsNone(data)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/database_fetcher.py b/web_console_v2/api/fedlearner_webconsole/serving/database_fetcher.py
new file mode 100644
index 000000000..08d1c7fb8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/database_fetcher.py
@@ -0,0 +1,43 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+import json
+
+
+class DatabaseFetcher:
+
+ @staticmethod
+ def fetch_by_int_key(query_key: int, signature: str) -> dict:
+ result = {'raw_id': [query_key], 'example_id': [str(query_key)]}
+ signature_dict = json.loads(signature)
+ signature_input = signature_dict['inputs']
+ for item in signature_input:
+ input_name = item['name']
+ input_type = item['type']
+ result[input_name] = []
+ if 'dim' in item:
+ dim = int(item['dim'][0])
+ else:
+ dim = 1
+ for _ in range(0, dim):
+ # TODO(lixiaoguang.01) fetch from ABase, match name
+ if input_type == 'DT_STRING':
+ result[input_name].append('')
+ elif input_type in ('DT_FLOAT', 'DT_DOUBLE'):
+ result[input_name].append(0.1)
+ elif input_type in ('DT_INT64', 'DT_INT32'):
+ result[input_name].append(1)
+ return result
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/metrics.py b/web_console_v2/api/fedlearner_webconsole/serving/metrics.py
new file mode 100644
index 000000000..9ad3e2077
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/metrics.py
@@ -0,0 +1,30 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from fedlearner_webconsole.serving.models import ServingModel
+from fedlearner_webconsole.utils import metrics
+
+
+def serving_metrics_emit_counter(name: str, serving_model: ServingModel = None):
+ if serving_model is None:
+ metrics.emit_counter(name, 1)
+ return
+ metrics.emit_counter(name,
+ 1,
+ tags={
+ 'project_id': str(serving_model.project_id),
+ 'serving_model_id': str(serving_model.id),
+ 'serving_model_name': serving_model.name,
+ })
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/models.py b/web_console_v2/api/fedlearner_webconsole/serving/models.py
new file mode 100644
index 000000000..64a138fa9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/models.py
@@ -0,0 +1,155 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+import json
+
+from sqlalchemy.sql.schema import UniqueConstraint, Index
+
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.proto import serving_pb2
+from fedlearner_webconsole.utils.pp_datetime import now, to_timestamp
+from fedlearner_webconsole.mmgr.models import ModelType
+
+
+class ServingModelStatus(enum.Enum):
+ UNKNOWN = 0
+ LOADING = 1
+ AVAILABLE = 2
+ UNLOADING = 3
+ PENDING_ACCEPT = 4
+ DELETED = 5
+ WAITING_CONFIG = 6
+
+
+class ServingDeploymentStatus(enum.Enum):
+ UNAVAILABLE = 0
+ AVAILABLE = 1
+
+
+class ServingModel(db.Model):
+ __tablename__ = 'serving_models_v2'
+ __table_args__ = (UniqueConstraint('name', name='uniq_name'), default_table_args('serving models'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ project_id = db.Column(db.Integer, nullable=False, comment='project id')
+ name = db.Column(db.String(255), comment='name')
+ serving_deployment_id = db.Column(db.Integer, comment='serving deployment db id')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment')
+ model_id = db.Column(db.Integer, comment='model id')
+ model_type = db.Column(db.Enum(ModelType, native_enum=False, length=64, create_constraint=False),
+ default=ModelType.NN_MODEL,
+ comment='model type')
+ model_path = db.Column(db.String(255), default=None, comment='model\'s path')
+ model_group_id = db.Column(db.Integer, comment='model group id for auto update scenario')
+ pending_model_id = db.Column(db.Integer, comment='model id when waiting for participants\' config')
+ pending_model_group_id = db.Column(db.Integer, comment='model group id when waiting for participants\' config')
+ signature = db.Column(db.Text(), default='', comment='model signature')
+ status = db.Column(db.Enum(ServingModelStatus, native_enum=False, length=64, create_constraint=False),
+ default=ServingModelStatus.UNKNOWN,
+ comment='status')
+ endpoint = db.Column(db.String(255), comment='endpoint')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created_at', default=now)
+ updated_at = db.Column(db.DateTime(timezone=True), comment='updated_at', default=now, onupdate=now)
+ extra = db.Column(db.Text(), comment='extra')
+
+ project = db.relationship('Project', primaryjoin='Project.id == foreign(ServingModel.project_id)')
+ serving_deployment = db.relationship('ServingDeployment',
+ primaryjoin='ServingDeployment.id == '
+ 'foreign(ServingModel.serving_deployment_id)')
+ model = db.relationship('Model', primaryjoin='Model.id == foreign(ServingModel.model_id)')
+ pending_model = db.relationship('Model', primaryjoin='Model.id == foreign(ServingModel.pending_model_id)')
+ model_group = db.relationship('ModelJobGroup',
+ primaryjoin='ModelJobGroup.id == foreign(ServingModel.model_group_id)')
+ pending_model_group = db.relationship('ModelJobGroup',
+ primaryjoin='ModelJobGroup.id == '
+ 'foreign(ServingModel.pending_model_group_id)')
+
+ def to_serving_service(self) -> serving_pb2.ServingService:
+ return serving_pb2.ServingService(id=self.id,
+ project_id=self.project_id,
+ name=self.name,
+ comment=self.comment,
+ is_local=True,
+ status=self.status.name,
+ support_inference=False,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at))
+
+ def to_serving_service_detail(self) -> serving_pb2.ServingServiceDetail:
+ detail = serving_pb2.ServingServiceDetail(id=self.id,
+ project_id=self.project_id,
+ name=self.name,
+ comment=self.comment,
+ model_id=self.model_id,
+ model_group_id=self.model_group_id,
+ model_type=self.model_type.name,
+ is_local=True,
+ endpoint=self.endpoint,
+ signature=self.signature,
+ status=self.status.name,
+ support_inference=False,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at))
+ if self.serving_deployment.is_remote_serving():
+ platform_config: dict = json.loads(self.serving_deployment.deploy_platform)
+ detail.remote_platform.CopyFrom(
+ serving_pb2.ServingServiceRemotePlatform(
+ platform=platform_config['platform'],
+ payload=platform_config['payload'],
+ ))
+ return detail
+
+
+class ServingDeployment(db.Model):
+ __tablename__ = 'serving_deployments_v2'
+ __table_args__ = (default_table_args('serving deployments in webconsole'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ project_id = db.Column(db.Integer, nullable=False, comment='project id')
+ deployment_name = db.Column(db.String(255), comment='deployment name')
+ resource = db.Column('rsc', db.String(255), comment='resource')
+ endpoint = db.Column(db.String(255), comment='endpoint')
+ deploy_platform = db.Column(db.Text(), comment='deploy platform. None means inside this platform')
+ status = db.Column(db.Enum(ServingDeploymentStatus, native_enum=False, length=64, create_constraint=False),
+ default=ServingDeploymentStatus.UNAVAILABLE,
+ comment='status')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created_at', default=now)
+ extra = db.Column(db.Text(), comment='extra')
+
+ project = db.relationship('Project', primaryjoin='Project.id == foreign(ServingDeployment.project_id)')
+
+ def is_remote_serving(self) -> bool:
+ return self.deploy_platform is not None
+
+
+class ServingNegotiator(db.Model):
+ __tablename__ = 'serving_negotiators_v2'
+ __table_args__ = (Index('idx_serving_model_uuid',
+ 'serving_model_uuid'), default_table_args('serving negotiators in webconsole'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ project_id = db.Column(db.Integer, nullable=False, comment='project id')
+ serving_model_id = db.Column(db.Integer, nullable=False, comment='serving model id')
+ is_local = db.Column(db.Boolean, comment='can serving locally')
+ with_label = db.Column(db.Boolean, comment='federal side with label or not')
+ serving_model_uuid = db.Column(db.String(255), comment='uuid for federal model')
+ feature_dataset_id = db.Column(db.Integer, comment='feature dataset id')
+ data_source_map = db.Column(db.Text(), comment='where to get model inference arguments')
+ raw_signature = db.Column(db.Text(), comment='save raw signature from tf serving')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created_at', default=now)
+ extra = db.Column(db.Text(), comment='extra')
+
+ project = db.relationship('Project', primaryjoin='Project.id == foreign(ServingNegotiator.project_id)')
+ serving_model = db.relationship('ServingModel',
+ primaryjoin='ServingModel.id == '
+ 'foreign(ServingNegotiator.serving_model_id)')
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/models_test.py b/web_console_v2/api/fedlearner_webconsole/serving/models_test.py
new file mode 100644
index 000000000..6bcff5367
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/models_test.py
@@ -0,0 +1,100 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import unittest
+
+from google.protobuf.json_format import MessageToDict
+
+from fedlearner_webconsole.mmgr.models import ModelJobGroup, Model, ModelType
+from fedlearner_webconsole.proto import serving_pb2
+from fedlearner_webconsole.serving.models import ServingModel, ServingModelStatus, ServingDeployment
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.db import db
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ServingModelTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ model_job_group = ModelJobGroup()
+ session.add(model_job_group)
+ session.flush([model_job_group])
+
+ model = Model()
+ model.name = 'test_model_name_1'
+ model.project_id = 1
+ model.group_id = model_job_group.id
+ session.add(model)
+ session.flush([model])
+
+ deployment = ServingDeployment()
+ deployment.project_id = 1
+ deploy_config = serving_pb2.RemoteDeployConfig(platform='test-platform',
+ payload='test-payload',
+ deploy_name='privacy-platform-test-serving',
+ model_src_path='')
+ deployment.deploy_platform = json.dumps(MessageToDict(deploy_config))
+ session.add(deployment)
+ session.flush([deployment])
+
+ serving_model = ServingModel()
+ serving_model.project_id = 1
+ serving_model.name = 'test-serving-model-1'
+ serving_model.model_id = model.id
+ serving_model.serving_deployment_id = deployment.id
+ session.add(serving_model)
+ session.commit()
+
+ self.model_group_id = model_job_group.id
+ self.model_id = model.id
+ self.serving_model_id = serving_model.id
+
+ def test_to_serving_service_detail(self):
+ with db.session_scope() as session:
+ serving_model: ServingModel = session.query(ServingModel).get(self.serving_model_id)
+ expected_detail = serving_pb2.ServingServiceDetail(id=self.serving_model_id,
+ project_id=1,
+ name='test-serving-model-1',
+ model_id=self.model_id,
+ model_type=ModelType.NN_MODEL.name,
+ is_local=True,
+ status=ServingModelStatus.UNKNOWN.name,
+ support_inference=False)
+ self.assertPartiallyEqual(to_dict(expected_detail),
+ to_dict(serving_model.to_serving_service_detail()),
+ ignore_fields=['created_at', 'updated_at', 'remote_platform'])
+ self.assertEqual(0, serving_model.to_serving_service_detail().model_group_id)
+
+ serving_model.model_group_id = self.serving_model_id
+ expected_detail.model_group_id = self.model_group_id
+ self.assertPartiallyEqual(to_dict(expected_detail),
+ to_dict(serving_model.to_serving_service_detail()),
+ ignore_fields=['created_at', 'updated_at', 'remote_platform'])
+
+ expected_detail.remote_platform.CopyFrom(
+ serving_pb2.ServingServiceRemotePlatform(
+ platform='test-platform',
+ payload='test-payload',
+ ))
+ self.assertPartiallyEqual(to_dict(expected_detail),
+ to_dict(serving_model.to_serving_service_detail()),
+ ignore_fields=['created_at', 'updated_at'])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/participant_fetcher.py b/web_console_v2/api/fedlearner_webconsole/serving/participant_fetcher.py
new file mode 100644
index 000000000..9a90e17db
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/participant_fetcher.py
@@ -0,0 +1,76 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+import json
+
+from google.protobuf.json_format import MessageToDict
+
+from fedlearner_webconsole.db import Session
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.proto import serving_pb2
+from fedlearner_webconsole.serving.models import ServingNegotiator
+from fedlearner_webconsole.serving.services import NegotiatorServingService
+
+
+class ParticipantFetcher:
+ _TF_DT_INT_TYPE_SET = ['DT_INT32', 'DT_INT16', 'DT_UINT16', 'DT_INT8', 'DT_UINT8']
+
+ def __init__(self, session: Session = None):
+ self._session = session
+
+ def fetch(self, serving_negotiator: ServingNegotiator, example_id: str) -> dict:
+ if serving_negotiator.is_local:
+ return {}
+ resp = NegotiatorServingService(self._session).participant_serving_service_inference(
+ serving_negotiator, example_id)
+ if resp.code != serving_pb2.SERVING_SERVICE_SUCCESS:
+ raise InternalException(resp.msg)
+ data = MessageToDict(resp.data)
+ participant_result = data['result']
+ signature = serving_negotiator.serving_model.signature
+ signature_dict = json.loads(signature)
+ signature_extend = signature_dict['from_participants']
+ assert len(signature_extend) == len(participant_result), \
+ f'Dim not match, need {len(signature_extend)}, got {len(participant_result)}'
+ result = {}
+ for item_key in participant_result:
+ if item_key not in signature_extend and len(participant_result) > 1:
+ continue
+ input_key = item_key
+ if len(participant_result) == 1:
+ input_key = list(signature_extend.keys())[0]
+ dtype = participant_result[item_key]['dtype']
+ result[input_key] = self._get_value_by_dtype(dtype, participant_result[item_key])
+ return result
+
+ def _get_value_by_dtype(self, dtype: str, input_data: dict):
+ if dtype == 'DT_FLOAT':
+ return input_data['floatVal']
+ if dtype == 'DT_DOUBLE':
+ return input_data['doubleVal']
+ if dtype == self._TF_DT_INT_TYPE_SET:
+ return input_data['intVal']
+ if dtype == 'DT_INT64':
+ return input_data['int64Val']
+ if dtype == 'DT_UINT32':
+ return input_data['uint32Val']
+ if dtype == 'DT_UINT64':
+ return input_data['uint64Val']
+ if dtype == 'DT_STRING':
+ return input_data['stringVal']
+ if dtype == 'DT_BOOL':
+ return input_data['boolVal']
+ return ''
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/remote.py b/web_console_v2/api/fedlearner_webconsole/serving/remote.py
new file mode 100644
index 000000000..2a7252751
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/remote.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import abstractmethod, ABCMeta
+from typing import Optional
+
+from fedlearner_webconsole.proto.serving_pb2 import RemoteDeployConfig, RemoteDeployState
+
+
+class IRemoteServing(metaclass=ABCMeta):
+ """Deploy model on remote third-party serving platform
+
+ """
+
+ @abstractmethod
+ def deploy_model(self, creator: str, config: RemoteDeployConfig) -> Optional[int]:
+ pass
+
+ @abstractmethod
+ def get_deploy_url(self, config: RemoteDeployConfig) -> str:
+ pass
+
+ @abstractmethod
+ def validate_config(self, config: RemoteDeployConfig) -> bool:
+ pass
+
+ @abstractmethod
+ def get_deploy_status(self, config: RemoteDeployConfig) -> RemoteDeployState:
+ return RemoteDeployState.REMOTE_DEPLOY_READY
+
+ @abstractmethod
+ def undeploy_model(self, config: RemoteDeployConfig):
+ pass
+
+
+supported_remote_serving = {}
+
+
+def register_remote_serving(name: str, serving: IRemoteServing):
+ supported_remote_serving[name] = serving
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/runners.py b/web_console_v2/api/fedlearner_webconsole/serving/runners.py
new file mode 100644
index 000000000..424950179
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/runners.py
@@ -0,0 +1,194 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import hashlib
+import json
+import logging
+import os
+import traceback
+from typing import Tuple
+
+from google.protobuf.json_format import MessageToDict
+from multiprocessing import Queue
+from sqlalchemy.orm import Session, joinedload
+
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import ItemType, IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import NotFoundException
+from fedlearner_webconsole.proto import serving_pb2
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, RunnerOutput
+from fedlearner_webconsole.proto.serving_pb2 import ServingServiceType
+from fedlearner_webconsole.serving.models import ServingModel, ServingNegotiator, ServingModelStatus, ServingDeployment
+from fedlearner_webconsole.serving.services import NegotiatorServingService, SavedModelService, \
+ TensorflowServingService, ServingModelService
+from fedlearner_webconsole.utils import pp_datetime
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.pp_time import sleep
+from fedlearner_webconsole.utils.process_utils import get_result_by_sub_process
+from fedlearner_webconsole.mmgr.service import ModelJobGroupService
+from fedlearner_webconsole.project.models import Project
+
+
+def _update_parsed_signature(q: Queue, model_path: str):
+ file_manager = FileManager()
+ exported_dirs = file_manager.ls(model_path, include_directory=True)
+ newest_version = max([int(os.path.basename(v.path)) for v in exported_dirs if os.path.basename(v.path).isnumeric()])
+ pb_path = os.path.join(model_path, str(newest_version), 'saved_model.pb')
+ saved_model_bytes = file_manager.read_bytes(pb_path)
+ signature_from_saved_model = SavedModelService.get_parse_example_details(saved_model_bytes)
+ q.put(signature_from_saved_model)
+
+
+class ModelSignatureParser(IRunnerV2):
+ """ Parse example from model saved path
+ """
+
+ def __init__(self) -> None:
+ self.PARSE_TIMES = 10
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ serving_model_id = context.input.model_signature_parser_input.serving_model_id
+ try:
+ for num in range(0, self.PARSE_TIMES):
+ with db.session_scope() as session:
+ # update parsed signature in serving model
+ serving_model = session.query(ServingModel).filter_by(id=serving_model_id).one_or_none()
+ if not serving_model:
+ raise NotFoundException(f'Failed to find serving model: {serving_model_id}')
+ signature_from_saved_model = get_result_by_sub_process(name='serving parse signature',
+ target=_update_parsed_signature,
+ kwargs={
+ 'model_path': serving_model.model_path,
+ })
+ signature_dict = MessageToDict(signature_from_saved_model)
+ # update raw signature in serving negotiator
+ deployment_name = serving_model.serving_deployment.deployment_name
+ tf_serving_service = TensorflowServingService(deployment_name)
+ signature_from_tf = tf_serving_service.get_model_signature()
+ raw_signature = json.dumps(signature_from_tf)
+ if len(raw_signature) > 0 and raw_signature != '{}':
+ update_serving_negotiator = session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model.id).one_or_none()
+ update_serving_negotiator.raw_signature = raw_signature
+ # add outputs to parsed signature
+ signature_dict['outputs'] = signature_from_tf['outputs']
+ signature_dict['from_participants'] = signature_from_tf['inputs']
+ if 'examples' in signature_dict['from_participants']:
+ signature_dict['from_participants'].pop('examples')
+ serving_model.signature = json.dumps(signature_dict)
+ session.commit()
+ return RunnerStatus.DONE, RunnerOutput()
+ sleep(3)
+ except Exception: # pylint: disable=broad-except
+ error_message = f'[ModelSignatureParser] failed to run, serving id={serving_model_id}'
+ logging.exception(error_message)
+ return RunnerStatus.FAILED, RunnerOutput(error_message='[ModelSignatureParser] failed to get signature from tf')
+
+ @staticmethod
+ def generate_task_name(serving_model_id: int, name: str):
+ hash_value = hashlib.sha256(str(pp_datetime.now()).encode('utf8'))
+ return f'parse_signature_{serving_model_id}_{name}_{hash_value.hexdigest()[0:6]}'
+
+
+class QueryParticipantStatusRunner(IRunnerV2):
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ self._auto_run_query()
+ return RunnerStatus.DONE, RunnerOutput()
+
+ @staticmethod
+ def _auto_run_query():
+ with db.session_scope() as session:
+ query = session.query(ServingNegotiator)
+ query = query.filter(ServingNegotiator.is_local.is_(False))
+ query = query.outerjoin(ServingNegotiator.serving_model).options(joinedload(
+ ServingNegotiator.serving_model)).filter(ServingModel.status == ServingModelStatus.PENDING_ACCEPT)
+ query = query.outerjoin(Project, Project.id == ServingNegotiator.project_id).options(
+ joinedload(ServingNegotiator.project))
+ all_records = query.all()
+ for serving_negotiator in all_records:
+ with db.session_scope() as session:
+ serving_model = serving_negotiator.serving_model
+ try:
+ result = NegotiatorServingService(session).operate_participant_serving_service(
+ serving_negotiator, ServingServiceType.SERVING_SERVICE_QUERY)
+ if result == serving_pb2.SERVING_SERVICE_SUCCESS:
+ serving_model.status = ServingModelStatus.LOADING
+ session.add(serving_model)
+ session.commit()
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning(f'[QueryParticipantStatusRunner] auto run query participant'
+ f' for {serving_model.name} with error {e}, trace: {traceback.format_exc()}')
+
+
+class UpdateModelRunner(IRunnerV2):
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ self._auto_run_update()
+ return RunnerStatus.DONE, RunnerOutput()
+
+ @staticmethod
+ def _auto_run_update():
+ with db.session_scope() as session:
+ all_records = session.query(ServingModel).filter(ServingModel.model_group_id.isnot(None)).outerjoin(
+ ServingDeployment, ServingDeployment.id == ServingModel.serving_deployment_id).options(
+ joinedload(ServingModel.serving_deployment)).all()
+ for serving_model in all_records:
+ with db.session_scope() as session:
+ try:
+ model = ModelJobGroupService(session).get_latest_model_from_model_group(
+ serving_model.model_group_id)
+ if serving_model.model_id == model.id:
+ # already serving the latest model
+ continue
+ serving_model.model_id = model.id
+ serving_model.model_path = model.get_exported_model_path()
+ if serving_model.serving_deployment.is_remote_serving():
+ ServingModelService(session).update_remote_serving_model(serving_model)
+ session.add(serving_model)
+ session.commit()
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning(
+ f'[UpdateModelRunner] auto run update model for {serving_model.name} with error {e}, '
+ f'trace: {traceback.format_exc()}')
+
+
+def start_query_participant(session: Session):
+ composer_service_name = 'serving_model_query_participant_status_v2'
+ composer_service = ComposerService(session)
+ if composer_service.get_item_status(composer_service_name) is not None:
+ return
+ runner_input = RunnerInput()
+ composer_service.collect_v2(
+ name=composer_service_name,
+ items=[(ItemType.SERVING_SERVICE_QUERY_PARTICIPANT_STATUS, runner_input)],
+ # cron job at every 10 seconds
+ cron_config='* * * * * */10')
+
+
+def start_update_model(session: Session):
+ composer_service_name = 'serving_model_update_model_v2'
+ composer_service = ComposerService(session)
+ if composer_service.get_item_status(composer_service_name) is not None:
+ return
+ runner_input = RunnerInput()
+ composer_service.collect_v2(
+ name=composer_service_name,
+ items=[(ItemType.SERVING_SERVICE_UPDATE_MODEL, runner_input)],
+ # cron job at every 30 seconds
+ cron_config='* * * * * */30')
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/services.py b/web_console_v2/api/fedlearner_webconsole/serving/services.py
new file mode 100644
index 000000000..c4de6e47d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/services.py
@@ -0,0 +1,779 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import hashlib
+import json
+import grpc
+import logging
+
+from google.protobuf import json_format
+from google.protobuf.json_format import MessageToDict
+from tensorflow import make_tensor_proto
+from tensorflow.core.protobuf import saved_model_pb2 # pylint: disable=no-name-in-module
+from tensorflow.core.framework import graph_pb2, node_def_pb2, types_pb2 # pylint: disable=no-name-in-module
+from tensorflow.core.example.example_pb2 import Example # pylint: disable=no-name-in-module
+from tensorflow.core.example.feature_pb2 import Int64List, Feature, FloatList, BytesList, Features # pylint: disable=no-name-in-module
+from tensorflow_serving.apis import (get_model_status_pb2, model_service_pb2_grpc, get_model_metadata_pb2,
+ prediction_service_pb2_grpc, predict_pb2)
+from tensorflow_serving.apis.get_model_metadata_pb2 import SignatureDefMap
+from tensorflow_serving.apis.get_model_status_pb2 import ModelVersionStatus
+from typing import Dict, List, Tuple, Optional
+from kubernetes.client.models.v1_pod_list import V1PodList
+from envs import Envs
+from fedlearner_webconsole.exceptions import InternalException, ResourceConflictException, NotFoundException, \
+ InvalidArgumentException
+from fedlearner_webconsole.mmgr.models import Model
+from fedlearner_webconsole.mmgr.service import ModelJobGroupService
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto import common_pb2, service_pb2, serving_pb2
+from fedlearner_webconsole.proto.serving_pb2 import ServingServiceType, RemoteDeployState
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.serving import remote
+from fedlearner_webconsole.serving.database_fetcher import DatabaseFetcher
+from fedlearner_webconsole.serving.metrics import serving_metrics_emit_counter
+from fedlearner_webconsole.serving.utils import get_model, get_serving_negotiator_by_serving_model_id
+from fedlearner_webconsole.utils import pp_datetime, flask_utils
+from fedlearner_webconsole.k8s.models import Pod, PodState
+from fedlearner_webconsole.db import Session
+from fedlearner_webconsole.k8s.k8s_client import k8s_client
+from fedlearner_webconsole.serving.models import ServingModel, ServingNegotiator, ServingModelStatus, ServingDeployment
+from fedlearner_webconsole.serving.serving_yaml_template import (generate_serving_yaml, DEPLOYMENT_TEMPLATE,
+ CONFIG_MAP_TEMPLATE, SERVICE_TEMPLATE)
+from fedlearner_webconsole.utils.const import API_VERSION
+from fedlearner_webconsole.proto.serving_pb2 import (ServingServiceInstance, ServingServiceSignature,
+ ServingServiceSignatureInput)
+from fedlearner_webconsole.utils.sorting import SortExpression
+
+
+class ServingDeploymentService:
+
+ def __init__(self, session: Session = None):
+ self._session = session
+
+ @staticmethod
+ def get_base_path(serving_model_id: int):
+ # TODO(wangsen.0914): should having a serving storage filesystem. /cc @lixiaoguang.01
+ return f'test/{serving_model_id}'
+
+ def _get_serving_object_definition(self, serving_model: ServingModel) -> Tuple[Dict, Dict, Dict]:
+ """get all kubernetes definition
+
+ Returns:
+ configMap, Deployment, Service
+ """
+ resource = json.loads(serving_model.serving_deployment.resource)
+ if serving_model.model_path is not None:
+ model_path = serving_model.model_path
+ else:
+ model_path = self.get_base_path(serving_model.id)
+ serving_config = {
+ 'project': serving_model.project,
+ 'model': {
+ 'base_path': model_path
+ },
+ 'serving': {
+ 'name': serving_model.serving_deployment.deployment_name,
+ 'resource': {
+ 'resource': {
+ 'cpu': resource['cpu'],
+ 'memory': resource['memory']
+ },
+ 'replicas': resource['replicas'],
+ },
+ }
+ }
+ config_map_object = generate_serving_yaml(serving_config, CONFIG_MAP_TEMPLATE, self._session)
+ deployment_object = generate_serving_yaml(serving_config, DEPLOYMENT_TEMPLATE, self._session)
+ service_object = generate_serving_yaml(serving_config, SERVICE_TEMPLATE, self._session)
+
+ return config_map_object, deployment_object, service_object
+
+ def create_or_update_deployment(self, serving_model: ServingModel):
+ """post a bunch of k8s resources.
+
+ Raises:
+ Raises RuntimeError if k8s post ops failed. Then you should call `session.rollback()`.
+ """
+ config_map_object, deployment_object, service_object = self._get_serving_object_definition(serving_model)
+
+ # For core api, failed to use *_app method.
+ k8s_client.create_or_update_config_map(metadata=config_map_object['metadata'],
+ data=config_map_object['data'],
+ name=config_map_object['metadata']['name'],
+ namespace=Envs.K8S_NAMESPACE)
+
+ k8s_client.create_or_update_app(app_yaml=deployment_object,
+ group='apps',
+ version='v1',
+ plural='deployments',
+ namespace=Envs.K8S_NAMESPACE)
+
+ k8s_client.create_or_update_service(metadata=service_object['metadata'],
+ spec=service_object['spec'],
+ name=service_object['metadata']['name'],
+ namespace=Envs.K8S_NAMESPACE)
+
+ def delete_deployment(self, serving_model: ServingModel):
+ """delete a bunch of k8s resources.
+ """
+ try:
+ config_map_object, deployment_object, service_object = self._get_serving_object_definition(serving_model)
+
+ # For core api, failed to use *_app method.
+ k8s_client.delete_config_map(name=config_map_object['metadata']['name'], namespace=Envs.K8S_NAMESPACE)
+
+ k8s_client.delete_app(app_name=deployment_object['metadata']['name'],
+ group='apps',
+ version='v1',
+ plural='deployments',
+ namespace=Envs.K8S_NAMESPACE)
+
+ k8s_client.delete_service(name=service_object['metadata']['name'], namespace=Envs.K8S_NAMESPACE)
+ except RuntimeError as err:
+ logging.warning(f'Failed to delete serving k8s resources, {err}')
+
+ @staticmethod
+ def get_pods_info(deployment_name: str) -> List[Pod]:
+ pods: V1PodList = k8s_client.get_pods(Envs.K8S_NAMESPACE, label_selector=f'app={deployment_name}')
+
+ pods_info = []
+ for p in pods.items:
+ pods_info.append(Pod.from_json(p.to_dict()))
+ return pods_info
+
+ @staticmethod
+ def get_replica_status(deployment_name: str) -> str:
+ config = k8s_client.get_deployment(deployment_name)
+ if config is not None and config.status is not None:
+ if config.status.ready_replicas is None:
+ config.status.ready_replicas = 0
+ return f'{config.status.ready_replicas}/{config.spec.replicas}'
+ return 'UNKNOWN'
+
+ @classmethod
+ def get_pods_status(cls, deployment_name: str) -> List[ServingServiceInstance]:
+ pods = cls.get_pods_info(deployment_name)
+ result = []
+ for pod in pods:
+ instance = ServingServiceInstance(name=pod.name,
+ cpu='UNKNOWN',
+ memory='UNKNOWN',
+ created_at=pod.creation_timestamp)
+ if pod.state in (PodState.RUNNING, PodState.SUCCEEDED):
+ instance.status = 'AVAILABLE'
+ else:
+ instance.status = 'UNAVAILABLE'
+ result.append(instance)
+ return result
+
+ @staticmethod
+ def get_pod_log(pod_name: str, tail_lines: int) -> List[str]:
+ """get pod log
+
+ Args:
+ pod_name (str): pod name that you want to query
+ tail_lines (int): lines you want to query
+
+ Returns:
+ List[str]: list of logs
+ """
+ return k8s_client.get_pod_log(pod_name, namespace=Envs.K8S_NAMESPACE, tail_lines=tail_lines)
+
+ @staticmethod
+ def generate_deployment_name(serving_model_id: int) -> str:
+ hash_value = hashlib.sha256(str(pp_datetime.now()).encode('utf8'))
+ return f'serving-{serving_model_id}-{hash_value.hexdigest()[0:6]}'
+
+
+class TensorflowServingService:
+
+ def __init__(self, deployment_name):
+ self._deployment_name = deployment_name
+ model_server_address = f'{deployment_name}.{Envs.K8S_NAMESPACE}.svc:8500'
+ channel = grpc.insecure_channel(model_server_address)
+ self.model_service_stub = model_service_pb2_grpc.ModelServiceStub(channel)
+ self.prediction_service_stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
+
+ def get_model_status(self) -> ModelVersionStatus.State:
+ """ ref: https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/get_model_status.proto#L26
+ """
+ request = get_model_status_pb2.GetModelStatusRequest()
+ request.model_spec.name = self._deployment_name
+ try:
+ state = self.model_service_stub.GetModelStatus(request).model_version_status[0].state
+ except grpc.RpcError:
+ return ModelVersionStatus.State.UNKNOWN
+ if state == ModelVersionStatus.State.START:
+ state = ModelVersionStatus.State.LOADING
+ elif state == ModelVersionStatus.State.END:
+ state = ModelVersionStatus.State.UNKNOWN
+ return state
+
+ def get_model_signature(self) -> dict:
+ request = get_model_metadata_pb2.GetModelMetadataRequest()
+ request.model_spec.name = self._deployment_name
+ request.metadata_field.append('signature_def')
+ try:
+ metadata = self.prediction_service_stub.GetModelMetadata(request)
+ except grpc.RpcError:
+ return {}
+ signature = SignatureDefMap()
+ metadata.metadata['signature_def'].Unpack(signature)
+ return MessageToDict(signature.signature_def['serving_default'])
+
+ def get_model_inference_output(self, user_input: Example, extend_input: Optional[dict] = None) -> dict:
+ inputs = make_tensor_proto([user_input.SerializeToString()])
+ request = predict_pb2.PredictRequest()
+ request.model_spec.name = self._deployment_name
+ request.inputs['examples'].CopyFrom(inputs)
+ if extend_input is not None:
+ for k in extend_input:
+ ext_inputs = make_tensor_proto([extend_input[k]])
+ request.inputs[k].CopyFrom(ext_inputs)
+ try:
+ output = self.prediction_service_stub.Predict(request)
+ except grpc.RpcError as err:
+ logging.error(f'Failed to inference, {err}')
+ return {'Error': str(err)}
+ return MessageToDict(output)
+
+ @staticmethod
+ def get_model_inference_endpoint(project_id: int, serving_model_id: int) -> str:
+ return f'{API_VERSION}/projects/{project_id}/serving_services/{serving_model_id}/inference'
+
+
+class NegotiatorServingService:
+
+ def __init__(self, session: Session = None):
+ self._session = session
+
+ def _handle_participant_request_create(self, request: service_pb2.ServingServiceRequest,
+ project: Project) -> service_pb2.ServingServiceResponse:
+ response = service_pb2.ServingServiceResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ # check model existence
+ model = self._session.query(Model).filter_by(uuid=request.model_uuid, project_id=project.id).first()
+ if model is None:
+ response.status.code = common_pb2.STATUS_NOT_FOUND
+ response.code = serving_pb2.SERVING_SERVICE_MODEL_NOT_FOUND
+ response.msg = 'model not found'
+ serving_metrics_emit_counter('serving.from_participant.create.model_not_found')
+ return response
+ # check serving model name
+ serving_model = self._session.query(ServingModel).filter_by(name=request.serving_model_name).first()
+ if serving_model is not None:
+ response.status.code = common_pb2.STATUS_INVALID_ARGUMENT
+ response.code = serving_pb2.SERVING_SERVICE_NAME_DUPLICATED
+ response.msg = 'serving model name is duplicated'
+ serving_metrics_emit_counter('serving.from_participant.create.duplicated', serving_model)
+ return response
+ # create db records in 3 tables
+ serving_model = ServingModel()
+ if model is not None:
+ serving_model.model_id = model.id
+ serving_model.model_path = model.get_exported_model_path()
+ serving_model.name = request.serving_model_name
+ serving_model.project_id = project.id
+ serving_model.status = ServingModelStatus.WAITING_CONFIG
+ serving_deployment = ServingDeployment()
+ serving_deployment.project_id = project.id
+ serving_deployment.resource = json.dumps({'cpu': '1000m', 'memory': '1Gi', 'replicas': 0})
+ serving_negotiator = ServingNegotiator()
+ serving_negotiator.project_id = project.id
+ serving_negotiator.is_local = False
+ try:
+ self._session.add(serving_model)
+ self._session.flush([serving_model])
+ except Exception as err:
+ serving_metrics_emit_counter('serving.from_participant.create.db_error', serving_model)
+ raise ResourceConflictException(
+ f'create serving service fail! serving model name = {serving_model.name}, err = {err}') from err
+ serving_deployment.deployment_name = ServingDeploymentService.generate_deployment_name(serving_model.id)
+ self._session.add(serving_deployment)
+ self._session.flush([serving_deployment])
+ serving_negotiator.serving_model_id = serving_model.id
+ serving_negotiator.serving_model_uuid = request.serving_model_uuid
+ serving_negotiator.with_label = False
+ self._session.add(serving_negotiator)
+ serving_model.endpoint = TensorflowServingService.get_model_inference_endpoint(project.id, serving_model.id)
+ serving_model.serving_deployment_id = serving_deployment.id
+ self._session.commit()
+ serving_metrics_emit_counter('serving.from_participant.create.success', serving_model)
+ return response
+
+ def _handle_participant_request_query(
+ self, request: service_pb2.ServingServiceRequest) -> service_pb2.ServingServiceResponse:
+ response = service_pb2.ServingServiceResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ # check serving model status
+ serving_model = self._session.query(ServingModel).filter_by(name=request.serving_model_name).first()
+ if serving_model is None:
+ response.status.code = common_pb2.STATUS_NOT_FOUND
+ response.code = serving_pb2.SERVING_SERVICE_MODEL_NOT_FOUND
+ response.msg = f'serving model not found, name = {request.serving_model_name}'
+ serving_metrics_emit_counter('serving.from_participant.query.serving_not_found')
+ return response
+ if serving_model.status == ServingModelStatus.WAITING_CONFIG:
+ response.code = serving_pb2.SERVING_SERVICE_PENDING_ACCEPT
+ response.msg = 'serving model is waiting for config'
+ serving_metrics_emit_counter('serving.from_participant.query.waiting', serving_model)
+ return response
+ serving_metrics_emit_counter('serving.from_participant.query.success', serving_model)
+ return response
+
+ def _handle_participant_request_destroy(
+ self, request: service_pb2.ServingServiceRequest) -> service_pb2.ServingServiceResponse:
+ response = service_pb2.ServingServiceResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ serving_negotiator = self._session.query(ServingNegotiator).filter_by(
+ serving_model_uuid=request.serving_model_uuid).one_or_none()
+ if serving_negotiator is None:
+ response.msg = 'serving negotiator is already deleted'
+ serving_metrics_emit_counter('serving.from_participant.delete.already_done')
+ return response
+ if serving_negotiator.serving_model.status == ServingModelStatus.WAITING_CONFIG:
+ self._session.delete(serving_negotiator.serving_model.serving_deployment)
+ self._session.delete(serving_negotiator.serving_model)
+ self._session.delete(serving_negotiator)
+ self._session.commit()
+ serving_metrics_emit_counter('serving.from_participant.delete.directly', serving_negotiator.serving_model)
+ return response
+ serving_negotiator.serving_model.status = ServingModelStatus.DELETED
+ serving_metrics_emit_counter('serving.from_participant.delete.success', serving_negotiator.serving_model)
+ self._session.commit()
+ return response
+
+ @staticmethod
+ def generate_uuid(serving_model_id: int) -> str:
+ hash_value = hashlib.sha256(str(pp_datetime.now()).encode('utf8'))
+ return f'{serving_model_id}{hash_value.hexdigest()[0:6]}'
+
+ def operate_participant_serving_service(self, serving_negotiator: ServingNegotiator, operation: ServingServiceType):
+ serving_model = self._session.query(ServingModel).filter_by(
+ id=serving_negotiator.serving_model_id).one_or_none()
+ # no need to notify participants when serving on third party platform
+ if serving_model.serving_deployment.is_remote_serving():
+ return serving_pb2.SERVING_SERVICE_SUCCESS
+ service = ParticipantService(self._session)
+ participants = service.get_platform_participants_by_project(serving_negotiator.project.id)
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(serving_negotiator.project.name,
+ serving_negotiator.project.token, participant.domain_name)
+ model_uuid = ''
+ if serving_model.model is not None:
+ model_uuid = '' or serving_model.model.uuid
+ resp = client.operate_serving_service(operation, serving_negotiator.serving_model_uuid, model_uuid,
+ serving_model.name)
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ msg = f'operate participant fail! status code = {resp.status.code}, msg = {resp.msg}'
+ logging.error(msg)
+ raise InternalException(msg)
+ if operation == serving_pb2.SERVING_SERVICE_CREATE:
+ if resp.code != serving_pb2.SERVING_SERVICE_SUCCESS:
+ return resp.code
+ elif operation == serving_pb2.SERVING_SERVICE_QUERY:
+ if resp.code != serving_pb2.SERVING_SERVICE_SUCCESS:
+ return resp.code
+ else: # SERVING_SERVICE_DESTROY
+ pass
+ return serving_pb2.SERVING_SERVICE_SUCCESS
+
+ def participant_serving_service_inference(self, serving_negotiator: ServingNegotiator,
+ example_id: str) -> service_pb2.ServingServiceInferenceResponse:
+ service = ParticipantService(self._session)
+ participants = service.get_platform_participants_by_project(serving_negotiator.project.id)
+ assert len(participants) == 1, f'support one participant only! num = {len(participants)}'
+ client = RpcClient.from_project_and_participant(serving_negotiator.project.name,
+ serving_negotiator.project.token, participants[0].domain_name)
+ resp = client.inference_serving_service(serving_negotiator.serving_model_uuid, example_id)
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ logging.error(resp.status.msg)
+ raise InternalException(resp.status.msg)
+ return resp
+
+ def handle_participant_request(self, request: service_pb2.ServingServiceRequest,
+ project: Project) -> service_pb2.ServingServiceResponse:
+ if request.operation_type == ServingServiceType.SERVING_SERVICE_CREATE:
+ return self._handle_participant_request_create(request, project)
+ if request.operation_type == ServingServiceType.SERVING_SERVICE_QUERY:
+ return self._handle_participant_request_query(request)
+ if request.operation_type == ServingServiceType.SERVING_SERVICE_DESTROY:
+ return self._handle_participant_request_destroy(request)
+ response = service_pb2.ServingServiceResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ return response
+
+ def handle_participant_inference_request(self, request: service_pb2.ServingServiceInferenceRequest,
+ project: Project) -> service_pb2.ServingServiceInferenceResponse:
+ response = service_pb2.ServingServiceInferenceResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ code=serving_pb2.SERVING_SERVICE_SUCCESS)
+ serving_negotiator = self._session.query(ServingNegotiator).filter_by(
+ serving_model_uuid=request.serving_model_uuid, project_id=project.id).one_or_none()
+ if serving_negotiator is None:
+ response.rpc_status.code = common_pb2.STATUS_NOT_FOUND
+ response.code = serving_pb2.SERVING_SERVICE_NEGOTIATOR_NOT_FOUND
+ response.msg = 'serving negotiator not found'
+ return response
+ deployment_name = serving_negotiator.serving_model.serving_deployment.deployment_name
+ tf_serving_service = TensorflowServingService(deployment_name)
+ query_key = int(request.example_id)
+ data_record = DatabaseFetcher.fetch_by_int_key(query_key, serving_negotiator.serving_model.signature)
+ feature_input = {}
+ for k, item in data_record.items():
+ if isinstance(item[0], int):
+ int_list = Int64List(value=item)
+ feature_input[k] = Feature(int64_list=int_list)
+ if isinstance(item[0], float):
+ float_list = FloatList(value=item)
+ feature_input[k] = Feature(float_list=float_list)
+ if isinstance(item[0], str):
+ data_record_bytes = [x.encode(encoding='utf-8') for x in item]
+ bytes_list = BytesList(value=data_record_bytes)
+ feature_input[k] = Feature(bytes_list=bytes_list)
+ input_data = Example(features=Features(feature=feature_input))
+ output = tf_serving_service.get_model_inference_output(input_data)
+ response.data.update({'result': output['outputs']})
+ return response
+
+
+class SavedModelService:
+ PARSE_EXAMPLE_NAME = 'ParseExample/ParseExample'
+ INPUT_NODE_NAMES = [PARSE_EXAMPLE_NAME]
+
+ @staticmethod
+ def get_nodes_from_graph(graph: graph_pb2.GraphDef, node_list: List[str]) -> Dict[str, node_def_pb2.NodeDef]:
+ """get nodes from graph by node names
+
+ Args:
+ graph (graph_pb2.GraphDef): GraphDef
+ node_list (List[str]): node name list
+
+ Returns:
+ Dict[str, node_def_pb2.NodeDef]: a mapping from node_name to NodeDef
+
+ Raises:
+ AssertionError: when failed to get all nodes required by node_list
+ """
+ result = {}
+ for n in graph.node:
+ if n.name in node_list:
+ result[n.name] = n
+ assert list(result.keys()) == node_list, f'Failed to get nodes: {node_list - result.keys()}'
+ return result
+
+ @classmethod
+ def get_parse_example_details(cls, saved_model_binary: bytes) -> ServingServiceSignature:
+ saved_model_message = saved_model_pb2.SavedModel()
+ saved_model_message.ParseFromString(saved_model_binary)
+ graph = saved_model_message.meta_graphs[0].graph_def
+
+ parse_example_op = cls.get_nodes_from_graph(graph, cls.INPUT_NODE_NAMES)[cls.PARSE_EXAMPLE_NAME]
+ assert parse_example_op.op == 'ParseExample', f'{parse_example_op} node is not a ParseExample op'
+
+ dense_keys_inputs = [i for i in parse_example_op.input if 'dense_keys' in i]
+ assert len(dense_keys_inputs) == parse_example_op.attr['Ndense'].i, 'Consistency check failed'
+
+ dense_keys_nodes = cls.get_nodes_from_graph(graph, dense_keys_inputs)
+ # Keep nodes in order
+ dense_keys_list = [dense_keys_nodes[i] for i in dense_keys_inputs]
+ signature = ServingServiceSignature()
+ # For more details on serving/examples/parse_graph.py
+ for n, t, s in zip(dense_keys_list, parse_example_op.attr['Tdense'].list.type,
+ parse_example_op.attr['dense_shapes'].list.shape):
+ signature_input = ServingServiceSignatureInput(name=n.attr['value'].tensor.string_val[0],
+ type=types_pb2.DataType.Name(t),
+ dim=[d.size for d in s.dim])
+ signature.inputs.append(signature_input)
+ return signature
+
+
+class ServingModelService(object):
+
+ def __init__(self, session: Session = None):
+ self._session = session
+
+ def create_from_param(self,
+ project_id: int,
+ name: str,
+ is_local: bool,
+ comment: Optional[str],
+ model_id: Optional[int],
+ model_group_id: Optional[int],
+ resource: serving_pb2.ServingServiceResource = None,
+ remote_platform: serving_pb2.ServingServiceRemotePlatform = None) -> ServingModel:
+ session = self._session
+ serving_model = ServingModel()
+ if model_id is not None:
+ serving_model.model_id = model_id
+ model = get_model(serving_model.model_id, self._session)
+ elif model_group_id is not None:
+ serving_model.model_group_id = model_group_id
+ model = ModelJobGroupService(self._session).get_latest_model_from_model_group(serving_model.model_group_id)
+ serving_model.model_id = model.id
+ else:
+ raise InvalidArgumentException('model_id and model_group_id need to fill one')
+ serving_model.name = name
+ serving_model.project_id = project_id
+ serving_model.comment = comment
+ serving_model.status = ServingModelStatus.LOADING
+ serving_model.model_path = model.get_exported_model_path()
+
+ try:
+ session.add(serving_model)
+ session.flush([serving_model])
+ except Exception as err:
+ serving_metrics_emit_counter('serving.create.db_fail', serving_model)
+ raise ResourceConflictException(
+ f'create serving service fail! serving model name = {serving_model.name}, err = {err}') from err
+
+ serving_deployment = ServingDeployment()
+ serving_deployment.project_id = project_id
+ session.add(serving_deployment)
+ session.flush([serving_deployment])
+ serving_model.serving_deployment_id = serving_deployment.id
+
+ serving_negotiator = ServingNegotiator()
+ serving_negotiator.project_id = project_id
+ serving_negotiator.is_local = is_local
+
+ if remote_platform is None: # serving inside this platform
+ serving_deployment.resource = json.dumps(MessageToDict(resource))
+ serving_deployment.deployment_name = ServingDeploymentService.generate_deployment_name(serving_model.id)
+ serving_model.endpoint = TensorflowServingService.get_model_inference_endpoint(project_id, serving_model.id)
+ self._create_or_update_deployment(serving_model)
+ else: # remote serving
+ deploy_config: serving_pb2.RemoteDeployConfig = self._create_remote_serving(remote_platform, serving_model)
+ serving_deployment.deploy_platform = json.dumps(MessageToDict(deploy_config))
+ serving_model.endpoint = self._get_remote_serving_url(remote_platform)
+
+ # Notifying participants needs to be placed behind the k8s operation,
+ # because when the k8s operation fails, it avoids the participants from generating dirty data
+ serving_negotiator.serving_model_id = serving_model.id
+ serving_negotiator.serving_model_uuid = NegotiatorServingService.generate_uuid(serving_model.id)
+ serving_negotiator.with_label = True
+ session.add(serving_negotiator)
+ session.flush([serving_negotiator])
+ if not serving_negotiator.is_local:
+ serving_model.status = ServingModelStatus.PENDING_ACCEPT
+ result = NegotiatorServingService(session).operate_participant_serving_service(
+ serving_negotiator, serving_pb2.SERVING_SERVICE_CREATE)
+ if result != serving_pb2.SERVING_SERVICE_SUCCESS:
+ raise InternalException(details=f'create participant serving service fail! result code = {result}')
+ return serving_model
+
+ def get_serving_service_detail(self,
+ serving_model_id: int,
+ project_id: Optional[int] = None,
+ sorter: Optional[SortExpression] = None) -> serving_pb2.ServingServiceDetail:
+ serving_model = self._session.query(ServingModel).filter_by(id=serving_model_id).one_or_none()
+ if not serving_model:
+ raise NotFoundException(f'Failed to find serving model {serving_model_id}')
+ if project_id is not None:
+ if serving_model.project_id != project_id:
+ raise NotFoundException(f'Failed to find serving model {serving_model_id} in project {project_id}')
+ deployment_name = serving_model.serving_deployment.deployment_name
+ result = serving_model.to_serving_service_detail()
+ if serving_model.serving_deployment.is_remote_serving():
+ result.status = self._get_remote_serving_status(serving_model).name
+ result.support_inference = (result.status == ServingModelStatus.AVAILABLE.name)
+ else:
+ status = TensorflowServingService(deployment_name).get_model_status()
+ if serving_model.status == ServingModelStatus.LOADING and status != ModelVersionStatus.State.UNKNOWN:
+ result.status = ModelVersionStatus.State.Name(status)
+ result.support_inference = (result.status == ServingModelStatus.AVAILABLE.name)
+ resource = json.loads(serving_model.serving_deployment.resource)
+ resource = serving_pb2.ServingServiceResource(
+ cpu=resource['cpu'],
+ memory=resource['memory'],
+ replicas=resource['replicas'],
+ )
+ result.resource.CopyFrom(resource)
+ if result.resource.replicas > 0:
+ k8s_serving_service = ServingDeploymentService()
+ result.instance_num_status = k8s_serving_service.get_replica_status(deployment_name)
+ instances = k8s_serving_service.get_pods_status(deployment_name)
+ if sorter is not None:
+ if sorter.field == 'created_at':
+ reverse = not sorter.is_asc
+ instances = sorted(instances, key=lambda x: x.created_at, reverse=reverse)
+ result.instances.extend(instances)
+ else:
+ result.instance_num_status = 'UNKNOWN'
+ serving_negotiator = get_serving_negotiator_by_serving_model_id(serving_model_id, self._session)
+ if serving_negotiator is not None:
+ result.is_local = serving_negotiator.is_local
+ if not serving_negotiator.with_label:
+ result.support_inference = False
+ return result
+
+ def set_resource_and_status_on_ref(self, single_res: serving_pb2.ServingService, serving_model: ServingModel):
+ if serving_model.serving_deployment.is_remote_serving():
+ single_res.status = self._get_remote_serving_status(serving_model).name
+ single_res.support_inference = (single_res.status == ServingModelStatus.AVAILABLE.name)
+ return
+ deployment_name = serving_model.serving_deployment.deployment_name
+ resource = json.loads(serving_model.serving_deployment.resource)
+ tf_serving_service = TensorflowServingService(deployment_name)
+ status = tf_serving_service.get_model_status()
+ if serving_model.status == ServingModelStatus.LOADING and status != ModelVersionStatus.State.UNKNOWN:
+ single_res.status = ModelVersionStatus.State.Name(status)
+ single_res.support_inference = (single_res.status == ServingModelStatus.AVAILABLE.name)
+ if resource['replicas'] > 0:
+ single_res.instance_num_status = ServingDeploymentService.get_replica_status(deployment_name)
+ else:
+ single_res.instance_num_status = 'UNKNOWN'
+ single_res.resource.cpu = resource['cpu']
+ single_res.resource.memory = resource['memory']
+ single_res.resource.replicas = resource['replicas']
+
+ def set_is_local_on_ref(self, single_res: serving_pb2.ServingService, serving_model: ServingModel):
+ serving_negotiator = self._session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model.id).one_or_none()
+ if serving_negotiator is not None:
+ single_res.is_local = serving_negotiator.is_local
+ if not serving_negotiator.with_label:
+ single_res.support_inference = False
+
+ def update_model(self, model_id: Optional[int], model_group_id: Optional[int], serving_model: ServingModel) -> bool:
+ need_update = False
+ if model_id is not None:
+ if model_id != serving_model.model_id:
+ model = get_model(model_id, self._session)
+ serving_model.model_id = model.id
+ serving_model.model_path = model.get_exported_model_path()
+ need_update = True
+ serving_model.model_group_id = None # clear model group config
+ elif model_group_id is not None and model_group_id != serving_model.model_group_id:
+ model = ModelJobGroupService(self._session).get_latest_model_from_model_group(model_group_id)
+ if serving_model.model_id != model.id:
+ serving_model.model_id = model.id
+ serving_model.model_path = model.get_exported_model_path()
+ need_update = True
+ serving_model.model_group_id = model_group_id
+
+ if not need_update:
+ return False
+
+ if serving_model.serving_deployment.is_remote_serving():
+ self.update_remote_serving_model(serving_model)
+ return True
+
+ serving_negotiator = self._session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model.id).one_or_none()
+ if serving_negotiator is not None and not serving_negotiator.is_local:
+ # TODO(lixiaoguang.01) support update model for federal serving
+ raise InvalidArgumentException('update model is not supported for federal serving')
+
+ self._create_or_update_deployment(serving_model)
+ return True
+
+ def update_resource(self, resource: dict, serving_model: ServingModel):
+ serving_model.serving_deployment.resource = json.dumps(resource)
+ serving_model.status = ServingModelStatus.LOADING
+ self._create_or_update_deployment(serving_model)
+ self._session.add(serving_model.serving_deployment)
+
+ def delete_serving_service(self, serving_model: ServingModel):
+ serving_negotiator = self._session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model.id).one_or_none()
+ if serving_negotiator is not None:
+ if not serving_negotiator.is_local:
+ NegotiatorServingService(self._session).operate_participant_serving_service(
+ serving_negotiator, serving_pb2.ServingServiceType.SERVING_SERVICE_DESTROY)
+ self._session.delete(serving_negotiator)
+ if serving_model.serving_deployment.is_remote_serving():
+ self._undeploy_remote_serving(serving_model)
+ else:
+ try:
+ ServingDeploymentService(self._session).delete_deployment(serving_model)
+ except RuntimeError as err:
+ serving_metrics_emit_counter('serving.delete.deployment_error', serving_model)
+ raise ResourceConflictException(
+ f'delete deployment fail! serving model id = {serving_model.id}, err = {err}') from err
+ self._session.delete(serving_model.serving_deployment)
+ self._session.delete(serving_model)
+
+ def _create_or_update_deployment(self, serving_model: ServingModel):
+ try:
+ ServingDeploymentService(self._session).create_or_update_deployment(serving_model)
+ except RuntimeError as err:
+ serving_metrics_emit_counter('serving.deployment_error', serving_model)
+ raise InternalException(
+ f'create or update deployment fail! serving model id = {serving_model.id}, err = {err}') from err
+
+ @staticmethod
+ def _create_remote_serving(remote_platform: serving_pb2.ServingServiceRemotePlatform,
+ serving_model: ServingModel) -> serving_pb2.RemoteDeployConfig:
+ current_user = flask_utils.get_current_user()
+ if remote_platform.platform not in remote.supported_remote_serving:
+ raise InvalidArgumentException(f'platform {remote_platform.platform} not supported')
+ deploy_config = serving_pb2.RemoteDeployConfig(platform=remote_platform.platform,
+ payload=remote_platform.payload,
+ deploy_name=f'privacy-platform-{serving_model.name}',
+ model_src_path=serving_model.model_path)
+ remote_helper = remote.supported_remote_serving[remote_platform.platform]
+ try:
+ deploy_config.deploy_id = remote_helper.deploy_model(current_user.username, deploy_config)
+ except (FileNotFoundError, AttributeError) as err:
+ serving_metrics_emit_counter('serving.remote_deployment_error', serving_model)
+ raise InvalidArgumentException(
+ f'create remote deployment fail! serving model id = {serving_model.id}, err = {err}') from err
+ # not stored in db, fetch from serving_model when deploy
+ deploy_config.model_src_path = ''
+ return deploy_config
+
+ def update_remote_serving_model(self, serving_model: ServingModel):
+ current_user = flask_utils.get_current_user()
+ if current_user is None:
+ username = 'robot'
+ else:
+ username = current_user.username
+ deploy_config = serving_pb2.RemoteDeployConfig()
+ json_format.Parse(serving_model.serving_deployment.deploy_platform, deploy_config)
+ if deploy_config.platform not in remote.supported_remote_serving:
+ raise InvalidArgumentException(f'platform {deploy_config.platform} not supported')
+ deploy_config.model_src_path = serving_model.model_path
+ remote_helper = remote.supported_remote_serving[deploy_config.platform]
+ remote_helper.deploy_model(username, deploy_config)
+
+ @staticmethod
+ def _get_remote_serving_url(remote_platform: serving_pb2.ServingServiceRemotePlatform) -> str:
+ if remote_platform.platform not in remote.supported_remote_serving:
+ raise InvalidArgumentException(f'platform {remote_platform.platform} not supported')
+ deploy_config = serving_pb2.RemoteDeployConfig(payload=remote_platform.payload)
+ remote_helper = remote.supported_remote_serving[remote_platform.platform]
+ return remote_helper.get_deploy_url(deploy_config)
+
+ @staticmethod
+ def _get_remote_serving_status(serving_model: ServingModel) -> ServingModelStatus:
+ deploy_config = serving_pb2.RemoteDeployConfig()
+ json_format.Parse(serving_model.serving_deployment.deploy_platform, deploy_config)
+ if deploy_config.platform not in remote.supported_remote_serving:
+ raise InvalidArgumentException(f'platform {deploy_config.platform} not supported')
+ remote_helper = remote.supported_remote_serving[deploy_config.platform]
+ deploy_status = remote_helper.get_deploy_status(deploy_config)
+ if deploy_status == RemoteDeployState.REMOTE_DEPLOY_READY:
+ return ServingModelStatus.AVAILABLE
+ return ServingModelStatus.LOADING
+
+ @staticmethod
+ def _undeploy_remote_serving(serving_model: ServingModel):
+ deploy_config = serving_pb2.RemoteDeployConfig()
+ json_format.Parse(serving_model.serving_deployment.deploy_platform, deploy_config)
+ if deploy_config.platform not in remote.supported_remote_serving:
+ raise InvalidArgumentException(f'platform {deploy_config.platform} not supported')
+ remote_helper = remote.supported_remote_serving[deploy_config.platform]
+ remote_helper.undeploy_model(deploy_config)
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/services_test.py b/web_console_v2/api/fedlearner_webconsole/serving/services_test.py
new file mode 100644
index 000000000..8ec297594
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/services_test.py
@@ -0,0 +1,511 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import os
+import unittest
+from unittest.mock import MagicMock, patch, call
+
+from google.protobuf import text_format
+from google.protobuf.json_format import MessageToDict
+from tensorflow.core.protobuf import saved_model_pb2
+from envs import Envs
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.exceptions import NotFoundException, InvalidArgumentException
+from fedlearner_webconsole.initial_db import initial_db
+from fedlearner_webconsole.mmgr.models import Model, ModelJobGroup, ModelType
+from fedlearner_webconsole.proto import serving_pb2
+from fedlearner_webconsole.serving.remote import register_remote_serving
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.serving.models import ServingDeployment, ServingModel, ServingNegotiator, ServingModelStatus
+from fedlearner_webconsole.serving.services import SavedModelService, ServingDeploymentService, ServingModelService
+
+from testing.common import BaseTestCase
+from testing.fake_remote_serving import FakeRemoteServing
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ServingDeploymentServiceTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ initial_db()
+ with db.session_scope() as session:
+
+ project = Project()
+ project.name = 'test_project_name'
+ session.add(project)
+ session.flush([project])
+
+ serving_deployment = ServingDeployment()
+ serving_deployment.deployment_name = 'test_deployment_name'
+ serving_deployment.project_id = project.id
+ serving_deployment.resource = json.dumps({'cpu': '4000m', 'memory': '8Gi', 'replicas': 3})
+ session.add(serving_deployment)
+ session.flush([serving_deployment])
+
+ serving_model = ServingModel()
+ serving_model.project_id = project.id
+ serving_model.serving_deployment_id = serving_deployment.id
+ session.add(serving_model)
+
+ session.commit()
+
+ self.serving_model_id = serving_model.id
+
+ @patch('fedlearner_webconsole.serving.services.k8s_client')
+ def test_create_deployment(self, mock_k8s_client: MagicMock):
+ mock_k8s_client.create_or_update_app = MagicMock()
+ mock_k8s_client.create_or_update_config_map = MagicMock()
+ mock_k8s_client.create_or_update_service = MagicMock()
+
+ with db.session_scope() as session:
+ # This is the best practices for using serving_model_id to interact with two session.
+ serving_model = session.query(ServingModel).get(self.serving_model_id)
+ service = ServingDeploymentService(session)
+ service.create_or_update_deployment(serving_model)
+
+ mock_k8s_client.create_or_update_config_map.assert_called_once()
+ mock_k8s_client.create_or_update_service.assert_called_once()
+ mock_k8s_client.create_or_update_app.assert_called_once()
+
+ @patch('fedlearner_webconsole.serving.services.k8s_client')
+ def test_delete_deployment(self, mock_k8s_client: MagicMock):
+ mock_k8s_client.delete_app = MagicMock()
+ mock_k8s_client.delete_config_map = MagicMock()
+ mock_k8s_client.delete_service = MagicMock()
+
+ with db.session_scope() as session:
+ # This is the best practices for using serving_model_id to interact with two session.
+ serving_model = session.query(ServingModel).get(self.serving_model_id)
+ service = ServingDeploymentService(session)
+ service.delete_deployment(serving_model)
+
+ mock_k8s_client.delete_app.assert_has_calls([
+ call(app_name=serving_model.serving_deployment.deployment_name,
+ group='apps',
+ version='v1',
+ plural='deployments',
+ namespace=Envs.K8S_NAMESPACE),
+ ])
+ mock_k8s_client.delete_config_map.assert_called_once_with(
+ name=f'{serving_model.serving_deployment.deployment_name}-config', namespace=Envs.K8S_NAMESPACE)
+ mock_k8s_client.delete_service.assert_called_once_with(name=serving_model.serving_deployment.deployment_name,
+ namespace=Envs.K8S_NAMESPACE)
+
+ def test_get_pods_info(self):
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).filter_by(id=self.serving_model_id).one()
+ deployment_name = serving_model.serving_deployment.deployment_name
+
+ info = ServingDeploymentService.get_pods_info(deployment_name)
+ self.assertEqual(len(info), 1)
+
+
+class SavedModelServiceTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+
+ with open(os.path.join(Envs.BASE_DIR, 'testing/test_data/saved_model.pbtxt'), 'rt', encoding='utf-8') as f:
+ self.saved_model_text = f.read()
+
+ self.saved_model_message = text_format.Parse(self.saved_model_text, saved_model_pb2.SavedModel())
+ self.graph = self.saved_model_message.meta_graphs[0].graph_def
+
+ def test_get_nodes_from_graph(self):
+ parse_example_node = SavedModelService.get_nodes_from_graph(
+ self.graph, ['ParseExample/ParseExample'])['ParseExample/ParseExample']
+ self.assertEqual(parse_example_node.name, 'ParseExample/ParseExample')
+ self.assertEqual(parse_example_node.op, 'ParseExample')
+
+ dense_nodes = SavedModelService.get_nodes_from_graph(
+ self.graph, ['ParseExample/ParseExample/dense_keys_0', 'ParseExample/ParseExample/dense_keys_1'])
+ self.assertEqual(len(dense_nodes), 2)
+
+ def test_get_parse_example_details(self):
+ signatures = SavedModelService.get_parse_example_details(self.saved_model_message.SerializeToString())
+ self.assertCountEqual([i.name for i in signatures.inputs], ['example_id', 'x'])
+ self.assertCountEqual([i.type for i in signatures.inputs], ['DT_STRING', 'DT_FLOAT'])
+ self.assertCountEqual([i.dim for i in signatures.inputs], [[], [392]])
+
+
+# Use BaseTestCase instead of NoWebServerTestCase to get system.variables.labels when generate deployment yaml
+class ServingModelServiceTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ # insert project
+ with db.session_scope() as session:
+ project = Project()
+ project.name = 'test_project_name'
+ session.add(project)
+ session.flush([project])
+
+ model_job_group = ModelJobGroup()
+ session.add(model_job_group)
+ session.flush([model_job_group])
+
+ model_1 = Model()
+ model_1.name = 'test_model_name_1'
+ model_1.model_path = '/test_path_1/'
+ model_1.uuid = 'test_uuid_1'
+ model_1.project_id = project.id
+ model_1.version = 1
+ model_1.group_id = model_job_group.id
+
+ model_2 = Model()
+ model_2.name = 'test_model_name_2'
+ model_2.model_path = '/test_path_2/'
+ model_2.project_id = project.id
+ model_2.version = 2
+ model_2.group_id = model_job_group.id
+
+ session.add_all([model_1, model_2])
+ session.commit()
+ self.project_id = project.id
+ self.model_1_id = model_1.id
+ self.model_1_uuid = model_1.uuid
+ self.model_2_id = model_2.id
+ self.model_group_id = model_job_group.id
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_create_from_param(self, mock_create_deployment: MagicMock):
+ name = 'test-serving-service-1'
+ resource = serving_pb2.ServingServiceResource(
+ cpu='1',
+ memory='2',
+ replicas=3,
+ )
+ param = {
+ 'model_id': self.model_1_id,
+ 'name': name,
+ 'comment': '',
+ 'resource': resource,
+ 'is_local': True,
+ }
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name=param['name'],
+ is_local=param['is_local'],
+ comment=param['comment'],
+ model_id=param['model_id'],
+ model_group_id=None,
+ resource=param['resource'])
+ session.commit()
+ serving_model_id = serving_model.id
+
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertEqual(name, serving_model.name)
+ self.assertEqual('/test_path_1/exported_models', serving_model.model_path)
+ serving_deployment = serving_model.serving_deployment
+ deployment_name_substr = f'serving-{serving_model_id}-'
+ self.assertIn(deployment_name_substr, serving_deployment.deployment_name)
+ self.assertEqual(
+ serving_deployment.resource,
+ json.dumps({
+ 'cpu': resource.cpu,
+ 'memory': resource.memory,
+ 'replicas': resource.replicas,
+ }))
+ serving_negotiator = session.query(ServingNegotiator).filter_by(
+ serving_model_id=serving_model_id).one_or_none()
+ self.assertIsNotNone(serving_negotiator)
+ self.assertEqual(serving_negotiator.project_id, self.project_id)
+
+ mock_create_deployment.assert_called_once()
+
+ name = 'test-auto-update-1'
+ param = {
+ 'model_group_id': self.model_group_id,
+ 'name': name,
+ 'resource': resource,
+ 'is_local': True,
+ }
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name=param['name'],
+ is_local=param['is_local'],
+ comment=None,
+ model_id=None,
+ model_group_id=param['model_group_id'],
+ resource=param['resource'])
+ session.commit()
+ serving_model_id = serving_model.id
+
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertEqual(self.model_2_id, serving_model.model_id)
+ self.assertEqual('/test_path_2/exported_models', serving_model.model_path)
+
+ @patch('fedlearner_webconsole.utils.flask_utils.get_current_user', MagicMock(return_value=User(username='test')))
+ def test_create_remote_serving_from_param(self):
+ reckon_remote_serving = FakeRemoteServing()
+ register_remote_serving(FakeRemoteServing.SERVING_PLATFORM, reckon_remote_serving)
+ name = 'test-remote-serving-1'
+ remote_platform = serving_pb2.ServingServiceRemotePlatform(platform=FakeRemoteServing.SERVING_PLATFORM,
+ payload='test-payload')
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name=name,
+ is_local=True,
+ comment=None,
+ model_id=None,
+ model_group_id=self.model_group_id,
+ resource=None,
+ remote_platform=remote_platform)
+ session.commit()
+ serving_model_id = serving_model.id
+
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertEqual(FakeRemoteServing.DEPLOY_URL, serving_model.endpoint)
+ deploy_platform = serving_pb2.RemoteDeployConfig(
+ platform=FakeRemoteServing.SERVING_PLATFORM,
+ payload='test-payload',
+ deploy_id=1,
+ deploy_name=f'privacy-platform-test-remote-serving-{serving_model_id}',
+ model_src_path='',
+ )
+ self.assertEqual(json.dumps(MessageToDict(deploy_platform)),
+ serving_model.serving_deployment.deploy_platform)
+
+ @patch('fedlearner_webconsole.utils.flask_utils.get_current_user', MagicMock(return_value=User(username='test')))
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_get_detail(self, mock_create_deployment: MagicMock):
+ name = 'test-serving-service-1'
+ resource = serving_pb2.ServingServiceResource(
+ cpu='1',
+ memory='2',
+ replicas=3,
+ )
+ param = {
+ 'model_id': self.model_1_id,
+ 'name': name,
+ 'comment': '',
+ 'resource': resource,
+ 'is_local': False,
+ }
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name=param['name'],
+ is_local=param['is_local'],
+ comment=param['comment'],
+ model_id=param['model_id'],
+ model_group_id=None,
+ resource=param['resource'])
+ detail = serving_model_service.get_serving_service_detail(serving_model.id, serving_model.project_id)
+ self.assertEqual(name, detail.name)
+ self.assertEqual(ServingModelStatus.PENDING_ACCEPT.name, detail.status)
+ try:
+ serving_model_service.get_serving_service_detail(serving_model.id + 1)
+ except NotFoundException:
+ pass
+ try:
+ serving_model_service.get_serving_service_detail(serving_model.id, serving_model.project_id + 1)
+ except NotFoundException:
+ pass
+
+ # get remote serving detail
+ name = 'test-remote-serving-1'
+ reckon_remote_serving = FakeRemoteServing()
+ register_remote_serving(FakeRemoteServing.SERVING_PLATFORM, reckon_remote_serving)
+ remote_platform = serving_pb2.ServingServiceRemotePlatform(platform=FakeRemoteServing.SERVING_PLATFORM,
+ payload='test-payload')
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name=name,
+ is_local=False,
+ comment=None,
+ model_id=self.model_1_id,
+ model_group_id=None,
+ resource=None,
+ remote_platform=remote_platform)
+ detail = serving_model_service.get_serving_service_detail(serving_model.id, serving_model.project_id)
+ expected_detail = serving_pb2.ServingServiceDetail(id=serving_model.id,
+ project_id=self.project_id,
+ name=name,
+ model_id=self.model_1_id,
+ model_type=ModelType.NN_MODEL.name,
+ is_local=False,
+ endpoint='test_deploy_url',
+ instance_num_status='UNKNOWN',
+ status=ServingModelStatus.AVAILABLE.name,
+ support_inference=True)
+ expected_detail.remote_platform.CopyFrom(
+ serving_pb2.ServingServiceRemotePlatform(
+ platform='unittest_mock',
+ payload='test-payload',
+ ))
+ self.assertPartiallyEqual(to_dict(expected_detail),
+ to_dict(detail),
+ ignore_fields=['created_at', 'updated_at'])
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_set_ref(self, mock_create_deployment: MagicMock):
+ name = 'test-serving-service-1'
+ resource = serving_pb2.ServingServiceResource(
+ cpu='1',
+ memory='2',
+ replicas=3,
+ )
+ param = {
+ 'model_id': self.model_1_id,
+ 'name': name,
+ 'comment': '',
+ 'resource': resource,
+ 'is_local': False,
+ }
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name=param['name'],
+ is_local=param['is_local'],
+ comment=param['comment'],
+ model_id=param['model_id'],
+ model_group_id=None,
+ resource=param['resource'])
+ serving_service = serving_model.to_serving_service()
+ self.assertTrue(serving_service.is_local) # default value
+ serving_model_service = ServingModelService(session)
+ serving_model_service.set_resource_and_status_on_ref(serving_service, serving_model)
+ serving_model_service.set_is_local_on_ref(serving_service, serving_model)
+ self.assertEqual('UNKNOWN', serving_service.instance_num_status)
+ self.assertEqual('1', serving_service.resource.cpu)
+ self.assertEqual('2', serving_service.resource.memory)
+ self.assertEqual(3, serving_service.resource.replicas)
+ self.assertFalse(serving_service.is_local)
+
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_update_model(self, mock_create_deployment: MagicMock):
+ resource = serving_pb2.ServingServiceResource(
+ cpu='1',
+ memory='2',
+ replicas=3,
+ )
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name='test-serving-service-1',
+ is_local=True,
+ comment='',
+ model_id=self.model_1_id,
+ model_group_id=None,
+ resource=resource)
+
+ need_update = serving_model_service.update_model(model_id=None,
+ model_group_id=self.model_group_id,
+ serving_model=serving_model)
+ self.assertEqual(True, need_update)
+ self.assertEqual(self.model_2_id, serving_model.model_id)
+
+ need_update = serving_model_service.update_model(model_id=self.model_2_id,
+ model_group_id=self.model_group_id,
+ serving_model=serving_model)
+ self.assertEqual(False, need_update)
+ self.assertEqual(self.model_2_id, serving_model.model_id)
+ self.assertIsNone(serving_model.model_group_id)
+
+ need_update = serving_model_service.update_model(model_id=self.model_1_id,
+ model_group_id=self.model_group_id,
+ serving_model=serving_model)
+ self.assertEqual(True, need_update)
+ self.assertEqual(self.model_1_id, serving_model.model_id)
+ self.assertIsNone(serving_model.model_group_id)
+
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name='test-serving-service-2',
+ is_local=False,
+ comment='',
+ model_id=self.model_1_id,
+ model_group_id=None,
+ resource=resource)
+ with self.assertRaises(InvalidArgumentException):
+ serving_model_service.update_model(model_id=None,
+ model_group_id=self.model_group_id,
+ serving_model=serving_model)
+
+ @patch('fedlearner_webconsole.utils.flask_utils.get_current_user', MagicMock(return_value=User(username='test')))
+ @patch('fedlearner_webconsole.serving.services.k8s_client')
+ @patch('fedlearner_webconsole.serving.services.ServingDeploymentService.create_or_update_deployment')
+ def test_delete_serving(self, mock_create_deployment: MagicMock, mock_k8s_client: MagicMock):
+ mock_k8s_client.delete_config_map = MagicMock()
+ mock_k8s_client.delete_app = MagicMock()
+ mock_k8s_client.delete_service = MagicMock()
+ # delete serving inside platform
+ resource = serving_pb2.ServingServiceResource(
+ cpu='1',
+ memory='2',
+ replicas=3,
+ )
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name='test-serving-service-1',
+ is_local=True,
+ comment='',
+ model_id=self.model_1_id,
+ model_group_id=None,
+ resource=resource)
+ serving_model_id = serving_model.id
+ serving_deployment_id = serving_model.serving_deployment_id
+ serving_model_service.delete_serving_service(serving_model)
+ mock_k8s_client.delete_config_map.assert_called_once()
+ mock_k8s_client.delete_app.assert_called_once()
+ mock_k8s_client.delete_service.assert_called_once()
+ # check db
+ with db.session_scope() as session:
+ serving_model = session.query(ServingModel).get(serving_model_id)
+ self.assertIsNone(serving_model)
+ serving_deployment = session.query(ServingDeployment).get(serving_deployment_id)
+ self.assertIsNone(serving_deployment)
+ negotiator = session.query(ServingNegotiator).filter_by(serving_model_id=serving_model_id).one_or_none()
+ self.assertIsNone(negotiator)
+
+ # delete remote serving
+ reckon_remote_serving = FakeRemoteServing()
+ register_remote_serving(FakeRemoteServing.SERVING_PLATFORM, reckon_remote_serving)
+ remote_platform = serving_pb2.ServingServiceRemotePlatform(platform=FakeRemoteServing.SERVING_PLATFORM,
+ payload='test-payload')
+ with db.session_scope() as session:
+ serving_model_service = ServingModelService(session)
+ serving_model = serving_model_service.create_from_param(project_id=self.project_id,
+ name='test-remote-serving-1',
+ is_local=True,
+ comment=None,
+ model_id=None,
+ model_group_id=self.model_group_id,
+ resource=None,
+ remote_platform=remote_platform)
+ serving_model_service.delete_serving_service(serving_model)
+ # called times not increased
+ mock_k8s_client.delete_config_map.assert_called_once()
+ mock_k8s_client.delete_app.assert_called_once()
+ mock_k8s_client.delete_service.assert_called_once()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/serving_yaml_template.py b/web_console_v2/api/fedlearner_webconsole/serving/serving_yaml_template.py
new file mode 100644
index 000000000..bb7aadbf1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/serving_yaml_template.py
@@ -0,0 +1,154 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict, Union
+from sqlalchemy.ext.declarative import DeclarativeMeta
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.utils.flask_utils import get_current_user
+from fedlearner_webconsole.utils.pp_yaml import compile_yaml_template,\
+ add_username_in_label, GenerateDictService
+
+CONFIG_MAP_TEMPLATE = """{
+ "apiVersion": "v1",
+ "kind": "ConfigMap",
+ "metadata": {
+ "name": self.name + "-config"
+ },
+ "data": {
+ "config.pb": "model_config_list {\\n config {\\n name: '" + self.name + "'\\n base_path: '" + model.base_path + "'\\n model_platform: 'tensorflow'\\n }\\n}\\n"
+ }
+}"""
+
+DEPLOYMENT_TEMPLATE = """{
+ "apiVersion": "apps/v1",
+ "kind": "Deployment",
+ "metadata": {
+ "name": self.name,
+ "labels": system.variables.labels,
+ "annotations": {
+ "queue": "fedlearner",
+ "schedulerName": "batch",
+ "min-member": "1",
+ "resource-cpu": str(self.resource.resource.cpu),
+ "resource-mem": str(self.resource.resource.memory),
+ },
+ },
+ "spec": {
+ "selector": {
+ "matchLabels": {
+ "app": self.name
+ }
+ },
+ "replicas": int(self.resource.replicas),
+ "template": {
+ "metadata": {
+ "labels": {
+ "app": self.name
+ }
+ },
+ "spec": {
+ "volumes": [
+ {
+ "name": self.name+ "-config",
+ "configMap": {
+ "name": self.name + "-config"
+ }
+ }
+ ] + list(system.variables.volumes_list),
+ "containers": [
+ {
+ "name": self.name,
+ "image": system.variables.serving_image,
+ "resources": {
+ "limits": dict(self.resource.resource)
+ },
+ "args": [
+ "--port=8500",
+ "--rest_api_port=8501",
+ "--model_config_file=/app/config/config.pb"
+ ],
+ "env": system.basic_envs_list,
+ "ports": [
+ {
+ "containerPort": 8500,
+ "name": "grpc",
+ },
+ {
+ "containerPort": 8501,
+ "name": "restful",
+ }
+ ],
+ "volumeMounts": [
+ {
+ "name": self.name + "-config",
+ "mountPath": "/app/config/"
+ }
+ ] + list(system.variables.volume_mounts_list)
+ }
+ ]
+ }
+ }
+ }
+}"""
+
+SERVICE_TEMPLATE = """{
+ "apiVersion": "v1",
+ "kind": "Service",
+ "metadata": {
+ "name": self.name
+ },
+ "spec": {
+ "selector": {
+ "app": self.name
+ },
+ "ports": [
+ {
+ "port": 8501,
+ "targetPort": "restful",
+ "name": "restful",
+ },
+ {
+ "port": 8500,
+ "targetPort": "grpc",
+ "name": "grpc",
+ }
+ ]
+ }
+}"""
+
+
+def generate_self_dict(serving: Union[Dict, DeclarativeMeta]) -> Dict:
+ if not isinstance(serving, dict):
+ serving = serving.to_dict()
+ return serving
+
+
+def generate_model_dict(model: Union[Dict, DeclarativeMeta]) -> Dict:
+ if not isinstance(model, dict):
+ model = model.to_dict()
+ return model
+
+
+def generate_serving_yaml(serving: Dict[str, Union[Dict, DeclarativeMeta]], yaml_template: str,
+ session: Session) -> Dict:
+ result_dict = compile_yaml_template(
+ yaml_template,
+ post_processors=[
+ lambda loaded_json: add_username_in_label(loaded_json, getattr(get_current_user(), 'username', None))
+ ],
+ system=GenerateDictService(session).generate_system_dict(),
+ model=generate_model_dict(serving['model']),
+ self=generate_self_dict(serving['serving']))
+ return result_dict
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/serving_yaml_template_test.py b/web_console_v2/api/fedlearner_webconsole/serving/serving_yaml_template_test.py
new file mode 100644
index 000000000..89156d2cb
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/serving_yaml_template_test.py
@@ -0,0 +1,91 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch
+
+from google.protobuf.text_format import Parse
+from tensorflow_serving.config.model_server_config_pb2 import ModelServerConfig
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.serving.serving_yaml_template import (CONFIG_MAP_TEMPLATE, DEPLOYMENT_TEMPLATE,
+ SERVICE_TEMPLATE, generate_serving_yaml)
+
+
+class ServingYamlTemplateTest(unittest.TestCase):
+
+ def setUp(self):
+
+ self.patcher_generate_system_dict = patch(
+ 'fedlearner_webconsole.serving.serving_yaml_template.GenerateDictService.generate_system_dict')
+ mock = self.patcher_generate_system_dict.start()
+ mock.return_value = {
+ 'basic_envs_list': {
+ 'name': 'HADOOP_HOME',
+ 'value': '/hadoop/'
+ },
+ 'variables': {
+ 'labels': {},
+ 'serving_image': 'dockerhub.com/fedlearner/serving:latest',
+ 'volumes_list': [{}],
+ 'volume_mounts_list': [{}],
+ }
+ }
+
+ self.serving = {
+ 'project': None,
+ 'model': {
+ 'base_path': '/test',
+ },
+ 'serving': {
+ 'name': 'serving-demo',
+ 'resource': {
+ 'resource': {
+ 'cpu': '4000m',
+ 'memory': '4Gi',
+ },
+ 'replicas': 2,
+ },
+ },
+ }
+ return super().setUp()
+
+ def tearDown(self):
+ self.patcher_generate_system_dict.stop()
+ return super().tearDown()
+
+ def test_config_map(self):
+ with db.session_scope() as session:
+ config_map_object = generate_serving_yaml(self.serving, CONFIG_MAP_TEMPLATE, session)
+ config = Parse(config_map_object['data']['config.pb'], ModelServerConfig())
+ self.assertEqual(config.model_config_list.config[0].base_path, '/test')
+
+ def test_deployment(self):
+ with db.session_scope() as session:
+ deployment_object = generate_serving_yaml(self.serving, DEPLOYMENT_TEMPLATE, session)
+ self.assertEqual('4000m',
+ deployment_object['spec']['template']['spec']['containers'][0]['resources']['limits']['cpu'])
+ self.assertEqual('serving-demo', deployment_object['metadata']['name'])
+ self.assertEqual('4000m', deployment_object['metadata']['annotations']['resource-cpu'])
+ self.assertEqual('4Gi', deployment_object['metadata']['annotations']['resource-mem'])
+
+ def test_service(self):
+ with db.session_scope() as session:
+ service_object = generate_serving_yaml(self.serving, SERVICE_TEMPLATE, session)
+ self.assertEqual('serving-demo', service_object['metadata']['name'])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/utils.py b/web_console_v2/api/fedlearner_webconsole/serving/utils.py
new file mode 100644
index 000000000..f055a28f0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/utils.py
@@ -0,0 +1,34 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.exceptions import NotFoundException
+from fedlearner_webconsole.mmgr.models import Model
+from fedlearner_webconsole.serving.models import ServingNegotiator
+
+
+def get_model(model_id: int, session: Session) -> Model:
+ model = session.query(Model).get(model_id)
+ if model is None:
+ raise NotFoundException(f'[Serving] model {model_id} is not found')
+ return model
+
+
+def get_serving_negotiator_by_serving_model_id(serving_model_id: int, session: Session) -> Optional[ServingNegotiator]:
+ serving_negotiator = session.query(ServingNegotiator).filter_by(serving_model_id=serving_model_id).one_or_none()
+ return serving_negotiator
diff --git a/web_console_v2/api/fedlearner_webconsole/serving/utils_test.py b/web_console_v2/api/fedlearner_webconsole/serving/utils_test.py
new file mode 100644
index 000000000..15caa0607
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/serving/utils_test.py
@@ -0,0 +1,59 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.mmgr.models import Model
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.serving.utils import get_model
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ServingServicesUtilsTest(NoWebServerTestCase):
+
+ def setUp(self):
+ self.maxDiff = None
+ super().setUp()
+ # insert project
+ with db.session_scope() as session:
+ project = Project()
+ project.name = 'test_project_name'
+ session.add(project)
+ session.flush([project])
+
+ model = Model()
+ model.name = 'test_model_name'
+ model.model_path = '/test_path/'
+ model.group_id = 1
+ model.uuid = 'test_uuid_1'
+ model.project_id = project.id
+
+ session.add(model)
+ session.commit()
+ self.project_id = project.id
+ self.model_id = model.id
+ self.model_uuid = model.uuid
+
+ def test_get_model(self):
+ with db.session_scope() as session:
+ model = get_model(self.model_id, session)
+ self.assertEqual(self.project_id, model.project_id)
+ self.assertEqual(self.model_id, model.id)
+ self.assertEqual(self.model_uuid, model.uuid)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/setting/BUILD.bazel
new file mode 100644
index 000000000..56802f042
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/setting/BUILD.bazel
@@ -0,0 +1,105 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_test",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "service_lib",
+ srcs = ["service.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:app_version_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "service_test",
+ srcs = [
+ "service_test.py",
+ ],
+ imports = ["../.."],
+ main = "service_test.py",
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:app_version_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/apis.py b/web_console_v2/api/fedlearner_webconsole/setting/apis.py
index 339406ae4..58aaaaabb 100644
--- a/web_console_v2/api/fedlearner_webconsole/setting/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/setting/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,99 +13,327 @@
# limitations under the License.
# coding: utf-8
+import logging
+from http import HTTPStatus
from pathlib import Path
+from flask_restful import Resource
+from google.protobuf.json_format import ParseDict, ParseError
+from marshmallow import fields
-from flask_restful import Resource, reqparse
-
-from fedlearner_webconsole.utils.k8s_client import k8s_client
-from fedlearner_webconsole.utils.decorators import jwt_required
-from fedlearner_webconsole.utils.decorators import admin_required
+from fedlearner_webconsole.k8s.k8s_client import k8s_client
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.proto.setting_pb2 import SystemVariables, SettingPb
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required, use_kwargs, use_args
+from fedlearner_webconsole.setting.service import DashboardService, SettingService
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import (NotFoundException, NoAccessException, InvalidArgumentException)
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
+from fedlearner_webconsole.flag.models import Flag
_POD_NAMESPACE = 'default'
# Ref: https://stackoverflow.com/questions/46046110/
# how-to-get-the-current-namespace-in-a-pod
-_k8s_namespace_file = Path(
- '/var/run/secrets/kubernetes.io/serviceaccount/namespace')
+_k8s_namespace_file = Path('/var/run/secrets/kubernetes.io/serviceaccount/namespace')
if _k8s_namespace_file.is_file():
- _POD_NAMESPACE = _k8s_namespace_file.read_text()
+ _POD_NAMESPACE = _k8s_namespace_file.read_text(encoding='utf-8')
+
+_SPECIAL_KEYS = ['webconsole_image', 'system_info', 'system_variables']
-class SettingsApi(Resource):
- @jwt_required()
+class SettingApi(Resource):
+
+ @credentials_required
@admin_required
- def get(self):
- deployment = k8s_client.get_deployment(
- name='fedlearner-web-console-v2', namespace=_POD_NAMESPACE)
+ def _get_webconsole_image(self) -> SettingPb:
+ try:
+ deployment = k8s_client.get_deployment(name='fedlearner-web-console-v2')
+ image = deployment.spec.template.spec.containers[0].image
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'settings: get deployment: {str(e)}')
+ image = None
+ return SettingPb(
+ uniq_key='webconsole_image',
+ value=image,
+ )
+
+ @credentials_required
+ @admin_required
+ def _get_system_variables(self) -> SystemVariables:
+ with db.session_scope() as session:
+ return SettingService(session).get_system_variables()
+
+ def get(self, key: str):
+ """Gets a specific setting.
+ ---
+ tags:
+ - system
+ description: gets a specific setting.
+ parameters:
+ - in: path
+ name: key
+ schema:
+ type: string
+ required: true
+ responses:
+ 200:
+ description: the setting
+ content:
+ application/json:
+ schema:
+ oneOf:
+ - $ref: '#/definitions/fedlearner_webconsole.proto.SettingPb'
+ - $ref: '#/definitions/fedlearner_webconsole.proto.SystemVariables'
+ - $ref: '#/definitions/fedlearner_webconsole.proto.SystemInfo'
+ """
+ if key == 'webconsole_image':
+ return make_flask_response(self._get_webconsole_image())
+
+ if key == 'system_variables':
+ return make_flask_response(self._get_system_variables())
+
+ if key == 'system_info':
+ return make_flask_response(SettingService.get_system_info())
+
+ setting = None
+ if key not in _SPECIAL_KEYS:
+ with db.session_scope() as session:
+ setting = SettingService(session).get_setting(key)
+ if setting is None:
+ raise NotFoundException(message=f'Failed to find setting {key}')
+ return make_flask_response(setting.to_proto())
+
+ @credentials_required
+ @admin_required
+ @use_kwargs({'value': fields.String(required=True)})
+ def put(self, key: str, value: str):
+ """Updates a specific setting.
+ ---
+ tags:
+ - system
+ description: updates a specific setting.
+ parameters:
+ - in: path
+ name: key
+ schema:
+ type: string
+ required: true
+ - in: body
+ name: body
+ schema:
+ type: object
+ properties:
+ value:
+ type: str
+ required: true
+ responses:
+ 200:
+ description: logs
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
+ if key in _SPECIAL_KEYS:
+ raise NoAccessException(message=f'Not able to update {key}')
+
+ with db.session_scope() as session:
+ setting = SettingService(session).create_or_update_setting(key, value)
+ return make_flask_response(setting.to_proto())
+
- return {
- 'data': {
- 'webconsole_image':
- deployment.spec.template.spec.containers[0].image
- }
- }
+class UpdateSystemVariablesApi(Resource):
- @jwt_required()
+ @credentials_required
@admin_required
- def patch(self):
- parser = reqparse.RequestParser()
- parser.add_argument('webconsole_image',
- type=str,
- required=False,
- default=None,
- help='image for webconsole')
- data = parser.parse_args()
-
- if data['webconsole_image']:
- new_image = data['webconsole_image']
- deployment = k8s_client.get_deployment('fedlearner-web-console-v2',
- _POD_NAMESPACE)
- spec = deployment.spec
- spec.template.spec.containers[0].image = new_image
- metadata = deployment.metadata
- k8s_client.create_or_update_deployment(
- metadata=metadata,
- spec=spec,
- name=metadata.name,
- namespace=metadata.namespace)
-
- return {'data': {}}
+ @use_args({'variables': fields.List(fields.Dict())})
+ def post(self, params: dict):
+ """Updates system variables.
+ ---
+ tags:
+ - system
+ description: updates all system variables.
+ parameters:
+ - in: body
+ name: body
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SystemVariables'
+ responses:
+ 200:
+ description: updated system variables
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SystemVariables'
+ """
+ try:
+ system_variables = ParseDict(params, SystemVariables())
+ except ParseError as e:
+ raise InvalidArgumentException(details=str(e)) from e
+ with db.session_scope() as session:
+ # TODO(xiangyuxuan.prs): check fixed flag
+ SettingService(session).set_system_variables(system_variables)
+ session.commit()
+ return make_flask_response(system_variables)
+
+
+class UpdateImageApi(Resource):
+
+ @credentials_required
+ @admin_required
+ @use_kwargs({'webconsole_image': fields.String(required=True)})
+ def post(self, webconsole_image: str):
+ """Updates webconsole image.
+ ---
+ tags:
+ - system
+ description: updates webconsole image.
+ parameters:
+ - in: body
+ name: body
+ schema:
+ type: object
+ properties:
+ image_uri:
+ type: string
+ required: true
+ responses:
+ 204:
+ description: updated successfully
+ """
+ deployment = k8s_client.get_deployment('fedlearner-web-console-v2')
+ spec = deployment.spec
+ spec.template.spec.containers[0].image = webconsole_image
+ metadata = deployment.metadata
+ k8s_client.create_or_update_deployment(metadata=metadata,
+ spec=spec,
+ name=metadata.name,
+ namespace=metadata.namespace)
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
class SystemPodLogsApi(Resource):
- @jwt_required()
+
+ @credentials_required
@admin_required
- def get(self, pod_name):
- parser = reqparse.RequestParser()
- parser.add_argument('tail_lines',
- type=int,
- location='args',
- required=True,
- help='tail lines is required')
- data = parser.parse_args()
- tail_lines = data['tail_lines']
- return {
- 'data':
- k8s_client.get_pod_log(name=pod_name,
- namespace=_POD_NAMESPACE,
- tail_lines=tail_lines).split('\n')
- }
+ @use_kwargs({'tail_lines': fields.Integer(required=True)}, location='query')
+ def get(self, pod_name: str, tail_lines: int):
+ """Gets webconsole pod logs.
+ ---
+ tags:
+ - system
+ description: gets webconsole pod logs.
+ parameters:
+ - in: path
+ name: pod_name
+ schema:
+ type: string
+ required: true
+ - in: query
+ name: tail_lines
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: logs
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
+ return make_flask_response(
+ k8s_client.get_pod_log(name=pod_name, namespace=_POD_NAMESPACE, tail_lines=tail_lines).split('\n'))
class SystemPodsApi(Resource):
- @jwt_required()
+
+ @credentials_required
@admin_required
def get(self):
+ """Gets webconsole pods.
+ ---
+ tags:
+ - system
+ description: gets webconsole pods.
+ responses:
+ 200:
+ description: name list of pods
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
webconsole_v2_pod_list = list(
- map(
- lambda pod: pod.metadata.name,
- k8s_client.get_pods(
- _POD_NAMESPACE,
- 'app.kubernetes.io/instance=fedlearner-web-console-v2').
- items))
- return {'data': webconsole_v2_pod_list}
+ map(lambda pod: pod.metadata.name,
+ k8s_client.get_pods(_POD_NAMESPACE, 'app.kubernetes.io/instance=fedlearner-web-console-v2').items))
+ return make_flask_response(webconsole_v2_pod_list)
+
+
+class VersionsApi(Resource):
+ # This is a system-based api, no JWT-Token for now.
+ def get(self):
+ """Gets the version info.
+ ---
+ tags:
+ - system
+ description: gets the version info.
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.ApplicationVersion'
+ """
+ return make_flask_response(SettingService.get_application_version().to_proto())
+
+
+class DashboardsApi(Resource):
+
+ @credentials_required
+ @admin_required
+ def get(self):
+ """Get dashboard information API
+ ---
+ tags:
+ - system
+ description: Get dashboard information API
+ responses:
+ 200:
+ description: a list of dashboard information. Note that the following dashboard ['overview'] is available.
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.DashboardInformation'
+ 500:
+ description: dashboard setup is wrong, please check.
+ content:
+ appliction/json:
+ schema:
+ type: object
+ properties:
+ code:
+ type: integer
+ message:
+ type: string
+ """
+ if not Flag.DASHBOARD_ENABLED.value:
+ raise NoAccessException('if you want to view dashboard, please enable flag `DASHBOARD_ENABLED`')
+ return make_flask_response(DashboardService().get_dashboards())
def initialize_setting_apis(api):
- api.add_resource(SettingsApi, '/settings')
+ api.add_resource(UpdateSystemVariablesApi, '/settings:update_system_variables')
+ api.add_resource(UpdateImageApi, '/settings:update_image')
+ api.add_resource(SettingApi, '/settings/')
+ api.add_resource(VersionsApi, '/versions')
api.add_resource(SystemPodLogsApi, '/system_pods//logs')
api.add_resource(SystemPodsApi, '/system_pods/name')
+ api.add_resource(DashboardsApi, '/dashboards')
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/apis_test.py b/web_console_v2/api/fedlearner_webconsole/setting/apis_test.py
new file mode 100644
index 000000000..0e6ef20b2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/setting/apis_test.py
@@ -0,0 +1,322 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import json
+import logging
+import os
+import unittest
+from http import HTTPStatus
+from types import SimpleNamespace
+from unittest.mock import patch, MagicMock
+
+from google.protobuf.struct_pb2 import Value
+
+from envs import Envs
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.setting_pb2 import DashboardInformation, SystemVariables, SystemVariable, SystemInfo
+from fedlearner_webconsole.setting.apis import _POD_NAMESPACE
+from fedlearner_webconsole.setting.models import Setting
+from fedlearner_webconsole.setting.service import SettingService
+
+from testing.common import BaseTestCase
+
+
+class SettingApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ setting = Setting(uniq_key='key1', value='value 1')
+ session.add(setting)
+ session.commit()
+
+ @patch('fedlearner_webconsole.setting.apis.k8s_client')
+ def test_get_webconsole_image(self, mock_k8s_client: MagicMock):
+ deployment = SimpleNamespace(
+ **{
+ 'metadata':
+ SimpleNamespace(**{
+ 'name': 'fedlearner-web-console-v2',
+ 'namespace': 'testns'
+ }),
+ 'spec':
+ SimpleNamespace(
+ **{
+ 'template':
+ SimpleNamespace(
+ **{
+ 'spec':
+ SimpleNamespace(
+ **{'containers': [SimpleNamespace(**{'image': 'fedlearner:test'})]})
+ })
+ })
+ })
+ mock_k8s_client.get_deployment = MagicMock(return_value=deployment)
+ resp = self.get_helper('/api/v2/settings/webconsole_image')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ self.signin_as_admin()
+ resp = self.get_helper('/api/v2/settings/webconsole_image')
+ self.assertResponseDataEqual(resp, {
+ 'uniq_key': 'webconsole_image',
+ 'value': 'fedlearner:test',
+ })
+
+ def test_get_system_variables(self):
+ system_variables = SystemVariables(variables=[
+ SystemVariable(name='test1', value_type=SystemVariable.ValueType.INT, value=Value(number_value=1))
+ ])
+ with db.session_scope() as session:
+ SettingService(session).set_system_variables(system_variables)
+ session.commit()
+ resp = self.get_helper('/api/v2/settings/system_variables')
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ self.signin_as_admin()
+ resp = self.get_helper('/api/v2/settings/system_variables')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ resp, {'variables': [{
+ 'name': 'test1',
+ 'value': 1.0,
+ 'value_type': 'INT',
+ 'fixed': False
+ }]})
+
+ def test_get(self):
+ resp = self.get_helper('/api/v2/settings/key1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp)['value'], 'value 1')
+ # Black list one
+ resp = self.get_helper('/api/v2/settings/variables')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ resp = self.get_helper('/api/v2/settings/key2')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_put(self):
+ resp = self.put_helper('/api/v2/settings/key1', data={'value': 'new value'})
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+ self.signin_as_admin()
+ resp = self.put_helper('/api/v2/settings/key1', data={'value': 'new value'})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp)['value'], 'new value')
+ with db.session_scope() as session:
+ setting = session.query(Setting).filter_by(uniq_key='key1').first()
+ self.assertEqual(setting.value, 'new value')
+ # Black list one
+ resp = self.put_helper('/api/v2/settings/system_variables', data={'value': 'new value'})
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+
+
+class SettingsApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ self._system_pods = SimpleNamespace(
+ **{
+ 'items': [
+ SimpleNamespace(**{'metadata': SimpleNamespace(**{'name': 'fake-fedlearner-web-console-v2-1'})}),
+ SimpleNamespace(**{'metadata': SimpleNamespace(**{'name': 'fake-fedlearner-web-console-v2-2'})}),
+ ]
+ })
+ self._system_pod_log = 'log1\nlog2'
+ self._mock_k8s_client = MagicMock()
+ self._mock_k8s_client.get_pods = MagicMock(return_value=self._system_pods)
+ self._mock_k8s_client.get_pod_log = MagicMock(return_value=self._system_pod_log)
+ self.signin_as_admin()
+
+ def test_get_system_pods(self):
+ with patch('fedlearner_webconsole.setting.apis.k8s_client', self._mock_k8s_client):
+ resp = self.get_helper('/api/v2/system_pods/name')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp),
+ ['fake-fedlearner-web-console-v2-1', 'fake-fedlearner-web-console-v2-2'])
+
+ def test_get_system_pods_log(self):
+ fake_pod_name = 'fake-fedlearner-web-console-v2-1'
+ with patch('fedlearner_webconsole.setting.apis.k8s_client', self._mock_k8s_client):
+ resp = self.get_helper(f'/api/v2/system_pods/{fake_pod_name}/logs?tail_lines=100')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp), ['log1', 'log2'])
+ self._mock_k8s_client.get_pod_log.assert_called_with(name=fake_pod_name,
+ namespace=_POD_NAMESPACE,
+ tail_lines=100)
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info',
+ lambda: SystemInfo(name='hahaha', domain_name='fl-test.com', pure_domain_name='test'))
+ def test_get_own_info_api(self):
+ resp = self.get_helper('/api/v2/settings/system_info')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'name': 'hahaha', 'domain_name': 'fl-test.com', 'pure_domain_name': 'test'})
+
+
+class UpdateSystemVariablesApi(BaseTestCase):
+
+ def test_post_no_permission(self):
+ resp = self.post_helper('/api/v2/settings:update_system_variables', data={'variables': []})
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+
+ def test_post_invalid_argument(self):
+ system_variables = SystemVariables(variables=[
+ SystemVariable(name='test1', value_type=SystemVariable.ValueType.INT, value=Value(number_value=1))
+ ])
+ with db.session_scope() as session:
+ SettingService(session).set_system_variables(system_variables)
+ session.commit()
+
+ self.signin_as_admin()
+ resp = self.post_helper('/api/v2/settings:update_system_variables', data={'variables': [{'h': 'ff'}]})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertIn('Failed to parse variables field', json.loads(resp.data).get('details'))
+
+ with db.session_scope() as session:
+ self.assertEqual(system_variables, SettingService(session).get_system_variables())
+
+ def test_post_200(self):
+ self.signin_as_admin()
+ resp = self.post_helper('/api/v2/settings:update_system_variables',
+ data={'variables': [{
+ 'name': 'new_var',
+ 'value': 2,
+ 'value_type': 'INT'
+ }]})
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ resp, {'variables': [{
+ 'name': 'new_var',
+ 'value': 2.0,
+ 'value_type': 'INT',
+ 'fixed': False
+ }]})
+
+ expected_system_variables = SystemVariables(variables=[
+ SystemVariable(name='new_var', value_type=SystemVariable.ValueType.INT, value=Value(number_value=2))
+ ])
+ with db.session_scope() as session:
+ self.assertEqual(expected_system_variables, SettingService(session).get_system_variables())
+
+
+class UpdateImageApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ @patch('fedlearner_webconsole.setting.apis.k8s_client')
+ def test_post(self, mock_k8s_client: MagicMock):
+ deployment = SimpleNamespace(
+ **{
+ 'metadata':
+ SimpleNamespace(**{
+ 'name': 'fedlearner-web-console-v2',
+ 'namespace': 'testns'
+ }),
+ 'spec':
+ SimpleNamespace(
+ **{
+ 'template':
+ SimpleNamespace(
+ **{
+ 'spec':
+ SimpleNamespace(
+ **{'containers': [SimpleNamespace(**{'image': 'fedlearner:test'})]})
+ })
+ })
+ })
+ mock_k8s_client.get_deployment = MagicMock(return_value=deployment)
+ mock_k8s_client.create_or_update_deployment = MagicMock()
+
+ resp = self.post_helper('/api/v2/settings:update_image', data={'webconsole_image': 'test-new-image'})
+ self.assertEqual(resp.status_code, HTTPStatus.UNAUTHORIZED)
+
+ self.signin_as_admin()
+ resp = self.post_helper('/api/v2/settings:update_image', data={'webconsole_image': 'test-new-image'})
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ _, kwargs = mock_k8s_client.create_or_update_deployment.call_args
+ self.assertEqual(kwargs['spec'].template.spec.containers[0].image, 'test-new-image')
+ self.assertEqual(kwargs['name'], deployment.metadata.name)
+ self.assertEqual(kwargs['namespace'], deployment.metadata.namespace)
+
+
+class VersionsApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def test_get_version_api(self):
+ resp = self.get_helper('/api/v2/versions')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertEqual(self.get_response_data(resp)['branch_name'], '')
+
+ content = """
+ revision:f09d681b4eda01f053cc1a645fa6fc0775852a48
+ branch name:release-2.0.1
+ version:2.0.1.5
+ pub date:Fri Jul 16 12:23:19 CST 2021
+ """
+ application_version_path = os.path.join(Envs.BASE_DIR, '../current_revision')
+ with open(application_version_path, 'wt', encoding='utf-8') as f:
+ f.write(content)
+
+ resp = self.get_helper('/api/v2/versions')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(
+ resp, {
+ 'pub_date': 'Fri Jul 16 12:23:19 CST 2021',
+ 'revision': 'f09d681b4eda01f053cc1a645fa6fc0775852a48',
+ 'branch_name': 'release-2.0.1',
+ 'version': '2.0.1.5',
+ })
+
+ os.remove(application_version_path)
+
+
+class DashboardsApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def test_get_flag_disable(self):
+ self.signin_as_admin()
+ get_dashboard_response = self.get_helper('/api/v2/dashboards')
+ self.assertEqual(get_dashboard_response.status_code, HTTPStatus.FORBIDDEN)
+
+ @patch('fedlearner_webconsole.flag.models.Flag.DASHBOARD_ENABLED.value', True)
+ @patch('fedlearner_webconsole.setting.apis.DashboardService.get_dashboards')
+ def test_get(self, mock_get_dashboards: MagicMock):
+ mock_get_dashboards.return_value = [DashboardInformation()]
+ get_dashboard_response = self.get_helper('/api/v2/dashboards')
+ self.assertEqual(get_dashboard_response.status_code, HTTPStatus.UNAUTHORIZED)
+
+ mock_get_dashboards.reset_mock()
+ self.signin_as_admin()
+ mock_get_dashboards.return_value = [DashboardInformation()]
+ get_dashboard_response = self.get_helper('/api/v2/dashboards')
+ self.assertEqual(get_dashboard_response.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(get_dashboard_response, [{'name': '', 'uuid': '', 'url': ''}])
+
+ mock_get_dashboards.reset_mock()
+ mock_get_dashboards.side_effect = InternalException('')
+ get_dashboard_response = self.get_helper('/api/v2/dashboards')
+ self.assertEqual(get_dashboard_response.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/models.py b/web_console_v2/api/fedlearner_webconsole/setting/models.py
index 7d46db01f..f1b5bae1b 100644
--- a/web_console_v2/api/fedlearner_webconsole/setting/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/setting/models.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,16 +13,26 @@
# limitations under the License.
# coding: utf-8
-# pylint: disable=raise-missing-from
-
-from sqlalchemy import UniqueConstraint
+from sqlalchemy import UniqueConstraint, func
from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.proto.setting_pb2 import SettingPb
class Setting(db.Model):
__tablename__ = 'settings_v2'
- __table_args__ = (UniqueConstraint('key', name='uniq_key'),
- default_table_args('this is webconsole settings table'))
- id = db.Column(db.Integer, primary_key=True, comment='id')
- key = db.Column(db.String(255), nullable=False, comment='key')
+ __table_args__ = (UniqueConstraint('uniq_key',
+ name='uniq_key'), default_table_args('this is webconsole settings table'))
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ uniq_key = db.Column(db.String(255), nullable=False, comment='uniq_key')
value = db.Column(db.Text, comment='value')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created at')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ onupdate=func.now(),
+ server_default=func.now(),
+ comment='updated at')
+
+ def to_proto(self):
+ return SettingPb(
+ uniq_key=self.uniq_key,
+ value=self.value,
+ )
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/models_test.py b/web_console_v2/api/fedlearner_webconsole/setting/models_test.py
new file mode 100644
index 000000000..1fc037f0c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/setting/models_test.py
@@ -0,0 +1,43 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.setting_pb2 import SettingPb
+from fedlearner_webconsole.setting.models import Setting
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class SettingTest(NoWebServerTestCase):
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ setting = Setting(
+ uniq_key='test',
+ value='test value',
+ )
+ session.add(setting)
+ session.commit()
+ with db.session_scope() as session:
+ setting = session.query(Setting).filter_by(uniq_key='test').first()
+ self.assertEqual(setting.to_proto(), SettingPb(
+ uniq_key='test',
+ value='test value',
+ ))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/service.py b/web_console_v2/api/fedlearner_webconsole/setting/service.py
new file mode 100644
index 000000000..b34213a1b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/setting/service.py
@@ -0,0 +1,141 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+import os
+import json
+from typing import List, Optional
+from google.protobuf import text_format
+from google.protobuf.json_format import MessageToDict, Parse, ParseDict, ParseError
+from sqlalchemy.orm.session import Session
+from envs import Envs
+from fedlearner_webconsole.proto import setting_pb2
+from fedlearner_webconsole.setting.models import Setting
+from fedlearner_webconsole.proto.setting_pb2 import SystemVariables
+from fedlearner_webconsole.utils.app_version import ApplicationVersion
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+from fedlearner_webconsole.exceptions import InternalException
+
+
+def parse_application_version(content: str) -> ApplicationVersion:
+ revision, branch_name, version, pub_date = None, None, None, None
+ for line in content.split('\n'):
+ if line.find(':') == -1:
+ continue
+ key, value = line.split(':', 1)
+ key, value = key.strip(), value.strip()
+ if value == '':
+ continue
+ if key == 'revision':
+ revision = value
+ elif key == 'branch name':
+ branch_name = value
+ elif key == 'version':
+ version = value
+ elif key == 'pub date':
+ pub_date = value
+
+ return ApplicationVersion(revision=revision, branch_name=branch_name, version=version, pub_date=pub_date)
+
+
+class SettingService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def get_setting(self, key: str) -> Optional[Setting]:
+ return self._session.query(Setting).filter_by(uniq_key=key).first()
+
+ def create_or_update_setting(self, key: str, value: str) -> Setting:
+ setting = self._session.query(Setting).filter_by(uniq_key=key).first()
+ if setting is None:
+ setting = Setting(uniq_key=key, value=value)
+ self._session.add(setting)
+ self._session.commit()
+ else:
+ setting.value = value
+ self._session.commit()
+ return setting
+
+ def get_system_variables(self) -> SystemVariables:
+ result = SystemVariables()
+ setting = self.get_setting('system_variables')
+ if setting is None:
+ return result
+ text_format.Parse(setting.value, result)
+ return result
+
+ def set_system_variables(self, system_variables: SystemVariables):
+ self.create_or_update_setting('system_variables', text_format.MessageToString(system_variables))
+
+ @staticmethod
+ def get_application_version() -> ApplicationVersion:
+ application_version_path = os.path.join(Envs.BASE_DIR, '../current_revision')
+ if not os.path.exists(application_version_path):
+ content = ''
+ else:
+ with open(application_version_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+ return parse_application_version(content)
+
+ def get_namespace(self) -> str:
+ return self.get_system_variables_dict().get('namespace', 'default')
+
+ def get_system_variables_dict(self) -> dict:
+ variables = self.get_system_variables().variables
+ return {
+ var.name: MessageToDict(var.value, preserving_proto_field_name=True, including_default_value_fields=True)
+ for var in variables
+ }
+
+ @staticmethod
+ def get_system_info() -> setting_pb2.SystemInfo:
+ system_info: setting_pb2.SystemInfo = Parse(Envs.SYSTEM_INFO, setting_pb2.SystemInfo())
+ system_info.pure_domain_name = get_pure_domain_name(system_info.domain_name) or ''
+ return system_info
+
+
+class DashboardService(object):
+ # Reference: https://discuss.elastic.co/t/kibana-g-and-a-parameters-in-the-dashboards-url-string/264642
+ DASHBOARD_FMT_STR = '{kibana_address}/app/kibana#/dashboard/{object_uuid}?_a=(filters:!((query:(match_phrase:(service.environment:{cluster})))))' # pylint:disable=line-too-long
+
+ REQUIRED_DASHBOARD = frozenset(['overview'])
+
+ @staticmethod
+ def _validate_saved_object_uuid(saved_object_uuid: str) -> bool:
+ if not isinstance(saved_object_uuid, str) or not saved_object_uuid:
+ return False
+ return True
+
+ def get_dashboards(self) -> List[setting_pb2.DashboardInformation]:
+ dashboard_list = json.loads(Envs.KIBANA_DASHBOARD_LIST)
+ if not DashboardService.REQUIRED_DASHBOARD.issubset({d['name'] for d in dashboard_list}):
+ raise InternalException(
+ f'failed to find required dashboard {list(DashboardService.REQUIRED_DASHBOARD)} uuid')
+ try:
+ dashboard_information_list = []
+ for item in dashboard_list:
+ dashboard_information = ParseDict(item, setting_pb2.DashboardInformation(), ignore_unknown_fields=False)
+ if not self._validate_saved_object_uuid(dashboard_information.uuid):
+ raise InternalException(f'invalid uuid for dashboard {dashboard_information.name}')
+
+ dashboard_information.url = DashboardService.DASHBOARD_FMT_STR.format(
+ kibana_address=Envs.KIBANA_ADDRESS, object_uuid=dashboard_information.uuid, cluster=Envs.CLUSTER)
+ dashboard_information_list.append(dashboard_information)
+ return dashboard_information_list
+ except ParseError as err:
+ msg = f'invalid `KIBANA_DASHBOARD_LIST`, details: {err}'
+ logging.warning(msg)
+ raise InternalException(msg) from err
diff --git a/web_console_v2/api/fedlearner_webconsole/setting/service_test.py b/web_console_v2/api/fedlearner_webconsole/setting/service_test.py
new file mode 100644
index 000000000..b774daeac
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/setting/service_test.py
@@ -0,0 +1,156 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+# pylint: disable=protected-access
+import json
+import unittest
+from unittest.mock import patch
+
+from google.protobuf.json_format import ParseDict
+
+from fedlearner_webconsole.initial_db import initial_db
+from fedlearner_webconsole.proto.setting_pb2 import DashboardInformation, SystemInfo, SystemVariables
+from fedlearner_webconsole.setting.models import Setting
+from fedlearner_webconsole.setting.service import DashboardService, parse_application_version, SettingService
+from fedlearner_webconsole.utils.app_version import Version
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.exceptions import InternalException
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ initial_db()
+
+ def test_get_setting(self):
+ with db.session_scope() as session:
+ setting = Setting(uniq_key='test_key1', value='test value 1')
+ session.add(setting)
+ session.commit()
+ # A new session
+ with db.session_scope() as session:
+ setting = SettingService(session).get_setting('test_key1')
+ self.assertEqual(setting.value, 'test value 1')
+ setting = SettingService(session).get_setting('100')
+ self.assertIsNone(setting)
+
+ def test_set_setting(self):
+ # A new setting
+ with db.session_scope() as session:
+ setting = SettingService(session).create_or_update_setting(key='k1', value='v1')
+ self.assertEqual(setting.uniq_key, 'k1')
+ self.assertEqual(setting.value, 'v1')
+ setting_in_db = \
+ session.query(Setting).filter_by(uniq_key='k1').first()
+ self.assertEqual(setting_in_db.value, 'v1')
+ # Existing setting
+ with db.session_scope() as session:
+ SettingService(session).create_or_update_setting(key='k1', value='v2')
+ setting_in_db = \
+ session.query(Setting).filter_by(uniq_key='k1').first()
+ self.assertEqual(setting_in_db.value, 'v2')
+
+ def test_parse_application_version(self):
+ content = """
+ revision:f09d681b4eda01f053cc1a645fa6fc0775852a48
+ branch name:release-2.0.1
+ version:2.0.1.5
+ pub date:Fri Jul 16 12:23:19 CST 2021
+ """
+ application_version = parse_application_version(content)
+ self.assertEqual(application_version.revision, 'f09d681b4eda01f053cc1a645fa6fc0775852a48')
+ self.assertEqual(application_version.branch_name, 'release-2.0.1')
+ self.assertEqual(application_version.version, Version('2.0.1.5'))
+ self.assertEqual(application_version.pub_date, 'Fri Jul 16 12:23:19 CST 2021')
+
+ content = """
+ revision:f09d681b4eda01f053cc1a645fa6fc0775852a48
+ branch name:master
+ version:
+ pub date:Fri Jul 16 12:23:19 CST 2021
+ """
+ application_version = parse_application_version(content)
+ self.assertEqual(application_version.revision, 'f09d681b4eda01f053cc1a645fa6fc0775852a48')
+ self.assertEqual(application_version.branch_name, 'master')
+ self.assertIsNone(application_version.version.version)
+ self.assertEqual(application_version.pub_date, 'Fri Jul 16 12:23:19 CST 2021')
+
+ def test_get_variable_by_key(self):
+ with db.session_scope() as session:
+ self.assertEqual(SettingService(session).get_system_variables_dict()['namespace'], 'default')
+ self.assertIsNone(SettingService(session).get_system_variables_dict().get('not-existed'))
+
+ def test_get_system_variables_dict(self):
+ test_data = {'variables': [{'name': 'a', 'value': 2}, {'name': 'b', 'value': []}]}
+ with db.session_scope() as session:
+ SettingService(session).set_system_variables(ParseDict(test_data, SystemVariables()))
+ self.assertEqual(SettingService(session).get_system_variables_dict(), {'a': 2, 'b': []})
+
+ @patch('envs.Envs.SYSTEM_INFO',
+ json.dumps({
+ 'name': 'hahaha',
+ 'domain_name': 'fl-test.com',
+ 'pure_domain_name': 'test'
+ }))
+ def test_get_system_info(self):
+ with db.session_scope() as session:
+ system_info = SettingService(session).get_system_info()
+ self.assertEqual(system_info, SystemInfo(name='hahaha', domain_name='fl-test.com', pure_domain_name='test'))
+
+
+class DashboardServiceTest(unittest.TestCase):
+
+ def test_validate_saved_object_uuid(self):
+ self.assertFalse(DashboardService._validate_saved_object_uuid(''))
+ self.assertFalse(DashboardService._validate_saved_object_uuid(None))
+ self.assertFalse(DashboardService._validate_saved_object_uuid(1))
+ self.assertTrue(DashboardService._validate_saved_object_uuid('c4c0af20-d03c-11ec-9be6-d5c22c92cd59'))
+
+ def test_get_dashboards(self):
+ with patch('envs.Envs.KIBANA_DASHBOARD_LIST', '[]'):
+ with self.assertRaises(InternalException) as cm:
+ DashboardService().get_dashboards()
+ self.assertEqual(cm.exception.details, 'failed to find required dashboard [\'overview\'] uuid')
+ with patch('envs.Envs.KIBANA_DASHBOARD_LIST', json.dumps([{'name': 'overview', 'uuid': 1}])):
+ with self.assertRaises(InternalException) as cm:
+ DashboardService().get_dashboards()
+ self.assertEqual(
+ cm.exception.details, 'invalid `KIBANA_DASHBOARD_LIST`, '
+ 'details: Failed to parse uuid field: expected string or bytes-like object.')
+ with patch('envs.Envs.KIBANA_DASHBOARD_LIST', json.dumps([{'name': 'overview', 'test': 1}])):
+ with self.assertRaises(InternalException) as cm:
+ DashboardService().get_dashboards()
+ self.assertIn(
+ 'invalid `KIBANA_DASHBOARD_LIST`, details: Message type "fedlearner_webconsole.proto.DashboardInformation" has no field named "test".', # pylint: disable=line-too-long
+ cm.exception.details)
+ with patch('envs.Envs.KIBANA_DASHBOARD_LIST', json.dumps([{'name': 'overview', 'uuid': '1'}])):
+ self.assertEqual(
+ DashboardService().get_dashboards(),
+ [
+ DashboardInformation(
+ name='overview',
+ uuid='1',
+ # pylint: disable=line-too-long
+ url=
+ 'localhost:1993/app/kibana#/dashboard/1?_a=(filters:!((query:(match_phrase:(service.environment:default)))))',
+ )
+ ])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/sparkapp/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/sparkapp/BUILD.bazel
new file mode 100644
index 000000000..c796eb015
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sparkapp/BUILD.bazel
@@ -0,0 +1,89 @@
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "schema_lib",
+ srcs = ["schema.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:images_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "schema_lib_test",
+ srcs = ["schema_test.py"],
+ imports = ["../.."],
+ main = "schema_test.py",
+ deps = [
+ ":schema_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "service_lib",
+ srcs = ["service.py"],
+ imports = ["../.."],
+ deps = [
+ ":schema_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/k8s:k8s_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:file_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "service_lib_test",
+ srcs = ["service_test.py"],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "service_test.py",
+ deps = [
+ ":schema_lib",
+ ":service_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":schema_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ srcs = ["apis_test.py"],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/sparkapp/apis.py b/web_console_v2/api/fedlearner_webconsole/sparkapp/apis.py
index 70dfc6339..feca6f55a 100644
--- a/web_console_v2/api/fedlearner_webconsole/sparkapp/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/sparkapp/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,54 +13,165 @@
# limitations under the License.
# coding: utf-8
-import base64
from http import HTTPStatus
+import logging
-from flask import request
from flask_restful import Api, Resource
+from marshmallow import Schema, fields, post_load
+from webargs.flaskparser import use_args, use_kwargs
from fedlearner_webconsole.sparkapp.schema import SparkAppConfig
-from fedlearner_webconsole.utils.decorators import jwt_required
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
from fedlearner_webconsole.sparkapp.service import SparkAppService
-from fedlearner_webconsole.exceptions import (InvalidArgumentException,
- NotFoundException)
+from fedlearner_webconsole.exceptions import (InternalException, NotFoundException)
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
+from fedlearner_webconsole.swagger.models import schema_manager
-class SparkAppsApi(Resource):
- @jwt_required()
- def post(self):
- service = SparkAppService()
- data = request.json
+class SparkAppPodParameter(Schema):
+ cores = fields.Integer(required=True)
+ memory = fields.String(required=True)
+ instances = fields.Integer(required=False, load_default=1)
+ core_limit = fields.String(required=False)
+ volume_mounts = fields.List(fields.Dict(fields.String, fields.String), required=False)
+ envs = fields.Dict(fields.String, fields.String)
- try:
- config = SparkAppConfig.from_dict(data)
- if config.files:
- config.files = base64.b64decode(config.files)
- except ValueError as err:
- raise InvalidArgumentException(details=err)
- res = service.submit_sparkapp(config=config)
- return {'data': res.to_dict()}, HTTPStatus.CREATED
+class SparkAppCreateParameter(Schema):
+ name = fields.String(required=True)
+ files = fields.String(required=False, load_default=None)
+ files_path = fields.String(required=False, load_default=None)
+ image_url = fields.String(required=False, load_default=None)
+ volumes = fields.List(fields.Dict(fields.String, fields.String), required=False, load_default=[])
+ driver_config = fields.Nested(SparkAppPodParameter)
+ executor_config = fields.Nested(SparkAppPodParameter)
+ py_files = fields.List(fields.String, required=False, load_default=[])
+ command = fields.List(fields.String, required=False, load_default=[])
+ main_application = fields.String(required=True)
+
+ @post_load
+ def make_spark_app_config(self, data, **kwargs):
+ del kwargs
+ return SparkAppConfig.from_dict(data)
+
+
+class SparkAppsApi(Resource):
+
+ @credentials_required
+ @use_args(SparkAppCreateParameter())
+ def post(self, config: SparkAppConfig):
+ """Create sparkapp
+ ---
+ tags:
+ - sparkapp
+ description: Create sparkapp
+ parameters:
+ - in: body
+ name: body
+ schema:
+ $ref: '#/definitions/SparkAppCreateParameter'
+ responses:
+ 201:
+ description: The sparkapp is created
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SparkAppInfo'
+ """
+ service = SparkAppService()
+ return make_flask_response(data=service.submit_sparkapp(config=config), status=HTTPStatus.CREATED)
class SparkAppApi(Resource):
- @jwt_required()
+
+ @credentials_required
def get(self, sparkapp_name: str):
+ """Get sparkapp status
+ ---
+ tags:
+ - sparkapp
+ description: Get sparkapp status
+ parameters:
+ - in: path
+ name: sparkapp_name
+ schema:
+ type: string
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.SparkAppInfo'
+ """
service = SparkAppService()
- return {
- 'data': service.get_sparkapp_info(sparkapp_name).to_dict()
- }, HTTPStatus.OK
+ return make_flask_response(data=service.get_sparkapp_info(sparkapp_name))
- @jwt_required()
+ @credentials_required
def delete(self, sparkapp_name: str):
+ """Delete a sparkapp whether the existence of sparkapp
+ ---
+ tags:
+ - sparkapp
+ description: Delete a sparkapp whether the existence of sparkapp
+ parameters:
+ - in: path
+ name: sparkapp_name
+ schema:
+ type: string
+ responses:
+ 204:
+ description: finish sparkapp deletion
+ """
service = SparkAppService()
try:
- sparkapp_info = service.delete_sparkapp(sparkapp_name)
- return {'data': sparkapp_info.to_dict()}, HTTPStatus.OK
+ service.delete_sparkapp(sparkapp_name)
except NotFoundException:
- return {'data': {'name': sparkapp_name}}, HTTPStatus.OK
+ logging.warning(f'[sparkapp] could not find sparkapp {sparkapp_name}')
+
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class SparkAppLogApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'lines': fields.Integer(required=True, help='lines is required')}, location='query')
+ def get(self, sparkapp_name: str, lines: int):
+ """Get sparkapp logs
+ ---
+ tags:
+ - sparkapp
+ description: Get sparkapp logs
+ parameters:
+ - in: path
+ name: sparkapp_name
+ schema:
+ type: string
+ - in: query
+ name: lines
+ schema:
+ type: integer
+ responses:
+ 200:
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ type: string
+ """
+ max_limit = 10000
+ if lines is None or lines > max_limit:
+ lines = max_limit
+ service = SparkAppService()
+ try:
+ return make_flask_response(data=service.get_sparkapp_log(sparkapp_name, lines))
+ except Exception as e: # pylint: disable=broad-except)
+ raise InternalException(details=f'error {e}') from e
def initialize_sparkapps_apis(api: Api):
api.add_resource(SparkAppsApi, '/sparkapps')
api.add_resource(SparkAppApi, '/sparkapps/')
+ api.add_resource(SparkAppLogApi, '/sparkapps//log')
+
+ schema_manager.append(SparkAppCreateParameter)
diff --git a/web_console_v2/api/fedlearner_webconsole/sparkapp/apis_test.py b/web_console_v2/api/fedlearner_webconsole/sparkapp/apis_test.py
new file mode 100644
index 000000000..27a0b332a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sparkapp/apis_test.py
@@ -0,0 +1,114 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from http import HTTPStatus
+import os
+import unittest
+import base64
+
+from unittest.mock import MagicMock, patch
+from fedlearner_webconsole.proto import sparkapp_pb2
+
+from testing.common import BaseTestCase
+from envs import Envs
+
+BASE_DIR = Envs.BASE_DIR
+
+
+class SparkAppApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._upload_path = os.path.join(BASE_DIR, 'test')
+ self._upload_path_patcher = patch('fedlearner_webconsole.sparkapp.service.UPLOAD_PATH', self._upload_path)
+ self._upload_path_patcher.start()
+
+ def tearDown(self):
+ self._upload_path_patcher.stop()
+ super().tearDown()
+
+ @patch('fedlearner_webconsole.sparkapp.service.SparkAppService.submit_sparkapp')
+ def test_submit_sparkapp(self, mock_submit_sparkapp: MagicMock):
+ mock_submit_sparkapp.return_value = sparkapp_pb2.SparkAppInfo()
+ tarball_file_path = os.path.join(BASE_DIR, 'testing/test_data/sparkapp.tar')
+ with open(tarball_file_path, 'rb') as f:
+ files_bin = f.read()
+
+ post_response = self.post_helper(
+ '/api/v2/sparkapps', {
+ 'name': 'fl-transformer-yaml',
+ 'files': base64.b64encode(files_bin).decode(),
+ 'image_url': 'dockerhub.com',
+ 'driver_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'core_limit': '4000m',
+ },
+ 'executor_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'instances': 5,
+ },
+ 'command': ['data.csv', 'data.rd'],
+ 'main_application': '${prefix}/convertor.py'
+ })
+ self.assertEqual(post_response.status_code, HTTPStatus.CREATED, post_response.json)
+ mock_submit_sparkapp.assert_called_once()
+ _, kwargs = mock_submit_sparkapp.call_args
+ self.assertEqual(kwargs['config'].name, 'fl-transformer-yaml')
+
+ mock_submit_sparkapp.reset_mock()
+ mock_submit_sparkapp.return_value = sparkapp_pb2.SparkAppInfo()
+ post_response = self.post_helper(
+ '/api/v2/sparkapps', {
+ 'name': 'fl-transformer-yaml',
+ 'image_url': 'dockerhub.com',
+ 'driver_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'core_limit': '4000m',
+ },
+ 'executor_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'instances': 5,
+ },
+ 'command': ['data.csv', 'data.rd'],
+ 'main_application': '${prefix}/convertor.py'
+ })
+ self.assertEqual(post_response.status_code, HTTPStatus.CREATED, post_response.json)
+ mock_submit_sparkapp.assert_called_once()
+ _, kwargs = mock_submit_sparkapp.call_args
+ self.assertEqual(kwargs['config'].name, 'fl-transformer-yaml')
+
+ @patch('fedlearner_webconsole.sparkapp.service.SparkAppService.get_sparkapp_info')
+ def test_get_sparkapp_info(self, mock_get_sparkapp: MagicMock):
+ mock_get_sparkapp.return_value = sparkapp_pb2.SparkAppInfo()
+
+ get_response = self.get_helper('/api/v2/sparkapps/fl-transformer-yaml')
+ self.assertEqual(get_response.status_code, HTTPStatus.OK)
+
+ mock_get_sparkapp.assert_called_once_with('fl-transformer-yaml')
+
+ @patch('fedlearner_webconsole.sparkapp.service.SparkAppService.delete_sparkapp')
+ def test_delete_sparkapp(self, mock_delete_sparkapp: MagicMock):
+ mock_delete_sparkapp.return_value = sparkapp_pb2.SparkAppInfo()
+ resp = self.delete_helper('/api/v2/sparkapps/fl-transformer-yaml')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ mock_delete_sparkapp.assert_called_once_with('fl-transformer-yaml')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/sparkapp/schema.py b/web_console_v2/api/fedlearner_webconsole/sparkapp/schema.py
index 31d91f44a..9b65b04ce 100644
--- a/web_console_v2/api/fedlearner_webconsole/sparkapp/schema.py
+++ b/web_console_v2/api/fedlearner_webconsole/sparkapp/schema.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,26 +13,28 @@
# limitations under the License.
# coding: utf-8
-from envs import Envs
+import base64
+import logging
+from typing import Optional
+from google.protobuf.json_format import ParseDict, MessageToDict
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.utils.images import generate_unified_version_image
+from fedlearner_webconsole.proto import sparkapp_pb2
-from fedlearner_webconsole.utils.mixins import from_dict_mixin, to_dict_mixin
-SPARK_POD_CONFIG_SERILIZE_FIELDS = [
- 'cores', 'memory', 'instances', 'core_limit', 'envs', 'volume_mounts'
-]
+class SparkPodConfig(object):
+ def __init__(self, spark_pod_config: sparkapp_pb2.SparkPodConfig):
+ self._spark_pod_config = spark_pod_config
-@to_dict_mixin(to_dict_fields=SPARK_POD_CONFIG_SERILIZE_FIELDS,
- ignore_none=True)
-@from_dict_mixin(from_dict_fields=SPARK_POD_CONFIG_SERILIZE_FIELDS)
-class SparkPodConfig(object):
- def __init__(self):
- self.cores = None
- self.memory = None
- self.instances = None
- self.core_limit = None
- self.volume_mounts = []
- self.envs = {}
+ @classmethod
+ def from_dict(cls, inputs: dict) -> 'SparkPodConfig':
+ spark_pod_config = sparkapp_pb2.SparkPodConfig()
+ envs = inputs.pop('envs')
+ inputs['env'] = [{'name': k, 'value': v} for k, v in envs.items()]
+ spark_pod_config = ParseDict(inputs, spark_pod_config, ignore_unknown_fields=True)
+ return cls(spark_pod_config)
def build_config(self) -> dict:
""" build config for sparkoperator api
@@ -41,171 +43,172 @@ def build_config(self) -> dict:
Returns:
dict: part of sparkoperator body
"""
- config = {
- 'cores': self.cores,
- 'memory': self.memory,
- }
- if self.instances:
- config['instances'] = self.instances
- if self.core_limit:
- config['coreLimit'] = self.core_limit
- if self.envs and len(self.envs) > 0:
- config['env'] = [{
- 'name': k,
- 'value': v
- } for k, v in self.envs.items()]
- if self.volume_mounts and len(self.volume_mounts) > 0:
- config['volumeMounts'] = self.volume_mounts
-
- return config
-
-
-SPARK_APP_CONFIG_SERILIZE_FIELDS = [
- 'name', 'files', 'files_path', 'volumes', 'image_url', 'driver_config',
- 'executor_config', 'command', 'main_application', 'py_files'
-]
-SPARK_APP_CONFIG_REQUIRED_FIELDS = ['name', 'image_url']
-
-
-@to_dict_mixin(to_dict_fields=SPARK_APP_CONFIG_SERILIZE_FIELDS,
- ignore_none=True)
-@from_dict_mixin(from_dict_fields=SPARK_APP_CONFIG_SERILIZE_FIELDS,
- required_fields=SPARK_APP_CONFIG_REQUIRED_FIELDS)
+ return MessageToDict(self._spark_pod_config,
+ including_default_value_fields=False,
+ preserving_proto_field_name=False)
+
+
class SparkAppConfig(object):
- def __init__(self):
- self.name = None
- # local files should be compressed to submit spark
- self.files = None
- # if nas/hdfs has those files, such as analyzer, only need files path \
- # to submit spark
- self.files_path = None
- self.image_url = None
- self.volumes = []
- self.driver_config = SparkPodConfig()
- self.executor_config = SparkPodConfig()
- self.py_files = []
- self.command = []
- self.main_application = None
-
- def _replace_placeholder_with_real_path(self, exper: str,
- sparkapp_path: str):
+
+ def __init__(self, spark_app_config: sparkapp_pb2.SparkAppConfig):
+ self._spark_app_config = spark_app_config
+ self.files: Optional[bytes] = None
+
+ @property
+ def files_path(self):
+ return self._spark_app_config.files_path
+
+ @property
+ def name(self):
+ return self._spark_app_config.name
+
+ @classmethod
+ def from_dict(cls, inputs: dict) -> 'SparkAppConfig':
+ self = cls(sparkapp_pb2.SparkAppConfig())
+ if 'files' in inputs:
+ input_files = inputs.pop('files')
+ if isinstance(input_files, str):
+ self.files = base64.b64decode(input_files)
+ elif isinstance(input_files, (bytearray, bytes)):
+ self.files = input_files
+ else:
+ logging.debug(f'[SparkAppConfig]: ignore parsing files fields, expected type is str or bytes, \
+ actually is {type(input_files)}')
+ self._spark_app_config = ParseDict(inputs, self._spark_app_config, ignore_unknown_fields=True)
+ return self
+
+ def _replace_placeholder_with_real_path(self, exper: str, sparkapp_path: str) -> str:
""" replace ${prefix} with real path
Args:
+ exper (str): sparkapp expression in body
sparkapp_path (str): sparkapp real path
+
+ Returns:
+ return the real path without ${prefix} expression
"""
return exper.replace('${prefix}', sparkapp_path)
def build_config(self, sparkapp_path: str) -> dict:
- return {
- 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
- 'kind': 'SparkApplication',
- 'metadata': {
- 'name': self.name,
- 'namespace': Envs.K8S_NAMESPACE,
- 'labels': Envs.K8S_LABEL_INFO
- },
- 'spec': {
- 'type':
- 'Python',
- 'pythonVersion':
- '3',
- 'mode':
- 'cluster',
- 'image':
- self.image_url,
- 'imagePullPolicy':
- 'Always',
- 'volumes':
- self.volumes,
- 'mainApplicationFile':
- self._replace_placeholder_with_real_path(
- self.main_application, sparkapp_path),
- 'arguments': [
- self._replace_placeholder_with_real_path(c, sparkapp_path)
- for c in self.command
- ],
- 'deps': {
- 'pyFiles': [
- self._replace_placeholder_with_real_path(
- f, sparkapp_path) for f in self.py_files
- ]
- },
- 'sparkConf': {
- 'spark.shuffle.service.enabled': 'false',
- },
- 'sparkVersion':
- '3.0.0',
- 'restartPolicy': {
- 'type': 'Never',
- },
- 'dynamicAllocation': {
- 'enabled': False,
- },
- 'driver': {
- **self.driver_config.build_config(),
- 'labels': {
- 'version': '3.0.0'
+ # sparkapp configuration limitation: initial executors must [5, 30]
+ if self._spark_app_config.executor_config.instances > 30:
+ self._spark_app_config.dynamic_allocation.max_executors = self._spark_app_config.executor_config.instances
+ self._spark_app_config.executor_config.instances = 30
+
+ with db.session_scope() as session:
+ setting_service = SettingService(session)
+ sys_variables = setting_service.get_system_variables_dict()
+ namespace = setting_service.get_namespace()
+ labels = sys_variables.get('labels')
+ if not self._spark_app_config.image_url:
+ self._spark_app_config.image_url = generate_unified_version_image(sys_variables.get('spark_image'))
+ for volume in sys_variables.get('volumes_list', []):
+ self._spark_app_config.volumes.append(
+ ParseDict(volume, sparkapp_pb2.Volume(), ignore_unknown_fields=True))
+ for volume_mount in sys_variables.get('volume_mounts_list', []):
+ volume_mount_pb = ParseDict(volume_mount, sparkapp_pb2.VolumeMount(), ignore_unknown_fields=True)
+ self._spark_app_config.executor_config.volume_mounts.append(volume_mount_pb)
+ self._spark_app_config.driver_config.volume_mounts.append(volume_mount_pb)
+ envs_list = []
+ for env in sys_variables.get('envs_list', []):
+ envs_list.append(ParseDict(env, sparkapp_pb2.Env()))
+ self._spark_app_config.driver_config.env.extend(envs_list)
+ self._spark_app_config.executor_config.env.extend(envs_list)
+ base_config = {
+ 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
+ 'kind': 'SparkApplication',
+ 'metadata': {
+ 'name': self._spark_app_config.name,
+ 'namespace': namespace,
+ 'labels': labels,
+ # Aimed for resource queue management purpose.
+ # It should work fine on where there is no resource queue service.
+ 'annotations': {
+ 'queue': 'fedlearner-spark',
+ 'schedulerName': 'batch',
},
- 'serviceAccount': 'spark',
},
- 'executor': {
- **self.executor_config.build_config(),
- 'labels': {
- 'version': '3.0.0'
+ 'spec': {
+ 'type':
+ 'Python',
+ 'timeToLiveSeconds':
+ 1800,
+ 'pythonVersion':
+ '3',
+ 'mode':
+ 'cluster',
+ 'image':
+ self._spark_app_config.image_url,
+ 'imagePullPolicy':
+ 'IfNotPresent',
+ 'volumes': [
+ MessageToDict(volume, including_default_value_fields=False, preserving_proto_field_name=False)
+ for volume in self._spark_app_config.volumes
+ ],
+ 'arguments': [
+ self._replace_placeholder_with_real_path(c, sparkapp_path)
+ for c in self._spark_app_config.command
+ ],
+ 'sparkConf': {
+ 'spark.shuffle.service.enabled': 'false',
},
+ 'sparkVersion':
+ '3.0.0',
+ 'restartPolicy': {
+ 'type': 'Never',
+ },
+ 'dynamicAllocation':
+ MessageToDict(self._spark_app_config.dynamic_allocation,
+ including_default_value_fields=False,
+ preserving_proto_field_name=False),
+ 'driver': {
+ **SparkPodConfig(self._spark_app_config.driver_config).build_config(),
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ 'serviceAccount': 'spark',
+ },
+ 'executor': {
+ **SparkPodConfig(self._spark_app_config.executor_config).build_config(),
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ }
}
}
- }
-
-
-SPARK_APP_INFO_SERILIZE_FIELDS = [
- 'name', 'namespace', 'command', 'driver', 'executor', 'image_url',
- 'main_application', 'spark_version', 'type', 'state'
-]
-
-
-@to_dict_mixin(to_dict_fields=SPARK_APP_INFO_SERILIZE_FIELDS, ignore_none=True)
-@from_dict_mixin(from_dict_fields=SPARK_APP_INFO_SERILIZE_FIELDS)
-class SparkAppInfo(object):
- @classmethod
- def from_k8s_resp(cls, resp):
- sparkapp_info = cls()
- if 'name' in resp['metadata']:
- sparkapp_info.name = resp['metadata']['name']
- elif 'name' in resp['details']:
- sparkapp_info.name = resp['details']['name']
- sparkapp_info.namespace = resp['metadata'].get('namespace', None)
- sparkapp_info.state = None
- if 'status' in resp:
- if isinstance(resp['status'], str):
- sparkapp_info.state = None
- elif isinstance(resp['status'], dict):
- sparkapp_info.state = resp.get('status',
- {}).get('applicationState',
- {}).get('state', None)
- sparkapp_info.command = resp.get('spec', {}).get('arguments', None)
- sparkapp_info.executor = SparkPodConfig.from_dict(
- resp.get('spec', {}).get('executor', {}))
- sparkapp_info.driver = SparkPodConfig.from_dict(
- resp.get('spec', {}).get('driver', {}))
- sparkapp_info.image_url = resp.get('spec', {}).get('image', None)
- sparkapp_info.main_application = resp.get('spec', {}).get(
- 'mainApplicationFile', None)
- sparkapp_info.spark_version = resp.get('spec',
- {}).get('sparkVersion', None)
- sparkapp_info.type = resp.get('spec', {}).get('type', None)
-
- return sparkapp_info
-
- def __init__(self):
- self.name = None
- self.state = None
- self.namespace = None
- self.command = None
- self.driver = SparkPodConfig()
- self.executor = SparkPodConfig()
- self.image_url = None
- self.main_application = None
- self.spark_version = None
- self.type = None
+ if self._spark_app_config.main_application:
+ base_config['spec']['mainApplicationFile'] = self._replace_placeholder_with_real_path(
+ self._spark_app_config.main_application, sparkapp_path)
+ if self._spark_app_config.py_files:
+ base_config['spec']['deps'] = {
+ 'pyFiles': [
+ self._replace_placeholder_with_real_path(f, sparkapp_path)
+ for f in self._spark_app_config.py_files
+ ]
+ }
+ return base_config
+
+
+def from_k8s_resp(resp: dict) -> sparkapp_pb2.SparkAppInfo:
+ sparkapp_info = sparkapp_pb2.SparkAppInfo()
+ if 'name' in resp['metadata']:
+ sparkapp_info.name = resp['metadata']['name']
+ elif 'name' in resp['details']:
+ sparkapp_info.name = resp['details']['name']
+ sparkapp_info.namespace = resp['metadata'].get('namespace', '')
+ if 'status' in resp:
+ if isinstance(resp['status'], str):
+ sparkapp_info.state = resp['status']
+ elif isinstance(resp['status'], dict):
+ sparkapp_info.state = resp.get('status', {}).get('applicationState', {}).get('state', '')
+ sparkapp_info.command.extend(resp.get('spec', {}).get('arguments', []))
+ sparkapp_info.executor.MergeFrom(
+ ParseDict(resp.get('spec', {}).get('executor', {}), sparkapp_pb2.SparkPodConfig(), ignore_unknown_fields=True))
+ sparkapp_info.driver.MergeFrom(
+ ParseDict(resp.get('spec', {}).get('driver', {}), sparkapp_pb2.SparkPodConfig(), ignore_unknown_fields=True))
+ sparkapp_info.image_url = resp.get('spec', {}).get('image', '')
+ sparkapp_info.main_application = resp.get('spec', {}).get('mainApplicationFile', '')
+ sparkapp_info.spark_version = resp.get('spec', {}).get('sparkVersion', '3')
+ sparkapp_info.type = resp.get('spec', {}).get('type', '')
+
+ return sparkapp_info
diff --git a/web_console_v2/api/fedlearner_webconsole/sparkapp/schema_test.py b/web_console_v2/api/fedlearner_webconsole/sparkapp/schema_test.py
new file mode 100644
index 000000000..3aef5a4d1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sparkapp/schema_test.py
@@ -0,0 +1,202 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.proto import sparkapp_pb2
+from fedlearner_webconsole.sparkapp.schema import SparkAppConfig, SparkPodConfig, from_k8s_resp
+
+
+class SparkAppSchemaTest(NoWebServerTestCase):
+
+ def test_spark_pod_config(self):
+ inputs = {'cores': 1, 'memory': '200m', 'core_limit': '4000m', 'envs': {'HELLO': '1'}}
+ spark_pod_config = SparkPodConfig.from_dict(inputs)
+ config = spark_pod_config.build_config()
+ self.assertDictEqual(config, {
+ 'cores': 1,
+ 'memory': '200m',
+ 'coreLimit': '4000m',
+ 'env': [{
+ 'name': 'HELLO',
+ 'value': '1'
+ }]
+ })
+
+ def test_sparkapp_config(self):
+ inputs = {
+ 'name': 'test',
+ 'files': bytes(100),
+ 'image_url': 'dockerhub.com',
+ 'driver_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'core_limit': '4000m',
+ 'envs': {
+ 'HELLO': '1'
+ },
+ 'volumeMounts': [{
+ 'mountPath': '/data',
+ 'name': 'data'
+ }]
+ },
+ 'executor_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'instances': 64,
+ 'envs': {
+ 'HELLO': '1'
+ },
+ 'volumeMounts': [{
+ 'mountPath': '/data',
+ 'name': 'data',
+ 'unknown': '2',
+ }]
+ },
+ 'command': ['hhh', 'another'],
+ 'main_application': '${prefix}/main.py',
+ 'volumes': [{
+ 'name': 'data',
+ 'hostPath': {
+ 'path': '/data',
+ },
+ 'unknown': '1',
+ }]
+ }
+ sparkapp_config = SparkAppConfig.from_dict(inputs)
+ config = sparkapp_config.build_config('./test')
+ self.assertEqual(config['spec']['mainApplicationFile'], './test/main.py')
+ self.assertNotIn('instances', config['spec']['driver'])
+ self.assertEqual([{'name': 'data', 'hostPath': {'path': '/data',}}], config['spec']['volumes'])
+ self.assertEqual(config['spec']['executor']['instances'], 30)
+ self.assertEqual(config['spec']['dynamicAllocation']['maxExecutors'], 64)
+
+ def test_sparkapp_dynamic_allocation(self):
+ inputs = {
+ 'name': 'test',
+ 'image_url': 'test.com/test/hhh:1',
+ 'dynamic_allocation': {
+ 'enabled': True,
+ 'initialExecutors': 2,
+ 'minExecutors': 2,
+ 'maxExecutors': 10
+ }
+ }
+ sparkapp_config: SparkAppConfig = SparkAppConfig.from_dict(inputs)
+ config = sparkapp_config.build_config('./test')
+ print(config['spec']['dynamicAllocation'])
+ self.assertEqual(len(config['spec']['dynamicAllocation']), 4)
+ self.assertTrue(config['spec']['dynamicAllocation']['enabled'])
+
+ def test_sparkapp_info(self):
+ resp = {
+ 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
+ 'kind': 'SparkApplication',
+ 'metadata': {
+ 'creationTimestamp':
+ '2021-05-18T08:59:16Z',
+ 'generation':
+ 1,
+ 'name':
+ 'fl-transformer-yaml',
+ 'namespace':
+ 'fedlearner',
+ 'resourceVersion':
+ '432649442',
+ 'selfLink':
+ '/apis/sparkoperator.k8s.io/v1beta2/namespaces/fedlearner/sparkapplications/fl-transformer-yaml',
+ 'uid':
+ '52d66d27-b7b7-11eb-b9df-b8599fdb0aac'
+ },
+ 'spec': {
+ 'arguments': ['data.csv', 'data_tfrecords/'],
+ 'driver': {
+ 'coreLimit': '4000m',
+ 'cores': 1,
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ 'memory': '512m',
+ 'serviceAccount': 'spark',
+ 'volumeMounts': [{
+ 'mountPath': '/data',
+ 'name': 'data',
+ 'readOnly': True
+ }],
+ },
+ 'dynamicAllocation': {
+ 'enabled': False
+ },
+ 'executor': {
+ 'cores': 1,
+ 'instances': 1,
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ 'memory': '512m',
+ 'volumeMounts': [{
+ 'mountPath': '/data',
+ 'name': 'data',
+ 'readOnly': True
+ }],
+ },
+ 'image': 'dockerhub.com',
+ 'imagePullPolicy': 'Always',
+ 'mainApplicationFile': 'transformer.py',
+ 'mode': 'cluster',
+ 'pythonVersion': '3',
+ 'restartPolicy': {
+ 'type': 'Never'
+ },
+ 'sparkConf': {
+ 'spark.shuffle.service.enabled': 'false'
+ },
+ 'sparkVersion': '3.0.0',
+ 'type': 'Python',
+ },
+ 'status': {
+ 'applicationState': {
+ 'state': 'COMPLETED'
+ },
+ 'driverInfo': {
+ 'podName': 'fl-transformer-yaml-driver',
+ 'webUIAddress': '11.249.131.12:4040',
+ 'webUIPort': 4040,
+ 'webUIServiceName': 'fl-transformer-yaml-ui-svc'
+ },
+ 'executionAttempts': 1,
+ 'executorState': {
+ 'fl-transformer-yaml-bdc15979a314310b-exec-1': 'PENDING',
+ 'fl-transformer-yaml-bdc15979a314310b-exec-2': 'COMPLETED'
+ },
+ 'lastSubmissionAttemptTime': '2021-05-18T10:31:13Z',
+ 'sparkApplicationId': 'spark-a380bfd520164d828a334bcb3a6404f9',
+ 'submissionAttempts': 1,
+ 'submissionID': '5bc7e2e7-cc0f-420c-8bc7-138b651a1dde',
+ 'terminationTime': '2021-05-18T10:32:08Z'
+ }
+ }
+
+ sparkapp_info = from_k8s_resp(resp)
+ self.assertEqual(sparkapp_info.namespace, 'fedlearner')
+ self.assertEqual(sparkapp_info.name, 'fl-transformer-yaml')
+ self.assertEqual(sparkapp_info.driver.volume_mounts[0],
+ sparkapp_pb2.VolumeMount(mount_path='/data', name='data', read_only=True))
+ self.assertEqual(sparkapp_info.executor.volume_mounts[0],
+ sparkapp_pb2.VolumeMount(mount_path='/data', name='data', read_only=True))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/sparkapp/service.py b/web_console_v2/api/fedlearner_webconsole/sparkapp/service.py
index 21e612777..68a14f4be 100644
--- a/web_console_v2/api/fedlearner_webconsole/sparkapp/service.py
+++ b/web_console_v2/api/fedlearner_webconsole/sparkapp/service.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,29 +20,23 @@
from typing import Tuple
from envs import Envs
+from fedlearner_webconsole.proto import sparkapp_pb2
from fedlearner_webconsole.utils.file_manager import FileManager
-from fedlearner_webconsole.sparkapp.schema import SparkAppConfig, SparkAppInfo
-from fedlearner_webconsole.utils.k8s_client import k8s_client
-from fedlearner_webconsole.utils.tars import TarCli
+from fedlearner_webconsole.sparkapp.schema import SparkAppConfig, from_k8s_resp
+from fedlearner_webconsole.k8s.k8s_client import (SPARKOPERATOR_CUSTOM_GROUP, SPARKOPERATOR_CUSTOM_VERSION, CrdKind,
+ k8s_client, SPARKOPERATOR_NAMESPACE)
+from fedlearner_webconsole.utils.file_operator import FileOperator
UPLOAD_PATH = Envs.STORAGE_ROOT
class SparkAppService(object):
+
def __init__(self) -> None:
self._base_dir = os.path.join(UPLOAD_PATH, 'sparkapp')
- self._file_client = FileManager()
-
- self._file_client.mkdir(self._base_dir)
-
- def _clear_and_make_an_empty_dir(self, dir_name: str):
- try:
- self._file_client.remove(dir_name)
- except Exception as err: # pylint: disable=broad-except
- logging.error('failed to remove %s with exception %s', dir_name,
- err)
- finally:
- self._file_client.mkdir(dir_name)
+ self._file_manager = FileManager()
+ self._file_operator = FileOperator()
+ self._file_manager.mkdir(self._base_dir)
def _get_sparkapp_upload_path(self, name: str) -> Tuple[bool, str]:
"""get upload path for specific sparkapp
@@ -57,50 +51,10 @@ def _get_sparkapp_upload_path(self, name: str) -> Tuple[bool, str]:
"""
sparkapp_path = os.path.join(self._base_dir, name)
- existable = False
- try:
- self._file_client.ls(sparkapp_path)
- existable = True
- except ValueError:
- existable = False
-
+ existable = self._file_manager.isdir(sparkapp_path)
return existable, sparkapp_path
- def _copy_files_to_target_filesystem(self, source_filesystem_path: str,
- target_filesystem_path: str) -> bool:
- """ copy files to remote filesystem
- - untar if file is tared
- - copy files to remote filesystem
-
- Args:
- source_filesystem_path (str): local filesystem
- target_filesystem_path (str): remote filesystem
-
- Returns:
- bool: whether success
- """
- temp_path = source_filesystem_path
- if source_filesystem_path.find('.tar') != -1:
- temp_path = os.path.abspath(
- os.path.join(source_filesystem_path, '../tmp'))
- os.makedirs(temp_path)
- TarCli.untar_file(source_filesystem_path, temp_path)
-
- for root, dirs, files in os.walk(temp_path):
- relative_path = os.path.relpath(root, temp_path)
- for f in files:
- file_path = os.path.join(root, f)
- remote_file_path = os.path.join(target_filesystem_path,
- relative_path, f)
- self._file_client.copy(file_path, remote_file_path)
- for d in dirs:
- remote_dir_path = os.path.join(target_filesystem_path,
- relative_path, d)
- self._file_client.mkdir(remote_dir_path)
-
- return True
-
- def submit_sparkapp(self, config: SparkAppConfig) -> SparkAppInfo:
+ def submit_sparkapp(self, config: SparkAppConfig) -> sparkapp_pb2.SparkAppInfo:
"""submit sparkapp
Args:
@@ -112,25 +66,27 @@ def submit_sparkapp(self, config: SparkAppConfig) -> SparkAppInfo:
Returns:
SparkAppInfo: resp of sparkapp
"""
+ logging.info(f'submit sparkapp with config:{config}')
sparkapp_path = config.files_path
- if config.files_path is None:
+ if not config.files_path:
_, sparkapp_path = self._get_sparkapp_upload_path(config.name)
- self._clear_and_make_an_empty_dir(sparkapp_path)
+ self._file_operator.clear_and_make_an_empty_dir(sparkapp_path)
- with tempfile.TemporaryDirectory() as temp_dir:
- tar_path = os.path.join(temp_dir, 'files.tar')
- with open(tar_path, 'wb') as fwrite:
- fwrite.write(config.files)
- self._copy_files_to_target_filesystem(
- source_filesystem_path=tar_path,
- target_filesystem_path=sparkapp_path)
+ # In case there is no files
+ if config.files is not None:
+ with tempfile.TemporaryDirectory() as temp_dir:
+ tar_path = os.path.join(temp_dir, 'files.tar')
+ with open(tar_path, 'wb') as fwrite:
+ fwrite.write(config.files)
+ self._file_operator.copy_to(tar_path, sparkapp_path, extract=True)
config_dict = config.build_config(sparkapp_path)
logging.info(f'submit sparkapp, config: {config_dict}')
- resp = k8s_client.create_sparkapplication(config_dict)
- return SparkAppInfo.from_k8s_resp(resp)
+ resp = k8s_client.create_app(config_dict, SPARKOPERATOR_CUSTOM_GROUP, SPARKOPERATOR_CUSTOM_VERSION,
+ CrdKind.SPARK_APPLICATION.value)
+ return from_k8s_resp(resp)
- def get_sparkapp_info(self, name: str) -> SparkAppInfo:
+ def get_sparkapp_info(self, name: str) -> sparkapp_pb2.SparkAppInfo:
""" get sparkapp info
Args:
@@ -143,9 +99,21 @@ def get_sparkapp_info(self, name: str) -> SparkAppInfo:
SparkAppInfo: resp of sparkapp
"""
resp = k8s_client.get_sparkapplication(name)
- return SparkAppInfo.from_k8s_resp(resp)
+ return from_k8s_resp(resp)
+
+ def get_sparkapp_log(self, name: str, lines: int) -> str:
+ """ get sparkapp log
+
+ Args:
+ name (str): sparkapp name
+ lines (int): max lines of log
+
+ Returns:
+ str: sparkapp log
+ """
+ return k8s_client.get_pod_log(f'{name}-driver', SPARKOPERATOR_NAMESPACE, tail_lines=lines)
- def delete_sparkapp(self, name: str) -> SparkAppInfo:
+ def delete_sparkapp(self, name: str) -> sparkapp_pb2.SparkAppInfo:
"""delete sparkapp
- delete sparkapp. If failed, raise exception
- delete the tmp filesystem
@@ -162,9 +130,9 @@ def delete_sparkapp(self, name: str) -> SparkAppInfo:
"""
existable, sparkapp_path = self._get_sparkapp_upload_path(name)
if existable:
- self._file_client.remove(sparkapp_path)
+ self._file_manager.remove(sparkapp_path)
resp = k8s_client.delete_sparkapplication(name)
- sparkapp_info = SparkAppInfo.from_k8s_resp(resp)
+ sparkapp_info = from_k8s_resp(resp)
return sparkapp_info
diff --git a/web_console_v2/api/fedlearner_webconsole/sparkapp/service_test.py b/web_console_v2/api/fedlearner_webconsole/sparkapp/service_test.py
new file mode 100644
index 000000000..a485547cd
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sparkapp/service_test.py
@@ -0,0 +1,278 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import os
+import shutil
+import unittest
+
+from unittest.mock import MagicMock, patch
+
+from envs import Envs
+from fedlearner_webconsole.sparkapp.schema import SparkAppConfig
+from fedlearner_webconsole.sparkapp.service import SparkAppService
+from testing.common import NoWebServerTestCase
+
+BASE_DIR = Envs.BASE_DIR
+
+
+class SparkAppServiceTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ self._upload_path = os.path.join(BASE_DIR, 'test-spark')
+ os.makedirs(self._upload_path)
+ self._patch_upload_path = patch('fedlearner_webconsole.sparkapp.service.UPLOAD_PATH', self._upload_path)
+ self._patch_upload_path.start()
+ self._sparkapp_service = SparkAppService()
+
+ def tearDown(self) -> None:
+ self._patch_upload_path.stop()
+ shutil.rmtree(self._upload_path)
+ return super().tearDown()
+
+ def _get_tar_file_path(self) -> str:
+ return os.path.join(BASE_DIR, 'testing/test_data/sparkapp.tar')
+
+ def test_get_sparkapp_upload_path(self):
+ existable, sparkapp_path = self._sparkapp_service._get_sparkapp_upload_path('test') # pylint: disable=protected-access
+ self.assertFalse(existable)
+
+ os.makedirs(sparkapp_path)
+ existable, _ = self._sparkapp_service._get_sparkapp_upload_path('test') # pylint: disable=protected-access
+ self.assertTrue(existable)
+
+ @patch('fedlearner_webconsole.k8s.k8s_client.k8s_client.create_app')
+ def test_submit_sparkapp(self, mock_create_app: MagicMock):
+ mock_create_app.return_value = {
+ 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
+ 'kind': 'SparkApplication',
+ 'metadata': {
+ 'creationTimestamp':
+ '2021-05-18T08:59:16Z',
+ 'generation':
+ 1,
+ 'name':
+ 'fl-transformer-yaml',
+ 'namespace':
+ 'fedlearner',
+ 'resourceVersion':
+ '432649442',
+ 'selfLink':
+ '/apis/sparkoperator.k8s.io/v1beta2/namespaces/fedlearner/sparkapplications/fl-transformer-yaml',
+ 'uid':
+ '52d66d27-b7b7-11eb-b9df-b8599fdb0aac'
+ },
+ 'spec': {
+ 'arguments': ['data.csv', 'data_tfrecords/'],
+ 'driver': {
+ 'coreLimit': '4000m',
+ 'cores': 1,
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ 'memory': '512m',
+ 'serviceAccount': 'spark',
+ },
+ 'dynamicAllocation': {
+ 'enabled': False
+ },
+ 'executor': {
+ 'cores': 1,
+ 'instances': 1,
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ 'memory': '512m',
+ },
+ 'image': 'dockerhub.com',
+ 'imagePullPolicy': 'Always',
+ 'mainApplicationFile': 'transformer.py',
+ 'mode': 'cluster',
+ 'pythonVersion': '3',
+ 'restartPolicy': {
+ 'type': 'Never'
+ },
+ 'sparkConf': {
+ 'spark.shuffle.service.enabled': 'false'
+ },
+ 'sparkVersion': '3.0.0',
+ 'type': 'Python',
+ },
+ 'status': {
+ 'applicationState': {
+ 'state': 'COMPLETED'
+ },
+ 'driverInfo': {
+ 'podName': 'fl-transformer-yaml-driver',
+ 'webUIAddress': '11.249.131.12:4040',
+ 'webUIPort': 4040,
+ 'webUIServiceName': 'fl-transformer-yaml-ui-svc'
+ },
+ 'executionAttempts': 1,
+ 'executorState': {
+ 'fl-transformer-yaml-bdc15979a314310b-exec-1': 'PENDING',
+ 'fl-transformer-yaml-bdc15979a314310b-exec-2': 'COMPLETED'
+ },
+ 'lastSubmissionAttemptTime': '2021-05-18T10:31:13Z',
+ 'sparkApplicationId': 'spark-a380bfd520164d828a334bcb3a6404f9',
+ 'submissionAttempts': 1,
+ 'submissionID': '5bc7e2e7-cc0f-420c-8bc7-138b651a1dde',
+ 'terminationTime': '2021-05-18T10:32:08Z'
+ }
+ }
+
+ tarball_file_path = os.path.join(BASE_DIR, 'testing/test_data/sparkapp.tar')
+ with open(tarball_file_path, 'rb') as f:
+ files_bin = f.read()
+
+ inputs = {
+ 'name': 'fl-transformer-yaml',
+ 'files': files_bin,
+ 'image_url': 'dockerhub.com',
+ 'driver_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'coreLimit': '4000m',
+ },
+ 'executor_config': {
+ 'cores': 1,
+ 'memory': '200m',
+ 'instances': 5,
+ },
+ 'command': ['data.csv', 'data.rd'],
+ 'main_application': '${prefix}/convertor.py'
+ }
+ config = SparkAppConfig.from_dict(inputs)
+ resp = self._sparkapp_service.submit_sparkapp(config)
+
+ self.assertTrue(
+ os.path.exists(os.path.join(self._upload_path, 'sparkapp', 'fl-transformer-yaml', 'convertor.py')))
+ mock_create_app.assert_called_once()
+ args = mock_create_app.call_args[0]
+ self.assertEqual(len(args), 4)
+ self.assertEqual(args[1:], ('sparkoperator.k8s.io', 'v1beta2', 'sparkapplications'))
+ self.assertTrue(resp.namespace, 'fedlearner')
+
+ @patch('fedlearner_webconsole.k8s.k8s_client.k8s_client.get_sparkapplication')
+ def test_get_sparkapp_info(self, mock_get_sparkapp: MagicMock):
+ mock_get_sparkapp.return_value = {
+ 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
+ 'kind': 'SparkApplication',
+ 'metadata': {
+ 'creationTimestamp':
+ '2021-05-18T08:59:16Z',
+ 'generation':
+ 1,
+ 'name':
+ 'fl-transformer-yaml',
+ 'namespace':
+ 'fedlearner',
+ 'resourceVersion':
+ '432649442',
+ 'selfLink':
+ '/apis/sparkoperator.k8s.io/v1beta2/namespaces/fedlearner/sparkapplications/fl-transformer-yaml',
+ 'uid':
+ '52d66d27-b7b7-11eb-b9df-b8599fdb0aac'
+ },
+ 'spec': {
+ 'arguments': ['data.csv', 'data_tfrecords/'],
+ 'driver': {
+ 'coreLimit': '4000m',
+ 'cores': 1,
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ 'memory': '512m',
+ 'serviceAccount': 'spark',
+ },
+ 'dynamicAllocation': {
+ 'enabled': False
+ },
+ 'executor': {
+ 'cores': 1,
+ 'instances': 1,
+ 'labels': {
+ 'version': '3.0.0'
+ },
+ 'memory': '512m',
+ },
+ 'image': 'dockerhub.com',
+ 'imagePullPolicy': 'Always',
+ 'mainApplicationFile': 'transformer.py',
+ 'mode': 'cluster',
+ 'pythonVersion': '3',
+ 'restartPolicy': {
+ 'type': 'Never'
+ },
+ 'sparkConf': {
+ 'spark.shuffle.service.enabled': 'false'
+ },
+ 'sparkVersion': '3.0.0',
+ 'type': 'Python',
+ },
+ 'status': {
+ 'applicationState': {
+ 'state': 'COMPLETED'
+ },
+ 'driverInfo': {
+ 'podName': 'fl-transformer-yaml-driver',
+ 'webUIAddress': '11.249.131.12:4040',
+ 'webUIPort': 4040,
+ 'webUIServiceName': 'fl-transformer-yaml-ui-svc'
+ },
+ 'executionAttempts': 1,
+ 'executorState': {
+ 'fl-transformer-yaml-bdc15979a314310b-exec-1': 'PENDING',
+ 'fl-transformer-yaml-bdc15979a314310b-exec-2': 'COMPLETED'
+ },
+ 'lastSubmissionAttemptTime': '2021-05-18T10:31:13Z',
+ 'sparkApplicationId': 'spark-a380bfd520164d828a334bcb3a6404f9',
+ 'submissionAttempts': 1,
+ 'submissionID': '5bc7e2e7-cc0f-420c-8bc7-138b651a1dde',
+ 'terminationTime': '2021-05-18T10:32:08Z'
+ }
+ }
+
+ resp = self._sparkapp_service.get_sparkapp_info('fl-transformer-yaml')
+
+ mock_get_sparkapp.assert_called_once()
+ self.assertTrue(resp.namespace, 'fedlearner')
+
+ @patch('fedlearner_webconsole.sparkapp.service.SparkAppService._get_sparkapp_upload_path')
+ @patch('fedlearner_webconsole.utils.file_manager.FileManager.remove')
+ @patch('fedlearner_webconsole.k8s.k8s_client.k8s_client.delete_sparkapplication')
+ def test_delete_sparkapp(self, mock_delete_sparkapp: MagicMock, mock_file_mananger_remove: MagicMock,
+ mock_upload_path: MagicMock):
+ mock_delete_sparkapp.return_value = {
+ 'kind': 'Status',
+ 'apiVersion': 'v1',
+ 'metadata': {},
+ 'status': 'Success',
+ 'details': {
+ 'name': 'fl-transformer-yaml',
+ 'group': 'sparkoperator.k8s.io',
+ 'kind': 'sparkapplications',
+ 'uid': '52d66d27-b7b7-11eb-b9df-b8599fdb0aac'
+ }
+ }
+ mock_upload_path.return_value = (True, 'test')
+ resp = self._sparkapp_service.delete_sparkapp(name='fl-transformer-yaml')
+ mock_delete_sparkapp.assert_called_once()
+ mock_file_mananger_remove.assert_called_once()
+ self.assertTrue(resp.name, 'fl-transformer-yaml')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/swagger/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/swagger/BUILD.bazel
new file mode 100644
index 000000000..38a8e3d24
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/swagger/BUILD.bazel
@@ -0,0 +1,12 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "@common_marshmallow//:pkg",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/swagger/__init__.py b/web_console_v2/api/fedlearner_webconsole/swagger/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/swagger/models.py b/web_console_v2/api/fedlearner_webconsole/swagger/models.py
new file mode 100644
index 000000000..b2eddbe52
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/swagger/models.py
@@ -0,0 +1,35 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List
+
+from marshmallow import Schema
+
+
+class _SchemaManager(object):
+
+ def __init__(self):
+ self._schemas = []
+
+ def append(self, schema: Schema):
+ if schema in self._schemas:
+ return
+ self._schemas.append(schema)
+
+ def get_schemas(self) -> List[Schema]:
+ return self._schemas
+
+
+schema_manager = _SchemaManager()
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/BUILD.bazel
new file mode 100644
index 000000000..68001c2c8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/BUILD.bazel
@@ -0,0 +1,5 @@
+filegroup(
+ name = "sys_preset_templates",
+ srcs = glob(["**/*.json"]),
+ visibility = ["//visibility:public"],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-fed-left.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-fed-left.json
new file mode 100644
index 000000000..0165805ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-fed-left.json
@@ -0,0 +1,1994 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "e2e-test",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": true,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job-streaming",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": \"Streaming\"\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": str(workflow.variables.input_base_dir)\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(workflow.variables.file_wildcard)\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(1024)\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": str(workflow.variables.input_data_format)\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [],
+ "easy_mode": true,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job-psi",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": \"PSI\"\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": str(workflow.variables.input_base_dir)\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(workflow.variables.file_wildcard)\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(1024)\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": str(workflow.variables.input_data_format)\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "raw-data-job-streaming"
+ },
+ {
+ "source": "raw-data-job-psi"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "DATA_JOIN",
+ "name": "data-join-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(workflow.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(workflow.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": \"--batch_mode\"\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(0)\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(999999999999)\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-streaming'].name)\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n list(system.variables.volume_mounts_list)\n ,\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n list(system.variables.volumes_list)\n\n }\n },\n \"pair\": true,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-streaming'].name)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"MIN_MATCHING_WINDOW\",\n \"value\": str(1024)\n },\n {\n \"name\": \"MAX_MATCHING_WINDOW\",\n \"value\": str(4096)\n },\n {\n \"name\": \"RAW_DATA_ITER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n list(system.variables.volume_mounts_list)\n ,\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n list(system.variables.volumes_list)\n\n }\n },\n \"pair\": true,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "raw-data-job-streaming"
+ },
+ {
+ "source": "raw-data-job-psi"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "PSI_DATA_JOIN",
+ "name": "psi-data-join-job",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_pem",
+ "tag": "",
+ "typed_value": "-----BEGIN RSA PUBLIC KEY-----\nMIGJAoGBAMZYpBzYDnROmrqC8LhDXhgW13E/JuTUHkHKsGwPScnp5TAueqo53ayu\nYzSlLrI+yQp206Kb/C+w/VdWJcLLAjAUBGqfZvCnsmpfOMt+s3JrNH24RCg282m/\nnIdpoVqb7SEDFlJPq3s0g/oZ5v0c74Yy5J/DuuaWcuU7URuYRbbnAgMBAAE=\n-----END RSA PUBLIC KEY-----",
+ "value": "-----BEGIN RSA PUBLIC KEY-----\nMIGJAoGBAMZYpBzYDnROmrqC8LhDXhgW13E/JuTUHkHKsGwPScnp5TAueqo53ayu\nYzSlLrI+yQp206Kb/C+w/VdWJcLLAjAUBGqfZvCnsmpfOMt+s3JrNH24RCg282m/\nnIdpoVqb7SEDFlJPq3s0g/oZ5v0c74Yy5J/DuuaWcuU7URuYRbbnAgMBAAE=\n-----END RSA PUBLIC KEY-----",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"TextArea\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_private_key_path",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_path",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(workflow.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(workflow.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(0)\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(999999999999)\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-psi'].name)\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": \"--batch_mode\"\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-psi'].name)\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": str(self.variables.rsa_key_pem)\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": str(self.variables.rsa_key_path)\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": str(self.variables.rsa_private_key_path)\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": \"data.aml.fl\"\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(None)\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(None)\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(False)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": int(int(workflow.variables.num_partitions))\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "data-join-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": true,
+ "job_type": "NN_MODEL_TRANINING",
+ "name": "nn-train",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "epoch_num",
+ "tag": "",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "code_tar",
+ "tag": "",
+ "typed_value": {
+ "follower/config.py": "leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\nleader_label_name = ['label']\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\n",
+ "follower/main.py": "# Copyright 2020 The FedLearner Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# coding: utf-8\n# pylint: disable=no-else-return, inconsistent-return-statements\n\nimport os\nimport logging\nimport numpy as np\nimport tensorflow.compat.v1 as tf\nimport fedlearner.trainer as flt\nfrom config import *\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\n\nROLE = 'follower'\n\nparser = flt.trainer_worker.create_argument_parser()\nparser.add_argument('--batch-size',\n type=int,\n default=100,\n help='Training batch size.')\nargs = parser.parse_args()\n\n\ndef input_fn(bridge, trainer_master):\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\n trainer_master).make_dataset()\n\n def parse_fn(example):\n feature_map = dict()\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\n for name in follower_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n features = tf.parse_example(example, features=feature_map)\n return features, dict(y=tf.constant(0))\n\n dataset = dataset.map(map_func=parse_fn,\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\n return dataset\n\n\ndef serving_input_receiver_fn():\n feature_map = {\n \"example_id\": tf.FixedLenFeature([], tf.string),\n \"raw_id\": tf.FixedLenFeature([], tf.string),\n }\n for name in follower_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\n features = tf.parse_example(record_batch, features=feature_map)\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\n\n\ndef model_fn(model, features, labels, mode):\n logging.info('model_fn: mode %s', mode)\n x = [\n tf.expand_dims(features[name], axis=-1)\n for name in follower_feature_names\n ]\n x = tf.concat(x, axis=-1)\n\n w1f = tf.get_variable(\n 'w1l',\n shape=[len(follower_feature_names),\n len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\n b1f = tf.get_variable('b1l',\n shape=[len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.zeros_initializer())\n\n act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)\n\n if mode == tf.estimator.ModeKeys.PREDICT:\n return model.make_spec(mode=mode, predictions=act1_f)\n\n if mode == tf.estimator.ModeKeys.TRAIN:\n gact1_f = model.send('act1_f', act1_f, require_grad=True)\n elif mode == tf.estimator.ModeKeys.EVAL:\n model.send('act1_f', act1_f, require_grad=False)\n\n #acc = model.recv('acc', tf.float32, require_grad=False)\n auc = model.recv('auc', tf.float32, require_grad=False)\n loss = model.recv('loss', tf.float32, require_grad=False)\n logging_hook = tf.train.LoggingTensorHook({\n 'auc': auc, 'loss': loss,\n }, every_n_iter=10)\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\n\n global_step = tf.train.get_or_create_global_step()\n if mode == tf.estimator.ModeKeys.TRAIN:\n optimizer = tf.train.GradientDescentOptimizer(0.1)\n train_op = model.minimize(optimizer,\n act1_f,\n grad_loss=gact1_f,\n global_step=global_step)\n return model.make_spec(mode,\n loss=tf.math.reduce_mean(act1_f),\n train_op=train_op,\n training_hooks=[logging_hook, step_metric_hook])\n if mode == tf.estimator.ModeKeys.EVAL:\n fake_loss = tf.reduce_mean(act1_f)\n return model.make_spec(mode=mode, loss=fake_loss,\n evaluation_hooks=[logging_hook, step_metric_hook])\n\n\nif __name__ == '__main__':\n logging.basicConfig(level=logging.INFO)\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\n serving_input_receiver_fn)\n",
+ "leader/config.py": "leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\nleader_label_name = ['label']\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\n",
+ "leader/main.py": "# Copyright 2020 The FedLearner Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# coding: utf-8\n# pylint: disable=no-else-return, inconsistent-return-statements\n\nimport os\nimport logging\nimport tensorflow.compat.v1 as tf\nimport fedlearner.trainer as flt\nfrom config import *\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\n\nROLE = 'leader'\n\nparser = flt.trainer_worker.create_argument_parser()\nparser.add_argument('--batch-size',\n type=int,\n default=100,\n help='Training batch size.')\nargs = parser.parse_args()\n\n\ndef input_fn(bridge, trainer_master):\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\n trainer_master).make_dataset()\n\n def parse_fn(example):\n feature_map = dict()\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\n for name in leader_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n label_map = {}\n for name in leader_label_name:\n label_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n features = tf.parse_example(example, features=feature_map)\n labels = tf.parse_example(example, features=label_map)\n return features, labels\n\n dataset = dataset.map(map_func=parse_fn,\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\n return dataset\n\n\ndef serving_input_receiver_fn():\n feature_map = {\n \"example_id\": tf.FixedLenFeature([], tf.string),\n \"raw_id\": tf.FixedLenFeature([], tf.string),\n }\n for name in leader_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\n features = tf.parse_example(record_batch, features=feature_map)\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\n\n\ndef model_fn(model, features, labels, mode):\n logging.info('model_fn: mode %s', mode)\n x = [\n tf.expand_dims(features[name], axis=-1)\n for name in leader_feature_names\n ]\n x = tf.concat(x, axis=-1)\n\n w1l = tf.get_variable(\n 'w1l',\n shape=[len(leader_feature_names),\n len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\n b1l = tf.get_variable('b1l',\n shape=[len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.zeros_initializer())\n\n act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)\n if mode == tf.estimator.ModeKeys.TRAIN:\n act1_f = model.recv('act1_f', tf.float32, require_grad=True)\n elif mode == tf.estimator.ModeKeys.EVAL:\n act1_f = model.recv('act1_f', tf.float32, require_grad=False)\n else:\n act1_f = features['act1_f']\n logits = act1_l + act1_f\n pred = tf.math.sigmoid(logits)\n\n if mode == tf.estimator.ModeKeys.PREDICT:\n return model.make_spec(mode=mode, predictions=pred)\n\n y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]\n y = tf.concat(y, axis=-1)\n\n loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n _, auc = tf.metrics.auc(labels=y, predictions=pred)\n #correct = tf.nn.in_top_k(predictions=logits, targets=y, k=1)\n #acc = tf.reduce_mean(input_tensor=tf.cast(correct, tf.float32))\n logging_hook = tf.train.LoggingTensorHook({\n # 'acc': acc,\n 'auc': auc,\n 'loss': loss,\n }, every_n_iter=10)\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\n #model.send('acc', acc, require_grad=False)\n model.send('auc', auc, require_grad=False)\n model.send('loss', loss, require_grad=False)\n\n global_step = tf.train.get_or_create_global_step()\n if mode == tf.estimator.ModeKeys.TRAIN:\n optimizer = tf.train.AdamOptimizer(1e-4)\n train_op = model.minimize(optimizer, loss, global_step=global_step)\n return model.make_spec(mode=mode,\n loss=loss,\n train_op=train_op,\n training_hooks=[logging_hook, step_metric_hook])\n\n if mode == tf.estimator.ModeKeys.EVAL:\n loss_pair = tf.metrics.mean(loss)\n return model.make_spec(mode=mode,\n loss=loss,\n eval_metric_ops={'loss': loss_pair},\n evaluation_hooks=[logging_hook, step_metric_hook])\n\n\nif __name__ == '__main__':\n logging.basicConfig(level=logging.INFO)\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\n serving_input_receiver_fn)\n",
+ "main.py": ""
+ },
+ "value": "{\"main.py\":\"\",\"leader/main.py\":\"# Copyright 2020 The FedLearner Authors. All Rights Reserved.\\n#\\n# Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n# you may not use this file except in compliance with the License.\\n# You may obtain a copy of the License at\\n#\\n# http://www.apache.org/licenses/LICENSE-2.0\\n#\\n# Unless required by applicable law or agreed to in writing, software\\n# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n# See the License for the specific language governing permissions and\\n# limitations under the License.\\n\\n# coding: utf-8\\n# pylint: disable=no-else-return, inconsistent-return-statements\\n\\nimport os\\nimport logging\\nimport tensorflow.compat.v1 as tf\\nimport fedlearner.trainer as flt\\nfrom config import *\\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\\n\\nROLE = 'leader'\\n\\nparser = flt.trainer_worker.create_argument_parser()\\nparser.add_argument('--batch-size',\\n type=int,\\n default=100,\\n help='Training batch size.')\\nargs = parser.parse_args()\\n\\n\\ndef input_fn(bridge, trainer_master):\\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\\n trainer_master).make_dataset()\\n\\n def parse_fn(example):\\n feature_map = dict()\\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\\n for name in leader_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n label_map = {}\\n for name in leader_label_name:\\n label_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n features = tf.parse_example(example, features=feature_map)\\n labels = tf.parse_example(example, features=label_map)\\n return features, labels\\n\\n dataset = dataset.map(map_func=parse_fn,\\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\\n return dataset\\n\\n\\ndef serving_input_receiver_fn():\\n feature_map = {\\n \\\"example_id\\\": tf.FixedLenFeature([], tf.string),\\n \\\"raw_id\\\": tf.FixedLenFeature([], tf.string),\\n }\\n for name in leader_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\\n features = tf.parse_example(record_batch, features=feature_map)\\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\\n\\n\\ndef model_fn(model, features, labels, mode):\\n logging.info('model_fn: mode %s', mode)\\n x = [\\n tf.expand_dims(features[name], axis=-1)\\n for name in leader_feature_names\\n ]\\n x = tf.concat(x, axis=-1)\\n\\n w1l = tf.get_variable(\\n 'w1l',\\n shape=[len(leader_feature_names),\\n len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\\n b1l = tf.get_variable('b1l',\\n shape=[len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.zeros_initializer())\\n\\n act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n act1_f = model.recv('act1_f', tf.float32, require_grad=True)\\n elif mode == tf.estimator.ModeKeys.EVAL:\\n act1_f = model.recv('act1_f', tf.float32, require_grad=False)\\n else:\\n act1_f = features['act1_f']\\n logits = act1_l + act1_f\\n pred = tf.math.sigmoid(logits)\\n\\n if mode == tf.estimator.ModeKeys.PREDICT:\\n return model.make_spec(mode=mode, predictions=pred)\\n\\n y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]\\n y = tf.concat(y, axis=-1)\\n\\n loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\\n _, auc = tf.metrics.auc(labels=y, predictions=pred)\\n #correct = tf.nn.in_top_k(predictions=logits, targets=y, k=1)\\n #acc = tf.reduce_mean(input_tensor=tf.cast(correct, tf.float32))\\n logging_hook = tf.train.LoggingTensorHook({\\n # 'acc': acc,\\n 'auc': auc,\\n 'loss': loss,\\n }, every_n_iter=10)\\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\\n #model.send('acc', acc, require_grad=False)\\n model.send('auc', auc, require_grad=False)\\n model.send('loss', loss, require_grad=False)\\n\\n global_step = tf.train.get_or_create_global_step()\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n optimizer = tf.train.AdamOptimizer(1e-4)\\n train_op = model.minimize(optimizer, loss, global_step=global_step)\\n return model.make_spec(mode=mode,\\n loss=loss,\\n train_op=train_op,\\n training_hooks=[logging_hook, step_metric_hook])\\n\\n if mode == tf.estimator.ModeKeys.EVAL:\\n loss_pair = tf.metrics.mean(loss)\\n return model.make_spec(mode=mode,\\n loss=loss,\\n eval_metric_ops={'loss': loss_pair},\\n evaluation_hooks=[logging_hook, step_metric_hook])\\n\\n\\nif __name__ == '__main__':\\n logging.basicConfig(level=logging.INFO)\\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\\n serving_input_receiver_fn)\\n\",\"follower/main.py\":\"# Copyright 2020 The FedLearner Authors. All Rights Reserved.\\n#\\n# Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n# you may not use this file except in compliance with the License.\\n# You may obtain a copy of the License at\\n#\\n# http://www.apache.org/licenses/LICENSE-2.0\\n#\\n# Unless required by applicable law or agreed to in writing, software\\n# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n# See the License for the specific language governing permissions and\\n# limitations under the License.\\n\\n# coding: utf-8\\n# pylint: disable=no-else-return, inconsistent-return-statements\\n\\nimport os\\nimport logging\\nimport numpy as np\\nimport tensorflow.compat.v1 as tf\\nimport fedlearner.trainer as flt\\nfrom config import *\\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\\n\\nROLE = 'follower'\\n\\nparser = flt.trainer_worker.create_argument_parser()\\nparser.add_argument('--batch-size',\\n type=int,\\n default=100,\\n help='Training batch size.')\\nargs = parser.parse_args()\\n\\n\\ndef input_fn(bridge, trainer_master):\\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\\n trainer_master).make_dataset()\\n\\n def parse_fn(example):\\n feature_map = dict()\\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\\n for name in follower_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n features = tf.parse_example(example, features=feature_map)\\n return features, dict(y=tf.constant(0))\\n\\n dataset = dataset.map(map_func=parse_fn,\\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\\n return dataset\\n\\n\\ndef serving_input_receiver_fn():\\n feature_map = {\\n \\\"example_id\\\": tf.FixedLenFeature([], tf.string),\\n \\\"raw_id\\\": tf.FixedLenFeature([], tf.string),\\n }\\n for name in follower_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\\n features = tf.parse_example(record_batch, features=feature_map)\\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\\n\\n\\ndef model_fn(model, features, labels, mode):\\n logging.info('model_fn: mode %s', mode)\\n x = [\\n tf.expand_dims(features[name], axis=-1)\\n for name in follower_feature_names\\n ]\\n x = tf.concat(x, axis=-1)\\n\\n w1f = tf.get_variable(\\n 'w1l',\\n shape=[len(follower_feature_names),\\n len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\\n b1f = tf.get_variable('b1l',\\n shape=[len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.zeros_initializer())\\n\\n act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)\\n\\n if mode == tf.estimator.ModeKeys.PREDICT:\\n return model.make_spec(mode=mode, predictions=act1_f)\\n\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n gact1_f = model.send('act1_f', act1_f, require_grad=True)\\n elif mode == tf.estimator.ModeKeys.EVAL:\\n model.send('act1_f', act1_f, require_grad=False)\\n\\n #acc = model.recv('acc', tf.float32, require_grad=False)\\n auc = model.recv('auc', tf.float32, require_grad=False)\\n loss = model.recv('loss', tf.float32, require_grad=False)\\n logging_hook = tf.train.LoggingTensorHook({\\n 'auc': auc, 'loss': loss,\\n }, every_n_iter=10)\\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\\n\\n global_step = tf.train.get_or_create_global_step()\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n optimizer = tf.train.GradientDescentOptimizer(0.1)\\n train_op = model.minimize(optimizer,\\n act1_f,\\n grad_loss=gact1_f,\\n global_step=global_step)\\n return model.make_spec(mode,\\n loss=tf.math.reduce_mean(act1_f),\\n train_op=train_op,\\n training_hooks=[logging_hook, step_metric_hook])\\n if mode == tf.estimator.ModeKeys.EVAL:\\n fake_loss = tf.reduce_mean(act1_f)\\n return model.make_spec(mode=mode, loss=fake_loss,\\n evaluation_hooks=[logging_hook, step_metric_hook])\\n\\n\\nif __name__ == '__main__':\\n logging.basicConfig(level=logging.INFO)\\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\\n serving_input_receiver_fn)\\n\",\"follower/config.py\":\"leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\\nleader_label_name = ['label']\\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\\n\",\"leader/config.py\":\"leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\\nleader_label_name = ['label']\\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\\n\"}",
+ "value_type": "CODE",
+ "widget_schema": "{\"component\":\"Code\",\"required\":true}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"activeDeadlineSeconds\": 1200, \n \"fedReplicaSpecs\": {\n \"Master\": {\n \"backoffLimit\": 1,\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(int(self.variables.epoch_num))\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(None)\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(None)\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": workflow.jobs['data-join-job'].name\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": \"\"\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": \"\"\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": self.variables.code_tar\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": \"\" and project.variables.storage_root_path + \"/job_output/\" + \"\" + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": \"\"\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n\n }\n },\n \"replicas\": int(1)\n },\n \"PS\": {\n \"backoffLimit\": 1,\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(1)\n },\n \"Worker\": {\n \"backoffLimit\": 6,\n \"mustSuccess\": True,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": \"train\"\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(1)\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": \"\"\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": self.variables.code_tar\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(1000)\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(None)\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(None)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\":[\"/bin/bash\",\"-c\"],\n \"args\": [\"export WORKER_RANK=$$INDEX && export PEER_ADDR=$$SERVICE_ID && /app/deploy/scripts/trainer/run_trainer_worker.sh\"],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(1)\n }\n }\n }\n}"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "data-join-job"
+ },
+ {
+ "source": "psi-data-join-job"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "TREE_MODEL_TRAINING",
+ "name": "tree-train",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_depth",
+ "tag": "",
+ "typed_value": "3",
+ "value": "3",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_iters",
+ "tag": "",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_ext",
+ "tag": "",
+ "typed_value": ".data",
+ "value": ".data",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_type",
+ "tag": "",
+ "typed_value": "tfrecord",
+ "value": "tfrecord",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"csv\",\"tfrecord\"]}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(workflow.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(workflow.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": \"train\"\n },\n {\n \"name\": \"LOSS_TYPE\",\n \"value\": \"logistic\"\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": str(workflow.jobs['psi-data-join-job'].name)\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"VALIDATION_DATA_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"NO_DATA\",\n \"value\": str(False)\n },\n {\n \"name\": \"FILE_EXT\",\n \"value\": str(self.variables.file_ext)\n },\n {\n \"name\": \"FILE_TYPE\",\n \"value\": str(self.variables.file_type)\n },\n {\n \"name\": \"LOAD_MODEL_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_MODEL_NAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(1)\n },\n {\n \"name\": \"LEARNING_RATE\",\n \"value\": str(0.3)\n },\n {\n \"name\": \"MAX_ITERS\",\n \"value\": str(int(self.variables.max_iters))\n },\n {\n \"name\": \"MAX_DEPTH\",\n \"value\": str(int(self.variables.max_depth))\n },\n {\n \"name\": \"MAX_BINS\",\n \"value\": str(33)\n },\n {\n \"name\": \"L2_REGULARIZATION\",\n \"value\": str(1.0)\n },\n {\n \"name\": \"NUM_PARALLEL\",\n \"value\": str(1)\n },\n {\n \"name\": \"VERIFY_EXAMPLE_IDS\",\n \"value\": str(False)\n },\n {\n \"name\": \"IGNORE_FIELDS\",\n \"value\": \"\"\n },\n {\n \"name\": \"CAT_FIELDS\",\n \"value\": \"\"\n },\n {\n \"name\": \"LABEL_FIELD\",\n \"value\": \"label\"\n },\n {\n \"name\": \"SEND_SCORES_TO_FOLLOWER\",\n \"value\": str(False)\n },\n {\n \"name\": \"SEND_METRICS_TO_FOLLOWER\",\n \"value\": str(False)\n },\n {\n \"name\": \"ENABLE_PACKING\",\n \"value\": str(True)\n },\n {\n \"name\": \"ES_BATCH_SIZE\",\n \"value\": str(10)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_tree_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"4Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"4Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": 1\n }\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "image",
+ "tag": "",
+ "typed_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "num_partitions",
+ "tag": "",
+ "typed_value": "1",
+ "value": "1",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "role",
+ "tag": "",
+ "typed_value": "Follower",
+ "value": "Follower",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"]}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "input_base_dir",
+ "tag": "",
+ "typed_value": "/app/deploy/integrated_test/credit_default",
+ "value": "/app/deploy/integrated_test/credit_default",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_wildcard",
+ "tag": "",
+ "typed_value": "*host.csv",
+ "value": "*host.csv",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "input_data_format",
+ "tag": "",
+ "typed_value": "CSV_DICT",
+ "value": "CSV_DICT",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"CSV_DICT\",\"TF_RECORD\"]}"
+ }
+ ]
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "data-join-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": ${Slot_batch_mode}\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(${Slot_start_time})\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(${Slot_end_time})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n ${Slot_volume_mounts}\n ,\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n ${Slot_volumes}\n\n }\n },\n \"pair\": true,\n \"replicas\": ${Slot_master_replicas}\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(${Slot_data_block_dump_interval})\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(${Slot_data_block_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(${Slot_example_id_dump_interval})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(${Slot_example_id_dump_threshold})\n },\n {\n \"name\": \"MIN_MATCHING_WINDOW\",\n \"value\": str(${Slot_min_matching_window})\n },\n {\n \"name\": \"MAX_MATCHING_WINDOW\",\n \"value\": str(${Slot_max_matching_window})\n },\n {\n \"name\": \"RAW_DATA_ITER\",\n \"value\": ${Slot_raw_data_iter}\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(${Slot_enable_negative_example_generator})\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n ${Slot_volume_mounts}\n ,\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n ${Slot_volumes}\n\n }\n },\n \"pair\": true,\n \"replicas\": ${Slot_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_mode": {
+ "default": "",
+ "default_value": "--batch_mode",
+ "help": "如果为空则为常驻求交",
+ "label": "是否为批处理模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次data block,小于0则无此限制",
+ "label": "数据dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_block_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多多少个样本就dump为一个data block,小于等于0则无此限制",
+ "label": "数据dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_enable_negative_example_generator": {
+ "default": "",
+ "default_value": false,
+ "help": "建议不修改,是否开启负采样,当follower求交时遇到无法匹配上的leader的example id,会以negative_sampling_rate为概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_end_time": {
+ "default": "",
+ "default_value": 999999999999.0,
+ "help": "建议不修改,使用自这个时间以前的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据末尾时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_max_matching_window": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,the max matching window for example join. <=0 means window size is infinite",
+ "label": "最大匹配滑窗",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_min_matching_window": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "建议不修改,the min matching window for example join ,<=0 means window size is infinite",
+ "label": "最小匹配滑窗",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_negative_sampling_rate": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,负采样比例,当follower求交时遇到无法匹配上的leader的example id,会以此概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "建议修改,求交后数据分区的数量,建议和raw_data一致",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_iter": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "raw_data文件类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_raw_data_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,原始数据的发布地址,根据参数内容在portal_publish_dir地址下寻找",
+ "label": "raw_data名字",
+ "reference": "workflow.jobs['raw-data-job-streaming'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_start_time": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,使用自这个时间起的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据起始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "nn-train": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(${Slot_epoch_num})\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(${Slot_start_date})\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(${Slot_end_date})\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": ${Slot_online_training}\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": ${Slot_checkpoint_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": ${Slot_load_checkpoint_filename}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": ${Slot_load_checkpoint_filename_with_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": ${Slot_load_checkpoint_from_job} and ${Slot_storage_root_path} + \"/job_output/\" + ${Slot_load_checkpoint_from_job} + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": ${Slot_export_path}\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_master_replicas})\n },\n \"PS\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + ${Slot_ps_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_ps_replicas})\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(${Slot_save_checkpoint_steps})\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(${Slot_save_checkpoint_secs})\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(${Slot_summary_save_steps})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_trainer_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_worker_replicas})\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_checkpoint_path": {
+ "default": "",
+ "default_value": "",
+ "help": "不建议修改,checkpoint输出路径,建议为空,会默认使用{storage_root_path}/job_output/{job_name}/checkpoints,强烈建议保持空值",
+ "label": "CHECKPOINT_PATH",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_key": {
+ "default": "",
+ "default_value": "",
+ "help": "代码tar包地址,如果为空则使用code tar",
+ "label": "模型代码路径",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_tar": {
+ "default": "",
+ "default_value": "",
+ "help": "代码包,variable中请使用代码类型",
+ "label": "代码",
+ "reference": "self.variables.code_tar",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,求交任务的名字",
+ "label": "数据源",
+ "reference": "workflow.jobs['data-join-job'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_end_date": {
+ "default": "",
+ "default_value": null,
+ "help": "training data end date",
+ "label": "结束时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_epoch_num": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "number of epoch for training, not support in online training",
+ "label": "epoch数量",
+ "reference": "self.variables.epoch_num",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_export_path": {
+ "default": "",
+ "default_value": "",
+ "help": "使用默认空值,将把models保存到$OUTPUT_BASE_DIR/exported_models 路径下。",
+ "label": "EXPORT_PATH",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模板不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_load_checkpoint_filename": {
+ "default": "",
+ "default_value": "",
+ "help": "加载checkpoint_path下的相对路径的checkpoint, 默认会加载checkpoint_path下的latest checkpoint",
+ "label": "LOAD_CHECKPOINT_FILENAME",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_filename_with_path": {
+ "default": "",
+ "default_value": "",
+ "help": "加载绝对路径下的checkpoint,需要细致到文件名",
+ "label": "从绝对路径加载checkpoint",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_from_job": {
+ "default": "",
+ "default_value": "",
+ "help": "指定任务名job_output下的latest checkpoint",
+ "label": "以任务名加载checkpoint",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_mode": {
+ "default": "",
+ "default_value": "train",
+ "help": "choices:['train','eval'] 训练还是验证",
+ "label": "模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_online_training": {
+ "default": "",
+ "default_value": "",
+ "help": "['','--online_training'] 否 是,the train master run for online training",
+ "label": "是否在线训练",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "PS的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,ps pod额外的环境变量",
+ "label": "PS额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_ps_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "PS的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的PS Pods数量",
+ "label": "PS的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_save_checkpoint_secs": {
+ "default": "",
+ "default_value": null,
+ "help": "int,Number of secs between checkpoints.",
+ "label": "SAVE_CHECKPOINT_SECS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_save_checkpoint_steps": {
+ "default": "",
+ "default_value": 1000.0,
+ "help": "int, Number of steps between checkpoints.",
+ "label": "SAVE_CHECKPOINT_STEPS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_sparse_estimator": {
+ "default": "",
+ "default_value": false,
+ "help": "bool,default False Whether using sparse estimator.",
+ "label": "SPARSE_ESTIMATOR",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_start_date": {
+ "default": "",
+ "default_value": null,
+ "help": "training data start date",
+ "label": "开始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_suffle_data_block": {
+ "default": "",
+ "default_value": "",
+ "help": "['','--shuffle_data_block'] 否 是,shuffle the data block or not",
+ "label": "是否shuffle数据块",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_summary_save_steps": {
+ "default": "",
+ "default_value": null,
+ "help": "int, Number of steps to save summary files.",
+ "label": "SUMMARY_SAVE_STEPS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_verbosity": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "int, Logging level",
+ "label": "日志等级",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Worker Pods数量",
+ "label": "Worker的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ }
+ },
+ "variables": []
+ },
+ "psi-data-join-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(${Slot_start_time})\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(${Slot_end_time})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": ${Slot_batch_mode}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": ${Slot_rsa_key_pem}\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": ${Slot_rsa_key_path}\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": ${Slot_rsa_private_key_path}\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": ${Slot_kms_key_name}\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": ${Slot_kms_client}\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": ${Slot_psi_raw_data_iter}\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": ${Slot_data_block_builder}\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": ${Slot_psi_output_builder}\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(${Slot_data_block_dump_interval})\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(${Slot_data_block_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(${Slot_example_id_dump_interval})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(${Slot_example_id_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(${Slot_psi_read_ahead_size})\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(${Slot_run_merger_read_ahead_buffer})\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(${Slot_enable_negative_example_generator})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_partition_num})\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_mode": {
+ "default": "",
+ "default_value": "--batch_mode",
+ "help": "如果为空则为常驻求交",
+ "label": "是否为批处理模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "data block output数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次data block,小于0则无此限制",
+ "label": "数据dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_block_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多多少个样本就dump为一个data block,小于等于0则无此限制",
+ "label": "数据dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_enable_negative_example_generator": {
+ "default": "",
+ "default_value": false,
+ "help": "建议不修改,是否开启负采样,当follower求交时遇到无法匹配上的leader的example id,会以negative_sampling_rate为概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_end_time": {
+ "default": "",
+ "default_value": 999999999999.0,
+ "help": "建议不修改,使用自这个时间以前的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据末尾时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_kms_client": {
+ "default": "",
+ "default_value": "data.aml.fl",
+ "help": "kms client",
+ "label": "kms client",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_kms_key_name": {
+ "default": "",
+ "default_value": "",
+ "help": "kms中的密钥名称,站内镜像需使用KMS",
+ "label": "密钥名称",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_negative_sampling_rate": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,负采样比例,当follower求交时遇到无法匹配上的leader的example id,会以此概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "建议修改,求交后数据分区的数量,建议和raw_data一致",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_psi_output_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "PSI output数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_psi_raw_data_iter": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "raw data数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_psi_read_ahead_size": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, the read ahead size for raw data",
+ "label": "psi_read_ahead_size",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,原始数据的发布地址,根据参数内容在portal_publish_dir地址下寻找",
+ "label": "raw_data名字",
+ "reference": "workflow.jobs['raw-data-job-psi'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_path": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA公钥或私钥的地址,在无RSA_KEY_PEM时必填",
+ "label": "RSA钥匙地址",
+ "reference": "self.variables.rsa_key_path",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_pem": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA公钥,follower需提供",
+ "label": "RSA公钥",
+ "reference": "self.variables.rsa_key_pem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_private_key_path": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA私钥的地址, leader必填",
+ "label": "RSA私钥地址",
+ "reference": "self.variables.rsa_private_key_path",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_run_merger_read_ahead_buffer": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, sort run merger read ahead buffer",
+ "label": "run_merger_read_ahead_buffer",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_start_time": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,使用自这个时间起的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据起始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "raw-data-job-psi": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "PSI",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "workflow.variables.file_wildcard",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "workflow.variables.input_base_dir",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "workflow.variables.input_data_format",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "raw-data-job-streaming": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "Streaming",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "workflow.variables.file_wildcard",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "workflow.variables.input_base_dir",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "workflow.variables.input_data_format",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "tree-train": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"LOSS_TYPE\",\n \"value\": ${Slot_loss_type}\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": ${Slot_data_path}\n },\n {\n \"name\": \"VALIDATION_DATA_PATH\",\n \"value\": ${Slot_validation_data_path}\n },\n {\n \"name\": \"NO_DATA\",\n \"value\": str(${Slot_no_data})\n },\n {\n \"name\": \"FILE_EXT\",\n \"value\": ${Slot_file_ext}\n },\n {\n \"name\": \"FILE_TYPE\",\n \"value\": ${Slot_file_type}\n },\n {\n \"name\": \"LOAD_MODEL_PATH\",\n \"value\": ${Slot_load_model_path}\n },\n {\n \"name\": \"LOAD_MODEL_NAME\",\n \"value\": ${Slot_load_model_name}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"LEARNING_RATE\",\n \"value\": str(${Slot_learning_rate})\n },\n {\n \"name\": \"MAX_ITERS\",\n \"value\": str(${Slot_max_iters})\n },\n {\n \"name\": \"MAX_DEPTH\",\n \"value\": str(${Slot_max_depth})\n },\n {\n \"name\": \"MAX_BINS\",\n \"value\": str(${Slot_max_bins})\n },\n {\n \"name\": \"L2_REGULARIZATION\",\n \"value\": str(${Slot_l2_regularization})\n },\n {\n \"name\": \"NUM_PARALLEL\",\n \"value\": str(${Slot_num_parallel})\n },\n {\n \"name\": \"VERIFY_EXAMPLE_IDS\",\n \"value\": str(${Slot_verify_example_ids})\n },\n {\n \"name\": \"IGNORE_FIELDS\",\n \"value\": ${Slot_ignore_fields}\n },\n {\n \"name\": \"CAT_FIELDS\",\n \"value\": ${Slot_cat_fields}\n },\n {\n \"name\": \"LABEL_FIELD\",\n \"value\": ${Slot_label_field}\n },\n {\n \"name\": \"SEND_SCORES_TO_FOLLOWER\",\n \"value\": str(${Slot_send_scores_to_follower})\n },\n {\n \"name\": \"SEND_METRICS_TO_FOLLOWER\",\n \"value\": str(${Slot_send_metrics_to_follower})\n },\n {\n \"name\": \"ENABLE_PACKING\",\n \"value\": str(${Slot_enable_packing})\n },\n {\n \"name\": \"ES_BATCH_SIZE\",\n \"value\": str(${Slot_es_batch_size})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_tree_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_mem}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_mem}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": 1\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_cat_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "类别类型特征,特征的值需要是非负整数。以逗号分隔如:alive,country,sex",
+ "label": "类别类型特征",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_path": {
+ "default": "",
+ "default_value": "",
+ "help": "数据存放位置",
+ "label": "数据存放位置",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "default": "",
+ "default_value": "",
+ "help": "求交数据集名称",
+ "label": "求交数据集名称",
+ "reference": "workflow.jobs['psi-data-join-job'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_enable_packing": {
+ "default": "",
+ "default_value": true,
+ "help": "是否开启优化",
+ "label": "是否开启优化",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_es_batch_size": {
+ "default": "",
+ "default_value": 10.0,
+ "help": "ES_BATCH_SIZE",
+ "label": "ES_BATCH_SIZE",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_file_ext": {
+ "default": "",
+ "default_value": ".data",
+ "help": "文件后缀",
+ "label": "文件后缀",
+ "reference": "self.variables.file_ext",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_file_type": {
+ "default": "",
+ "default_value": "csv",
+ "help": "文件类型,csv或tfrecord",
+ "label": "文件类型,csv或tfrecord",
+ "reference": "self.variables.file_type",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_ignore_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "以逗号分隔如:name,age,sex",
+ "label": "不入模的特征",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_l2_regularization": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "L2惩罚系数",
+ "label": "L2惩罚系数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_label_field": {
+ "default": "",
+ "default_value": "label",
+ "help": "label特征名",
+ "label": "label特征名",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_learning_rate": {
+ "default": "",
+ "default_value": 0.3,
+ "help": "学习率",
+ "label": "学习率",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_load_model_name": {
+ "default": "",
+ "default_value": "",
+ "help": "按任务名称加载模型,{STORAGE_ROOT_PATH}/job_output/{LOAD_MODEL_NAME}/exported_models",
+ "label": "模型任务名称",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_model_path": {
+ "default": "",
+ "default_value": "",
+ "help": "模型文件地址",
+ "label": "模型文件地址",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_loss_type": {
+ "default": "",
+ "default_value": "logistic",
+ "help": "损失函数类型,logistic或mse,默认logistic",
+ "label": "损失函数类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_max_bins": {
+ "default": "",
+ "default_value": 33.0,
+ "help": "最大分箱数",
+ "label": "最大分箱数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_max_depth": {
+ "default": "",
+ "default_value": 3.0,
+ "help": "最大深度",
+ "label": "最大深度",
+ "reference": "self.variables.max_depth",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_max_iters": {
+ "default": "",
+ "default_value": 5.0,
+ "help": "树的数量",
+ "label": "迭代数",
+ "reference": "self.variables.max_iters",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_mode": {
+ "default": "",
+ "default_value": "train",
+ "help": "任务类型,train或eval",
+ "label": "任务类型,train或eval",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_no_data": {
+ "default": "",
+ "default_value": false,
+ "help": "Leader是否没数据",
+ "label": "Leader是否没数据",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_num_parallel": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "进程数量",
+ "label": "进程数量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_send_metrics_to_follower": {
+ "default": "",
+ "default_value": false,
+ "help": "是否发送指标到follower",
+ "label": "是否发送指标到follower",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_send_scores_to_follower": {
+ "default": "",
+ "default_value": false,
+ "help": "是否发送结果到follower",
+ "label": "是否发送结果到follower",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_validation_data_path": {
+ "default": "",
+ "default_value": "",
+ "help": "验证数据集地址",
+ "label": "验证数据集地址",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_verbosity": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "日志输出等级",
+ "label": "日志输出等级",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_verify_example_ids": {
+ "default": "",
+ "default_value": false,
+ "help": "是否检查example_id对齐 If set to true, the first column of the data will be treated as example ids that must match between leader and follower",
+ "label": "是否检查example_id对齐",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "所需CPU",
+ "label": "所需CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_mem": {
+ "default": "",
+ "default_value": "4Gi",
+ "help": "所需内存",
+ "label": "所需内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ }
+ }
+ },
+ "group_alias": "e2e-test",
+ "name": "e2e-fed-left"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-fed-right.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-fed-right.json
new file mode 100644
index 000000000..42d0fa7fc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-fed-right.json
@@ -0,0 +1,1994 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "e2e-test",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": true,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job-streaming",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": \"Streaming\"\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": str(workflow.variables.input_base_dir)\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(workflow.variables.file_wildcard)\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(1024)\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": str(workflow.variables.input_data_format)\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [],
+ "easy_mode": true,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job-psi",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": \"PSI\"\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": str(workflow.variables.input_base_dir)\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(workflow.variables.file_wildcard)\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(1024)\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": str(workflow.variables.input_data_format)\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "raw-data-job-streaming"
+ },
+ {
+ "source": "raw-data-job-psi"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "DATA_JOIN",
+ "name": "data-join-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(workflow.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(workflow.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": \"--batch_mode\"\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(0)\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(999999999999)\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-streaming'].name)\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n list(system.variables.volume_mounts_list)\n ,\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n list(system.variables.volumes_list)\n\n }\n },\n \"pair\": true,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-streaming'].name)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"MIN_MATCHING_WINDOW\",\n \"value\": str(1024)\n },\n {\n \"name\": \"MAX_MATCHING_WINDOW\",\n \"value\": str(4096)\n },\n {\n \"name\": \"RAW_DATA_ITER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n list(system.variables.volume_mounts_list)\n ,\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n list(system.variables.volumes_list)\n\n }\n },\n \"pair\": true,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "raw-data-job-streaming"
+ },
+ {
+ "source": "raw-data-job-psi"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "PSI_DATA_JOIN",
+ "name": "psi-data-join-job",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_pem",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"TextArea\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_private_key_path",
+ "tag": "",
+ "typed_value": "/app/deploy/integrated_test/rsa_private.key",
+ "value": "/app/deploy/integrated_test/rsa_private.key",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_path",
+ "tag": "",
+ "typed_value": "/app/deploy/integrated_test/rsa_private.key",
+ "value": "/app/deploy/integrated_test/rsa_private.key",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(workflow.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(workflow.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(0)\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(999999999999)\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-psi'].name)\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": \"--batch_mode\"\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-psi'].name)\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": str(self.variables.rsa_key_pem)\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": str(self.variables.rsa_key_path)\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": str(self.variables.rsa_private_key_path)\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": \"data.aml.fl\"\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(None)\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(None)\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(False)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": int(int(workflow.variables.num_partitions))\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "data-join-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": true,
+ "job_type": "NN_MODEL_TRANINING",
+ "name": "nn-train",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "epoch_num",
+ "tag": "",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "code_tar",
+ "tag": "",
+ "typed_value": {
+ "follower/config.py": "leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\nleader_label_name = ['label']\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\n",
+ "follower/main.py": "# Copyright 2020 The FedLearner Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# coding: utf-8\n# pylint: disable=no-else-return, inconsistent-return-statements\n\nimport os\nimport logging\nimport numpy as np\nimport tensorflow.compat.v1 as tf\nimport fedlearner.trainer as flt\nfrom config import *\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\n\nROLE = 'follower'\n\nparser = flt.trainer_worker.create_argument_parser()\nparser.add_argument('--batch-size',\n type=int,\n default=100,\n help='Training batch size.')\nargs = parser.parse_args()\n\n\ndef input_fn(bridge, trainer_master):\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\n trainer_master).make_dataset()\n\n def parse_fn(example):\n feature_map = dict()\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\n for name in follower_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n features = tf.parse_example(example, features=feature_map)\n return features, dict(y=tf.constant(0))\n\n dataset = dataset.map(map_func=parse_fn,\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\n return dataset\n\n\ndef serving_input_receiver_fn():\n feature_map = {\n \"example_id\": tf.FixedLenFeature([], tf.string),\n \"raw_id\": tf.FixedLenFeature([], tf.string),\n }\n for name in follower_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\n features = tf.parse_example(record_batch, features=feature_map)\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\n\n\ndef model_fn(model, features, labels, mode):\n logging.info('model_fn: mode %s', mode)\n x = [\n tf.expand_dims(features[name], axis=-1)\n for name in follower_feature_names\n ]\n x = tf.concat(x, axis=-1)\n\n w1f = tf.get_variable(\n 'w1l',\n shape=[len(follower_feature_names),\n len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\n b1f = tf.get_variable('b1l',\n shape=[len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.zeros_initializer())\n\n act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)\n\n if mode == tf.estimator.ModeKeys.PREDICT:\n return model.make_spec(mode=mode, predictions=act1_f)\n\n if mode == tf.estimator.ModeKeys.TRAIN:\n gact1_f = model.send('act1_f', act1_f, require_grad=True)\n elif mode == tf.estimator.ModeKeys.EVAL:\n model.send('act1_f', act1_f, require_grad=False)\n\n #acc = model.recv('acc', tf.float32, require_grad=False)\n auc = model.recv('auc', tf.float32, require_grad=False)\n loss = model.recv('loss', tf.float32, require_grad=False)\n logging_hook = tf.train.LoggingTensorHook({\n 'auc': auc, 'loss': loss,\n }, every_n_iter=10)\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\n\n global_step = tf.train.get_or_create_global_step()\n if mode == tf.estimator.ModeKeys.TRAIN:\n optimizer = tf.train.GradientDescentOptimizer(0.1)\n train_op = model.minimize(optimizer,\n act1_f,\n grad_loss=gact1_f,\n global_step=global_step)\n return model.make_spec(mode,\n loss=tf.math.reduce_mean(act1_f),\n train_op=train_op,\n training_hooks=[logging_hook, step_metric_hook])\n if mode == tf.estimator.ModeKeys.EVAL:\n fake_loss = tf.reduce_mean(act1_f)\n return model.make_spec(mode=mode, loss=fake_loss,\n evaluation_hooks=[logging_hook, step_metric_hook])\n\n\nif __name__ == '__main__':\n logging.basicConfig(level=logging.INFO)\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\n serving_input_receiver_fn)\n",
+ "leader/config.py": "leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\nleader_label_name = ['label']\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\n",
+ "leader/main.py": "# Copyright 2020 The FedLearner Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# coding: utf-8\n# pylint: disable=no-else-return, inconsistent-return-statements\n\nimport os\nimport logging\nimport tensorflow.compat.v1 as tf\nimport fedlearner.trainer as flt\nfrom config import *\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\n\nROLE = 'leader'\n\nparser = flt.trainer_worker.create_argument_parser()\nparser.add_argument('--batch-size',\n type=int,\n default=100,\n help='Training batch size.')\nargs = parser.parse_args()\n\n\ndef input_fn(bridge, trainer_master):\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\n trainer_master).make_dataset()\n\n def parse_fn(example):\n feature_map = dict()\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\n for name in leader_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n label_map = {}\n for name in leader_label_name:\n label_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n features = tf.parse_example(example, features=feature_map)\n labels = tf.parse_example(example, features=label_map)\n return features, labels\n\n dataset = dataset.map(map_func=parse_fn,\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\n return dataset\n\n\ndef serving_input_receiver_fn():\n feature_map = {\n \"example_id\": tf.FixedLenFeature([], tf.string),\n \"raw_id\": tf.FixedLenFeature([], tf.string),\n }\n for name in leader_feature_names:\n feature_map[name] = tf.FixedLenFeature([],\n tf.float32,\n default_value=0.0)\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\n features = tf.parse_example(record_batch, features=feature_map)\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\n\n\ndef model_fn(model, features, labels, mode):\n logging.info('model_fn: mode %s', mode)\n x = [\n tf.expand_dims(features[name], axis=-1)\n for name in leader_feature_names\n ]\n x = tf.concat(x, axis=-1)\n\n w1l = tf.get_variable(\n 'w1l',\n shape=[len(leader_feature_names),\n len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\n b1l = tf.get_variable('b1l',\n shape=[len(leader_label_name)],\n dtype=tf.float32,\n initializer=tf.zeros_initializer())\n\n act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)\n if mode == tf.estimator.ModeKeys.TRAIN:\n act1_f = model.recv('act1_f', tf.float32, require_grad=True)\n elif mode == tf.estimator.ModeKeys.EVAL:\n act1_f = model.recv('act1_f', tf.float32, require_grad=False)\n else:\n act1_f = features['act1_f']\n logits = act1_l + act1_f\n pred = tf.math.sigmoid(logits)\n\n if mode == tf.estimator.ModeKeys.PREDICT:\n return model.make_spec(mode=mode, predictions=pred)\n\n y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]\n y = tf.concat(y, axis=-1)\n\n loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n _, auc = tf.metrics.auc(labels=y, predictions=pred)\n #correct = tf.nn.in_top_k(predictions=logits, targets=y, k=1)\n #acc = tf.reduce_mean(input_tensor=tf.cast(correct, tf.float32))\n logging_hook = tf.train.LoggingTensorHook({\n # 'acc': acc,\n 'auc': auc,\n 'loss': loss,\n }, every_n_iter=10)\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\n #model.send('acc', acc, require_grad=False)\n model.send('auc', auc, require_grad=False)\n model.send('loss', loss, require_grad=False)\n\n global_step = tf.train.get_or_create_global_step()\n if mode == tf.estimator.ModeKeys.TRAIN:\n optimizer = tf.train.AdamOptimizer(1e-4)\n train_op = model.minimize(optimizer, loss, global_step=global_step)\n return model.make_spec(mode=mode,\n loss=loss,\n train_op=train_op,\n training_hooks=[logging_hook, step_metric_hook])\n\n if mode == tf.estimator.ModeKeys.EVAL:\n loss_pair = tf.metrics.mean(loss)\n return model.make_spec(mode=mode,\n loss=loss,\n eval_metric_ops={'loss': loss_pair},\n evaluation_hooks=[logging_hook, step_metric_hook])\n\n\nif __name__ == '__main__':\n logging.basicConfig(level=logging.INFO)\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\n serving_input_receiver_fn)\n",
+ "main.py": ""
+ },
+ "value": "{\"main.py\":\"\",\"leader/main.py\":\"# Copyright 2020 The FedLearner Authors. All Rights Reserved.\\n#\\n# Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n# you may not use this file except in compliance with the License.\\n# You may obtain a copy of the License at\\n#\\n# http://www.apache.org/licenses/LICENSE-2.0\\n#\\n# Unless required by applicable law or agreed to in writing, software\\n# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n# See the License for the specific language governing permissions and\\n# limitations under the License.\\n\\n# coding: utf-8\\n# pylint: disable=no-else-return, inconsistent-return-statements\\n\\nimport os\\nimport logging\\nimport tensorflow.compat.v1 as tf\\nimport fedlearner.trainer as flt\\nfrom config import *\\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\\n\\nROLE = 'leader'\\n\\nparser = flt.trainer_worker.create_argument_parser()\\nparser.add_argument('--batch-size',\\n type=int,\\n default=100,\\n help='Training batch size.')\\nargs = parser.parse_args()\\n\\n\\ndef input_fn(bridge, trainer_master):\\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\\n trainer_master).make_dataset()\\n\\n def parse_fn(example):\\n feature_map = dict()\\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\\n for name in leader_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n label_map = {}\\n for name in leader_label_name:\\n label_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n features = tf.parse_example(example, features=feature_map)\\n labels = tf.parse_example(example, features=label_map)\\n return features, labels\\n\\n dataset = dataset.map(map_func=parse_fn,\\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\\n return dataset\\n\\n\\ndef serving_input_receiver_fn():\\n feature_map = {\\n \\\"example_id\\\": tf.FixedLenFeature([], tf.string),\\n \\\"raw_id\\\": tf.FixedLenFeature([], tf.string),\\n }\\n for name in leader_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\\n features = tf.parse_example(record_batch, features=feature_map)\\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\\n\\n\\ndef model_fn(model, features, labels, mode):\\n logging.info('model_fn: mode %s', mode)\\n x = [\\n tf.expand_dims(features[name], axis=-1)\\n for name in leader_feature_names\\n ]\\n x = tf.concat(x, axis=-1)\\n\\n w1l = tf.get_variable(\\n 'w1l',\\n shape=[len(leader_feature_names),\\n len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\\n b1l = tf.get_variable('b1l',\\n shape=[len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.zeros_initializer())\\n\\n act1_l = tf.nn.bias_add(tf.matmul(x, w1l), b1l)\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n act1_f = model.recv('act1_f', tf.float32, require_grad=True)\\n elif mode == tf.estimator.ModeKeys.EVAL:\\n act1_f = model.recv('act1_f', tf.float32, require_grad=False)\\n else:\\n act1_f = features['act1_f']\\n logits = act1_l + act1_f\\n pred = tf.math.sigmoid(logits)\\n\\n if mode == tf.estimator.ModeKeys.PREDICT:\\n return model.make_spec(mode=mode, predictions=pred)\\n\\n y = [tf.expand_dims(labels[name], axis=-1) for name in leader_label_name]\\n y = tf.concat(y, axis=-1)\\n\\n loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\\n _, auc = tf.metrics.auc(labels=y, predictions=pred)\\n #correct = tf.nn.in_top_k(predictions=logits, targets=y, k=1)\\n #acc = tf.reduce_mean(input_tensor=tf.cast(correct, tf.float32))\\n logging_hook = tf.train.LoggingTensorHook({\\n # 'acc': acc,\\n 'auc': auc,\\n 'loss': loss,\\n }, every_n_iter=10)\\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\\n #model.send('acc', acc, require_grad=False)\\n model.send('auc', auc, require_grad=False)\\n model.send('loss', loss, require_grad=False)\\n\\n global_step = tf.train.get_or_create_global_step()\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n optimizer = tf.train.AdamOptimizer(1e-4)\\n train_op = model.minimize(optimizer, loss, global_step=global_step)\\n return model.make_spec(mode=mode,\\n loss=loss,\\n train_op=train_op,\\n training_hooks=[logging_hook, step_metric_hook])\\n\\n if mode == tf.estimator.ModeKeys.EVAL:\\n loss_pair = tf.metrics.mean(loss)\\n return model.make_spec(mode=mode,\\n loss=loss,\\n eval_metric_ops={'loss': loss_pair},\\n evaluation_hooks=[logging_hook, step_metric_hook])\\n\\n\\nif __name__ == '__main__':\\n logging.basicConfig(level=logging.INFO)\\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\\n serving_input_receiver_fn)\\n\",\"follower/main.py\":\"# Copyright 2020 The FedLearner Authors. All Rights Reserved.\\n#\\n# Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\\n# you may not use this file except in compliance with the License.\\n# You may obtain a copy of the License at\\n#\\n# http://www.apache.org/licenses/LICENSE-2.0\\n#\\n# Unless required by applicable law or agreed to in writing, software\\n# distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\\n# See the License for the specific language governing permissions and\\n# limitations under the License.\\n\\n# coding: utf-8\\n# pylint: disable=no-else-return, inconsistent-return-statements\\n\\nimport os\\nimport logging\\nimport numpy as np\\nimport tensorflow.compat.v1 as tf\\nimport fedlearner.trainer as flt\\nfrom config import *\\nfrom fedlearner.trainer.trainer_worker import StepLossAucMetricsHook\\n\\nROLE = 'follower'\\n\\nparser = flt.trainer_worker.create_argument_parser()\\nparser.add_argument('--batch-size',\\n type=int,\\n default=100,\\n help='Training batch size.')\\nargs = parser.parse_args()\\n\\n\\ndef input_fn(bridge, trainer_master):\\n dataset = flt.data.DataBlockLoader(args.batch_size, ROLE, bridge,\\n trainer_master).make_dataset()\\n\\n def parse_fn(example):\\n feature_map = dict()\\n feature_map['example_id'] = tf.FixedLenFeature([], tf.string)\\n feature_map['raw_id'] = tf.FixedLenFeature([], tf.string)\\n for name in follower_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n features = tf.parse_example(example, features=feature_map)\\n return features, dict(y=tf.constant(0))\\n\\n dataset = dataset.map(map_func=parse_fn,\\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\\n return dataset\\n\\n\\ndef serving_input_receiver_fn():\\n feature_map = {\\n \\\"example_id\\\": tf.FixedLenFeature([], tf.string),\\n \\\"raw_id\\\": tf.FixedLenFeature([], tf.string),\\n }\\n for name in follower_feature_names:\\n feature_map[name] = tf.FixedLenFeature([],\\n tf.float32,\\n default_value=0.0)\\n record_batch = tf.placeholder(dtype=tf.string, name='examples')\\n features = tf.parse_example(record_batch, features=feature_map)\\n features['act1_f'] = tf.placeholder(dtype=tf.float32, name='act1_f')\\n receiver_tensors = {'examples': record_batch, 'act1_f': features['act1_f']}\\n return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)\\n\\n\\ndef model_fn(model, features, labels, mode):\\n logging.info('model_fn: mode %s', mode)\\n x = [\\n tf.expand_dims(features[name], axis=-1)\\n for name in follower_feature_names\\n ]\\n x = tf.concat(x, axis=-1)\\n\\n w1f = tf.get_variable(\\n 'w1l',\\n shape=[len(follower_feature_names),\\n len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.random_uniform_initializer(-0.01, 0.01))\\n b1f = tf.get_variable('b1l',\\n shape=[len(leader_label_name)],\\n dtype=tf.float32,\\n initializer=tf.zeros_initializer())\\n\\n act1_f = tf.nn.bias_add(tf.matmul(x, w1f), b1f)\\n\\n if mode == tf.estimator.ModeKeys.PREDICT:\\n return model.make_spec(mode=mode, predictions=act1_f)\\n\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n gact1_f = model.send('act1_f', act1_f, require_grad=True)\\n elif mode == tf.estimator.ModeKeys.EVAL:\\n model.send('act1_f', act1_f, require_grad=False)\\n\\n #acc = model.recv('acc', tf.float32, require_grad=False)\\n auc = model.recv('auc', tf.float32, require_grad=False)\\n loss = model.recv('loss', tf.float32, require_grad=False)\\n logging_hook = tf.train.LoggingTensorHook({\\n 'auc': auc, 'loss': loss,\\n }, every_n_iter=10)\\n step_metric_hook = StepLossAucMetricsHook(loss_tensor=loss, auc_tensor=auc)\\n\\n global_step = tf.train.get_or_create_global_step()\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n optimizer = tf.train.GradientDescentOptimizer(0.1)\\n train_op = model.minimize(optimizer,\\n act1_f,\\n grad_loss=gact1_f,\\n global_step=global_step)\\n return model.make_spec(mode,\\n loss=tf.math.reduce_mean(act1_f),\\n train_op=train_op,\\n training_hooks=[logging_hook, step_metric_hook])\\n if mode == tf.estimator.ModeKeys.EVAL:\\n fake_loss = tf.reduce_mean(act1_f)\\n return model.make_spec(mode=mode, loss=fake_loss,\\n evaluation_hooks=[logging_hook, step_metric_hook])\\n\\n\\nif __name__ == '__main__':\\n logging.basicConfig(level=logging.INFO)\\n flt.trainer_worker.train(ROLE, args, input_fn, model_fn,\\n serving_input_receiver_fn)\\n\",\"follower/config.py\":\"leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\\nleader_label_name = ['label']\\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\\n\",\"leader/config.py\":\"leader_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12']\\nleader_label_name = ['label']\\nfollower_feature_names = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\\n\"}",
+ "value_type": "CODE",
+ "widget_schema": "{\"component\":\"Code\",\"required\":true}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"activeDeadlineSeconds\": 1200, \n \"fedReplicaSpecs\": {\n \"Master\": {\n \"backoffLimit\": 1,\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(int(self.variables.epoch_num))\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(None)\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(None)\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": workflow.jobs['data-join-job'].name\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": \"\"\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": \"\"\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": self.variables.code_tar\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": \"\" and project.variables.storage_root_path + \"/job_output/\" + \"\" + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": \"\"\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n\n }\n },\n \"replicas\": int(1)\n },\n \"PS\": {\n \"backoffLimit\": 1,\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(1)\n },\n \"Worker\": {\n \"backoffLimit\": 6,\n \"mustSuccess\": True,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": \"train\"\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(1)\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": \"\"\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": self.variables.code_tar\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(1000)\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(None)\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(None)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\":[\"/bin/bash\",\"-c\"],\n \"args\": [\"export WORKER_RANK=$$INDEX && export PEER_ADDR=$$SERVICE_ID && /app/deploy/scripts/trainer/run_trainer_worker.sh\"],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(1)\n }\n }\n }\n}"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "data-join-job"
+ },
+ {
+ "source": "psi-data-join-job"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "TREE_MODEL_TRAINING",
+ "name": "tree-train",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_depth",
+ "tag": "",
+ "typed_value": "3",
+ "value": "3",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_iters",
+ "tag": "",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_ext",
+ "tag": "",
+ "typed_value": ".data",
+ "value": ".data",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_type",
+ "tag": "",
+ "typed_value": "tfrecord",
+ "value": "tfrecord",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(workflow.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(workflow.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": \"train\"\n },\n {\n \"name\": \"LOSS_TYPE\",\n \"value\": \"logistic\"\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": str(workflow.jobs['psi-data-join-job'].name)\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"VALIDATION_DATA_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"NO_DATA\",\n \"value\": str(False)\n },\n {\n \"name\": \"FILE_EXT\",\n \"value\": str(self.variables.file_ext)\n },\n {\n \"name\": \"FILE_TYPE\",\n \"value\": str(self.variables.file_type)\n },\n {\n \"name\": \"LOAD_MODEL_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_MODEL_NAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(1)\n },\n {\n \"name\": \"LEARNING_RATE\",\n \"value\": str(0.3)\n },\n {\n \"name\": \"MAX_ITERS\",\n \"value\": str(int(self.variables.max_iters))\n },\n {\n \"name\": \"MAX_DEPTH\",\n \"value\": str(int(self.variables.max_depth))\n },\n {\n \"name\": \"MAX_BINS\",\n \"value\": str(33)\n },\n {\n \"name\": \"L2_REGULARIZATION\",\n \"value\": str(1.0)\n },\n {\n \"name\": \"NUM_PARALLEL\",\n \"value\": str(1)\n },\n {\n \"name\": \"VERIFY_EXAMPLE_IDS\",\n \"value\": str(False)\n },\n {\n \"name\": \"IGNORE_FIELDS\",\n \"value\": \"\"\n },\n {\n \"name\": \"CAT_FIELDS\",\n \"value\": \"\"\n },\n {\n \"name\": \"LABEL_FIELD\",\n \"value\": \"label\"\n },\n {\n \"name\": \"SEND_SCORES_TO_FOLLOWER\",\n \"value\": str(False)\n },\n {\n \"name\": \"SEND_METRICS_TO_FOLLOWER\",\n \"value\": str(False)\n },\n {\n \"name\": \"ENABLE_PACKING\",\n \"value\": str(True)\n },\n {\n \"name\": \"ES_BATCH_SIZE\",\n \"value\": str(10)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_tree_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"4Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"4Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": 1\n }\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "image",
+ "tag": "",
+ "typed_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "num_partitions",
+ "tag": "",
+ "typed_value": "1",
+ "value": "1",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "role",
+ "tag": "",
+ "typed_value": "Leader",
+ "value": "Leader",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"]}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "input_base_dir",
+ "tag": "",
+ "typed_value": "/app/deploy/integrated_test/credit_default",
+ "value": "/app/deploy/integrated_test/credit_default",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_wildcard",
+ "tag": "",
+ "typed_value": "*guest.csv",
+ "value": "*guest.csv",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "input_data_format",
+ "tag": "",
+ "typed_value": "CSV_DICT",
+ "value": "CSV_DICT",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"CSV_DICT\",\"TF_RECORD\"]}"
+ }
+ ]
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "data-join-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": ${Slot_batch_mode}\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(${Slot_start_time})\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(${Slot_end_time})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n ${Slot_volume_mounts}\n ,\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n ${Slot_volumes}\n\n }\n },\n \"pair\": true,\n \"replicas\": ${Slot_master_replicas}\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(${Slot_data_block_dump_interval})\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(${Slot_data_block_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(${Slot_example_id_dump_interval})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(${Slot_example_id_dump_threshold})\n },\n {\n \"name\": \"MIN_MATCHING_WINDOW\",\n \"value\": str(${Slot_min_matching_window})\n },\n {\n \"name\": \"MAX_MATCHING_WINDOW\",\n \"value\": str(${Slot_max_matching_window})\n },\n {\n \"name\": \"RAW_DATA_ITER\",\n \"value\": ${Slot_raw_data_iter}\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(${Slot_enable_negative_example_generator})\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n ${Slot_volume_mounts}\n ,\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n ${Slot_volumes}\n\n }\n },\n \"pair\": true,\n \"replicas\": ${Slot_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_mode": {
+ "default": "",
+ "default_value": "--batch_mode",
+ "help": "如果为空则为常驻求交",
+ "label": "是否为批处理模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次data block,小于0则无此限制",
+ "label": "数据dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_block_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多多少个样本就dump为一个data block,小于等于0则无此限制",
+ "label": "数据dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_enable_negative_example_generator": {
+ "default": "",
+ "default_value": false,
+ "help": "建议不修改,是否开启负采样,当follower求交时遇到无法匹配上的leader的example id,会以negative_sampling_rate为概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_end_time": {
+ "default": "",
+ "default_value": 999999999999.0,
+ "help": "建议不修改,使用自这个时间以前的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据末尾时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_max_matching_window": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,the max matching window for example join. <=0 means window size is infinite",
+ "label": "最大匹配滑窗",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_min_matching_window": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "建议不修改,the min matching window for example join ,<=0 means window size is infinite",
+ "label": "最小匹配滑窗",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_negative_sampling_rate": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,负采样比例,当follower求交时遇到无法匹配上的leader的example id,会以此概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "建议修改,求交后数据分区的数量,建议和raw_data一致",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_iter": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "raw_data文件类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_raw_data_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,原始数据的发布地址,根据参数内容在portal_publish_dir地址下寻找",
+ "label": "raw_data名字",
+ "reference": "workflow.jobs['raw-data-job-streaming'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_start_time": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,使用自这个时间起的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据起始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "nn-train": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(${Slot_epoch_num})\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(${Slot_start_date})\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(${Slot_end_date})\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": ${Slot_online_training}\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": ${Slot_checkpoint_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": ${Slot_load_checkpoint_filename}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": ${Slot_load_checkpoint_filename_with_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": ${Slot_load_checkpoint_from_job} and ${Slot_storage_root_path} + \"/job_output/\" + ${Slot_load_checkpoint_from_job} + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": ${Slot_export_path}\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_master_replicas})\n },\n \"PS\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + ${Slot_ps_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_ps_replicas})\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(${Slot_save_checkpoint_steps})\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(${Slot_save_checkpoint_secs})\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(${Slot_summary_save_steps})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_trainer_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_worker_replicas})\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_checkpoint_path": {
+ "default": "",
+ "default_value": "",
+ "help": "不建议修改,checkpoint输出路径,建议为空,会默认使用{storage_root_path}/job_output/{job_name}/checkpoints,强烈建议保持空值",
+ "label": "CHECKPOINT_PATH",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_key": {
+ "default": "",
+ "default_value": "",
+ "help": "代码tar包地址,如果为空则使用code tar",
+ "label": "模型代码路径",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_tar": {
+ "default": "",
+ "default_value": "",
+ "help": "代码包,variable中请使用代码类型",
+ "label": "代码",
+ "reference": "self.variables.code_tar",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,求交任务的名字",
+ "label": "数据源",
+ "reference": "workflow.jobs['data-join-job'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_end_date": {
+ "default": "",
+ "default_value": null,
+ "help": "training data end date",
+ "label": "结束时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_epoch_num": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "number of epoch for training, not support in online training",
+ "label": "epoch数量",
+ "reference": "self.variables.epoch_num",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_export_path": {
+ "default": "",
+ "default_value": "",
+ "help": "使用默认空值,将把models保存到$OUTPUT_BASE_DIR/exported_models 路径下。",
+ "label": "EXPORT_PATH",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模板不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_load_checkpoint_filename": {
+ "default": "",
+ "default_value": "",
+ "help": "加载checkpoint_path下的相对路径的checkpoint, 默认会加载checkpoint_path下的latest checkpoint",
+ "label": "LOAD_CHECKPOINT_FILENAME",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_filename_with_path": {
+ "default": "",
+ "default_value": "",
+ "help": "加载绝对路径下的checkpoint,需要细致到文件名",
+ "label": "从绝对路径加载checkpoint",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_from_job": {
+ "default": "",
+ "default_value": "",
+ "help": "指定任务名job_output下的latest checkpoint",
+ "label": "以任务名加载checkpoint",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_mode": {
+ "default": "",
+ "default_value": "train",
+ "help": "choices:['train','eval'] 训练还是验证",
+ "label": "模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_online_training": {
+ "default": "",
+ "default_value": "",
+ "help": "['','--online_training'] 否 是,the train master run for online training",
+ "label": "是否在线训练",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "PS的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,ps pod额外的环境变量",
+ "label": "PS额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_ps_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "PS的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的PS Pods数量",
+ "label": "PS的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_save_checkpoint_secs": {
+ "default": "",
+ "default_value": null,
+ "help": "int,Number of secs between checkpoints.",
+ "label": "SAVE_CHECKPOINT_SECS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_save_checkpoint_steps": {
+ "default": "",
+ "default_value": 1000.0,
+ "help": "int, Number of steps between checkpoints.",
+ "label": "SAVE_CHECKPOINT_STEPS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_sparse_estimator": {
+ "default": "",
+ "default_value": false,
+ "help": "bool,default False Whether using sparse estimator.",
+ "label": "SPARSE_ESTIMATOR",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_start_date": {
+ "default": "",
+ "default_value": null,
+ "help": "training data start date",
+ "label": "开始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_suffle_data_block": {
+ "default": "",
+ "default_value": "",
+ "help": "['','--shuffle_data_block'] 否 是,shuffle the data block or not",
+ "label": "是否shuffle数据块",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_summary_save_steps": {
+ "default": "",
+ "default_value": null,
+ "help": "int, Number of steps to save summary files.",
+ "label": "SUMMARY_SAVE_STEPS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_verbosity": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "int, Logging level",
+ "label": "日志等级",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Worker Pods数量",
+ "label": "Worker的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ }
+ },
+ "variables": []
+ },
+ "psi-data-join-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(${Slot_start_time})\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(${Slot_end_time})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": ${Slot_batch_mode}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": ${Slot_rsa_key_pem}\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": ${Slot_rsa_key_path}\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": ${Slot_rsa_private_key_path}\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": ${Slot_kms_key_name}\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": ${Slot_kms_client}\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": ${Slot_psi_raw_data_iter}\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": ${Slot_data_block_builder}\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": ${Slot_psi_output_builder}\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(${Slot_data_block_dump_interval})\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(${Slot_data_block_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(${Slot_example_id_dump_interval})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(${Slot_example_id_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(${Slot_psi_read_ahead_size})\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(${Slot_run_merger_read_ahead_buffer})\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(${Slot_enable_negative_example_generator})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_partition_num})\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_mode": {
+ "default": "",
+ "default_value": "--batch_mode",
+ "help": "如果为空则为常驻求交",
+ "label": "是否为批处理模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "data block output数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次data block,小于0则无此限制",
+ "label": "数据dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_block_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多多少个样本就dump为一个data block,小于等于0则无此限制",
+ "label": "数据dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_enable_negative_example_generator": {
+ "default": "",
+ "default_value": false,
+ "help": "建议不修改,是否开启负采样,当follower求交时遇到无法匹配上的leader的example id,会以negative_sampling_rate为概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_end_time": {
+ "default": "",
+ "default_value": 999999999999.0,
+ "help": "建议不修改,使用自这个时间以前的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据末尾时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_kms_client": {
+ "default": "",
+ "default_value": "data.aml.fl",
+ "help": "kms client",
+ "label": "kms client",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_kms_key_name": {
+ "default": "",
+ "default_value": "",
+ "help": "kms中的密钥名称,站内镜像需使用KMS",
+ "label": "密钥名称",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_negative_sampling_rate": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,负采样比例,当follower求交时遇到无法匹配上的leader的example id,会以此概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "建议修改,求交后数据分区的数量,建议和raw_data一致",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_psi_output_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "PSI output数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_psi_raw_data_iter": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "raw data数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_psi_read_ahead_size": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, the read ahead size for raw data",
+ "label": "psi_read_ahead_size",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,原始数据的发布地址,根据参数内容在portal_publish_dir地址下寻找",
+ "label": "raw_data名字",
+ "reference": "workflow.jobs['raw-data-job-psi'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_path": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA公钥或私钥的地址,在无RSA_KEY_PEM时必填",
+ "label": "RSA钥匙地址",
+ "reference": "self.variables.rsa_key_path",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_pem": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA公钥,follower需提供",
+ "label": "RSA公钥",
+ "reference": "self.variables.rsa_key_pem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_private_key_path": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA私钥的地址, leader必填",
+ "label": "RSA私钥地址",
+ "reference": "self.variables.rsa_private_key_path",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_run_merger_read_ahead_buffer": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, sort run merger read ahead buffer",
+ "label": "run_merger_read_ahead_buffer",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_start_time": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,使用自这个时间起的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据起始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "raw-data-job-psi": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "PSI",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "workflow.variables.file_wildcard",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "workflow.variables.input_base_dir",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "workflow.variables.input_data_format",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "raw-data-job-streaming": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "Streaming",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "workflow.variables.file_wildcard",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "workflow.variables.input_base_dir",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "workflow.variables.input_data_format",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "tree-train": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"LOSS_TYPE\",\n \"value\": ${Slot_loss_type}\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": ${Slot_data_path}\n },\n {\n \"name\": \"VALIDATION_DATA_PATH\",\n \"value\": ${Slot_validation_data_path}\n },\n {\n \"name\": \"NO_DATA\",\n \"value\": str(${Slot_no_data})\n },\n {\n \"name\": \"FILE_EXT\",\n \"value\": ${Slot_file_ext}\n },\n {\n \"name\": \"FILE_TYPE\",\n \"value\": ${Slot_file_type}\n },\n {\n \"name\": \"LOAD_MODEL_PATH\",\n \"value\": ${Slot_load_model_path}\n },\n {\n \"name\": \"LOAD_MODEL_NAME\",\n \"value\": ${Slot_load_model_name}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"LEARNING_RATE\",\n \"value\": str(${Slot_learning_rate})\n },\n {\n \"name\": \"MAX_ITERS\",\n \"value\": str(${Slot_max_iters})\n },\n {\n \"name\": \"MAX_DEPTH\",\n \"value\": str(${Slot_max_depth})\n },\n {\n \"name\": \"MAX_BINS\",\n \"value\": str(${Slot_max_bins})\n },\n {\n \"name\": \"L2_REGULARIZATION\",\n \"value\": str(${Slot_l2_regularization})\n },\n {\n \"name\": \"NUM_PARALLEL\",\n \"value\": str(${Slot_num_parallel})\n },\n {\n \"name\": \"VERIFY_EXAMPLE_IDS\",\n \"value\": str(${Slot_verify_example_ids})\n },\n {\n \"name\": \"IGNORE_FIELDS\",\n \"value\": ${Slot_ignore_fields}\n },\n {\n \"name\": \"CAT_FIELDS\",\n \"value\": ${Slot_cat_fields}\n },\n {\n \"name\": \"LABEL_FIELD\",\n \"value\": ${Slot_label_field}\n },\n {\n \"name\": \"SEND_SCORES_TO_FOLLOWER\",\n \"value\": str(${Slot_send_scores_to_follower})\n },\n {\n \"name\": \"SEND_METRICS_TO_FOLLOWER\",\n \"value\": str(${Slot_send_metrics_to_follower})\n },\n {\n \"name\": \"ENABLE_PACKING\",\n \"value\": str(${Slot_enable_packing})\n },\n {\n \"name\": \"ES_BATCH_SIZE\",\n \"value\": str(${Slot_es_batch_size})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_tree_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_mem}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_mem}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": 1\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_cat_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "类别类型特征,特征的值需要是非负整数。以逗号分隔如:alive,country,sex",
+ "label": "类别类型特征",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_path": {
+ "default": "",
+ "default_value": "",
+ "help": "数据存放位置",
+ "label": "数据存放位置",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "default": "",
+ "default_value": "",
+ "help": "求交数据集名称",
+ "label": "求交数据集名称",
+ "reference": "workflow.jobs['psi-data-join-job'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_enable_packing": {
+ "default": "",
+ "default_value": true,
+ "help": "是否开启优化",
+ "label": "是否开启优化",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_es_batch_size": {
+ "default": "",
+ "default_value": 10.0,
+ "help": "ES_BATCH_SIZE",
+ "label": "ES_BATCH_SIZE",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_file_ext": {
+ "default": "",
+ "default_value": ".csv",
+ "help": "文件后缀",
+ "label": "文件后缀",
+ "reference": "self.variables.file_ext",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_file_type": {
+ "default": "",
+ "default_value": "csv",
+ "help": "文件类型,csv或tfrecord",
+ "label": "文件类型,csv或tfrecord",
+ "reference": "self.variables.file_type",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_ignore_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "以逗号分隔如:name,age,sex",
+ "label": "不入模的特征",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_l2_regularization": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "L2惩罚系数",
+ "label": "L2惩罚系数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_label_field": {
+ "default": "",
+ "default_value": "label",
+ "help": "label特征名",
+ "label": "label特征名",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_learning_rate": {
+ "default": "",
+ "default_value": 0.3,
+ "help": "学习率",
+ "label": "学习率",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_load_model_name": {
+ "default": "",
+ "default_value": "",
+ "help": "按任务名称加载模型,{STORAGE_ROOT_PATH}/job_output/{LOAD_MODEL_NAME}/exported_models",
+ "label": "模型任务名称",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_model_path": {
+ "default": "",
+ "default_value": "",
+ "help": "模型文件地址",
+ "label": "模型文件地址",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_loss_type": {
+ "default": "",
+ "default_value": "logistic",
+ "help": "损失函数类型,logistic或mse,默认logistic",
+ "label": "损失函数类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_max_bins": {
+ "default": "",
+ "default_value": 33.0,
+ "help": "最大分箱数",
+ "label": "最大分箱数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_max_depth": {
+ "default": "",
+ "default_value": 3.0,
+ "help": "最大深度",
+ "label": "最大深度",
+ "reference": "self.variables.max_depth",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_max_iters": {
+ "default": "",
+ "default_value": 5.0,
+ "help": "树的数量",
+ "label": "迭代数",
+ "reference": "self.variables.max_iters",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_mode": {
+ "default": "",
+ "default_value": "train",
+ "help": "任务类型,train或eval",
+ "label": "任务类型,train或eval",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_no_data": {
+ "default": "",
+ "default_value": false,
+ "help": "Leader是否没数据",
+ "label": "Leader是否没数据",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_num_parallel": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "进程数量",
+ "label": "进程数量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_send_metrics_to_follower": {
+ "default": "",
+ "default_value": false,
+ "help": "是否发送指标到follower",
+ "label": "是否发送指标到follower",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_send_scores_to_follower": {
+ "default": "",
+ "default_value": false,
+ "help": "是否发送结果到follower",
+ "label": "是否发送结果到follower",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_validation_data_path": {
+ "default": "",
+ "default_value": "",
+ "help": "验证数据集地址",
+ "label": "验证数据集地址",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_verbosity": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "日志输出等级",
+ "label": "日志输出等级",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_verify_example_ids": {
+ "default": "",
+ "default_value": false,
+ "help": "是否检查example_id对齐 If set to true, the first column of the data will be treated as example ids that must match between leader and follower",
+ "label": "是否检查example_id对齐",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "所需CPU",
+ "label": "所需CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_mem": {
+ "default": "",
+ "default_value": "4Gi",
+ "help": "所需内存",
+ "label": "所需内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ }
+ }
+ },
+ "group_alias": "e2e-test",
+ "name": "e2e-fed-right"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-local.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-local.json
new file mode 100644
index 000000000..fa5294234
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-local.json
@@ -0,0 +1,285 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "e2e-test",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": true,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": \"Streaming\"\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(4)\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": \"/app/deploy/integrated_test/tfrecord_raw_data\"\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": \"*.rd\"\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"--check_success_tag\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(1024)\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 4\n }\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "image",
+ "typed_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ }
+ ]
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "raw-data-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "--check_success_tag",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "Streaming",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ }
+ }
+ },
+ "group_alias": "e2e-test",
+ "name": "e2e-local"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-sparse-estimator-test-right.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-sparse-estimator-test-right.json
new file mode 100644
index 000000000..742d28904
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/e2e-sparse-estimator-test-right.json
@@ -0,0 +1,1127 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "e2e-test",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": true,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job-psi",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": \"PSI\"\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": str(workflow.variables.input_base_dir)\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(workflow.variables.file_wildcard)\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(1024)\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": str(workflow.variables.input_data_format)\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "raw-data-job-psi"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "PSI_DATA_JOIN",
+ "name": "psi-data-join-job",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_pem",
+ "tag": "",
+ "typed_value": "-----BEGIN RSA PUBLIC KEY-----\nMIGJAoGBAMZYpBzYDnROmrqC8LhDXhgW13E/JuTUHkHKsGwPScnp5TAueqo53ayu\nYzSlLrI+yQp206Kb/C+w/VdWJcLLAjAUBGqfZvCnsmpfOMt+s3JrNH24RCg282m/\nnIdpoVqb7SEDFlJPq3s0g/oZ5v0c74Yy5J/DuuaWcuU7URuYRbbnAgMBAAE=\n-----END RSA PUBLIC KEY-----\n",
+ "value": "-----BEGIN RSA PUBLIC KEY-----\nMIGJAoGBAMZYpBzYDnROmrqC8LhDXhgW13E/JuTUHkHKsGwPScnp5TAueqo53ayu\nYzSlLrI+yQp206Kb/C+w/VdWJcLLAjAUBGqfZvCnsmpfOMt+s3JrNH24RCg282m/\nnIdpoVqb7SEDFlJPq3s0g/oZ5v0c74Yy5J/DuuaWcuU7URuYRbbnAgMBAAE=\n-----END RSA PUBLIC KEY-----\n",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"TextArea\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_private_key_path",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_path",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_cpu",
+ "tag": "",
+ "typed_value": "2000m",
+ "value": "2000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_mem",
+ "tag": "",
+ "typed_value": "3Gi",
+ "value": "3Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(workflow.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(workflow.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(0)\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(999999999999)\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-psi'].name)\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": \"2000m\",\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": \"--batch_mode\"\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job-psi'].name)\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": str(self.variables.rsa_key_pem)\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": str(self.variables.rsa_key_path)\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": str(self.variables.rsa_private_key_path)\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": \"data.aml.fl\"\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(None)\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(None)\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(False)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": int(int(workflow.variables.num_partitions))\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "psi-data-join-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": true,
+ "job_type": "NN_MODEL_TRANINING",
+ "name": "nn-train-job",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "code_tar",
+ "tag": "",
+ "typed_value": {
+ "follower/main.py": "# encoding=utf8\nimport os\nimport logging\nimport datetime\n\nimport tensorflow.compat.v1 as tf\n\nimport fedlearner.trainer as flt\n# from byted_deepinsight import DeepInsight2Hook\n\nROLE = 'follower'\n\nparser = flt.trainer_worker.create_argument_parser()\nparser.add_argument('--batch-size', type=int, default=16,\n help='Training batch size.')\nparser.add_argument('--clean-model', type=bool, default=False,\n help='clean checkpoint and saved_model')\nargs = parser.parse_args()\n\n\ndef apply_clean():\n if args.worker_rank == 0 and args.clean_model and tf.io.gfile.exists(args.checkpoint_path):\n tf.logging.info(\"--clean_model flag set. Removing existing checkpoint_path dir:\"\n \" {}\".format(args.checkpoint_path))\n tf.io.gfile.rmtree(args.checkpoint_path)\n\n if args.worker_rank == 0 and args.clean_model and args.export_path and tf.io.gfile.exists(args.export_path):\n tf.logging.info(\"--clean_model flag set. Removing existing savedmodel dir:\"\n \" {}\".format(args.export_path))\n tf.io.gfile.rmtree(args.export_path)\n\n\ndef input_fn(bridge, trainer_master=None):\n dataset = flt.data.DataBlockLoader(\n args.batch_size, ROLE, bridge, trainer_master).make_dataset()\n \n def parse_fn(example):\n feature_map = {}\n feature_map[\"example_id\"] = tf.FixedLenFeature([], tf.string)\n feature_map[\"label\"] = tf.FixedLenFeature([], tf.int64)\n features = tf.parse_example(example, features=feature_map)\n labels = {'label': features.pop('label')}\n return features, labels\n\n dataset = dataset.map(map_func=parse_fn,\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\n return dataset\n\n\ndef raw_serving_input_receiver_fn():\n features = {}\n features['logits'] = tf.placeholder(dtype=tf.float32, name='logits')\n return tf.estimator.export.build_raw_serving_input_receiver_fn(features)()\n\n\ndef model_fn(model, features, labels, mode):\n global_step = tf.train.get_or_create_global_step()\n \n if mode == tf.estimator.ModeKeys.TRAIN:\n logits = model.recv('logits', tf.float32, require_grad=True)\n else:\n logits = features['logits']\n\n if mode == tf.estimator.ModeKeys.TRAIN:\n y = tf.dtypes.cast(labels['label'], tf.float32)\n logits = tf.reshape(logits, y.shape)\n loss = tf.nn.sigmoid_cross_entropy_with_logits(\n labels=y, logits=logits)\n loss = tf.math.reduce_mean(loss)\n\n # cala auc\n pred = tf.math.sigmoid(logits)\n _, auc = tf.metrics.auc(labels=y, predictions=pred)\n\n logging_hook = tf.train.LoggingTensorHook(\n {\"loss\": loss, \"auc\": auc}, every_n_iter=10)\n\n # send auc back to leader\n model.send('auc', auc, require_grad=False)\n model.send('loss', loss, require_grad=False)\n \n ## visulization with tensorboard\n # current_time = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n # train_log_dir = os.path.join(args.tensorboard_log, current_time)\n # loss_op = tf.summary.scalar('train_loss', loss)\n # auc_op = tf.summary.scalar('train_auc', auc)\n # summary_hook = tf.train.SummarySaverHook(\n # save_steps=5,\n # output_dir=train_log_dir,\n # summary_op=[loss_op, auc_op])\n \n ## visulization with deepinsight\n # uid_tensor = tf.reshape(features['uid'], shape=[-1])\n # req_time_tensor = tf.reshape(features['req_time'], shape=[-1])\n # score_tensor = tf.reshape(pred, shape=[args.batch_size])\n # label_tensor = tf.reshape(y, shape=[-1])\n # # logging.info(\"==> uid tensor : %s, req_time_tensor: %s, score_tensor: %s, label_tensor: %s\" % (uid_tensor, req_time_tensor, score_tensor, label_tensor))\n # deep_insight_hook = DeepInsight2Hook(uid_tensor, req_time_tensor, score_tensor, label_tensor)\n \n train_hooks = [logging_hook] #, summary_hook, deep_insight_hook]\n\n # optimizer = tf.train.GradientDescentOptimizer(0.1)\n # optimizer = tf.train.AdagradOptimizer(0.1)\n # optimizer = tf.train.AdamOptimizer()\n optimizer = tf.train.FtrlOptimizer(learning_rate=0.16921544485102483, \n l1_regularization_strength=1e-05, l2_regularization_strength=0.0005945795938393141,\n initial_accumulator_value=0.44352,\n learning_rate_power=-0.59496)\n train_op = model.minimize(optimizer, loss, global_step=global_step)\n return model.make_spec(mode, loss=loss, train_op=train_op,\n training_hooks=train_hooks)\n\n if mode == tf.estimator.ModeKeys.PREDICT:\n return model.make_spec(mode, predictions=logits)\n\n\n\nif __name__ == '__main__':\n logging.basicConfig(\n level=logging.INFO,\n format='%(asctime)-15s [%(filename)s:%(lineno)d] %(levelname)s %(message)s'\n )\n apply_clean()\n try:\n flt.trainer_worker.train(\n ROLE, args, input_fn,\n model_fn, raw_serving_input_receiver_fn)\n except ValueError as err:\n logging.info('cannot save model as there is no parameters: details:{}'.format(err))\n\n\n",
+ "main.py": ""
+ },
+ "value": "{\"main.py\":\"\",\"follower/main.py\":\"# encoding=utf8\\nimport os\\nimport logging\\nimport datetime\\n\\nimport tensorflow.compat.v1 as tf\\n\\nimport fedlearner.trainer as flt\\n# from byted_deepinsight import DeepInsight2Hook\\n\\nROLE = 'follower'\\n\\nparser = flt.trainer_worker.create_argument_parser()\\nparser.add_argument('--batch-size', type=int, default=16,\\n help='Training batch size.')\\nparser.add_argument('--clean-model', type=bool, default=False,\\n help='clean checkpoint and saved_model')\\nargs = parser.parse_args()\\n\\n\\ndef apply_clean():\\n if args.worker_rank == 0 and args.clean_model and tf.io.gfile.exists(args.checkpoint_path):\\n tf.logging.info(\\\"--clean_model flag set. Removing existing checkpoint_path dir:\\\"\\n \\\" {}\\\".format(args.checkpoint_path))\\n tf.io.gfile.rmtree(args.checkpoint_path)\\n\\n if args.worker_rank == 0 and args.clean_model and args.export_path and tf.io.gfile.exists(args.export_path):\\n tf.logging.info(\\\"--clean_model flag set. Removing existing savedmodel dir:\\\"\\n \\\" {}\\\".format(args.export_path))\\n tf.io.gfile.rmtree(args.export_path)\\n\\n\\ndef input_fn(bridge, trainer_master=None):\\n dataset = flt.data.DataBlockLoader(\\n args.batch_size, ROLE, bridge, trainer_master).make_dataset()\\n \\n def parse_fn(example):\\n feature_map = {}\\n feature_map[\\\"example_id\\\"] = tf.FixedLenFeature([], tf.string)\\n feature_map[\\\"label\\\"] = tf.FixedLenFeature([], tf.int64)\\n features = tf.parse_example(example, features=feature_map)\\n labels = {'label': features.pop('label')}\\n return features, labels\\n\\n dataset = dataset.map(map_func=parse_fn,\\n num_parallel_calls=tf.data.experimental.AUTOTUNE)\\n return dataset\\n\\n\\ndef raw_serving_input_receiver_fn():\\n features = {}\\n features['logits'] = tf.placeholder(dtype=tf.float32, name='logits')\\n return tf.estimator.export.build_raw_serving_input_receiver_fn(features)()\\n\\n\\ndef model_fn(model, features, labels, mode):\\n global_step = tf.train.get_or_create_global_step()\\n \\n if mode == tf.estimator.ModeKeys.TRAIN:\\n logits = model.recv('logits', tf.float32, require_grad=True)\\n else:\\n logits = features['logits']\\n\\n if mode == tf.estimator.ModeKeys.TRAIN:\\n y = tf.dtypes.cast(labels['label'], tf.float32)\\n logits = tf.reshape(logits, y.shape)\\n loss = tf.nn.sigmoid_cross_entropy_with_logits(\\n labels=y, logits=logits)\\n loss = tf.math.reduce_mean(loss)\\n\\n # cala auc\\n pred = tf.math.sigmoid(logits)\\n _, auc = tf.metrics.auc(labels=y, predictions=pred)\\n\\n logging_hook = tf.train.LoggingTensorHook(\\n {\\\"loss\\\": loss, \\\"auc\\\": auc}, every_n_iter=10)\\n\\n # send auc back to leader\\n model.send('auc', auc, require_grad=False)\\n model.send('loss', loss, require_grad=False)\\n \\n ## visulization with tensorboard\\n # current_time = datetime.datetime.now().strftime(\\\"%Y%m%d_%H%M%S\\\")\\n # train_log_dir = os.path.join(args.tensorboard_log, current_time)\\n # loss_op = tf.summary.scalar('train_loss', loss)\\n # auc_op = tf.summary.scalar('train_auc', auc)\\n # summary_hook = tf.train.SummarySaverHook(\\n # save_steps=5,\\n # output_dir=train_log_dir,\\n # summary_op=[loss_op, auc_op])\\n \\n ## visulization with deepinsight\\n # uid_tensor = tf.reshape(features['uid'], shape=[-1])\\n # req_time_tensor = tf.reshape(features['req_time'], shape=[-1])\\n # score_tensor = tf.reshape(pred, shape=[args.batch_size])\\n # label_tensor = tf.reshape(y, shape=[-1])\\n # # logging.info(\\\"==> uid tensor : %s, req_time_tensor: %s, score_tensor: %s, label_tensor: %s\\\" % (uid_tensor, req_time_tensor, score_tensor, label_tensor))\\n # deep_insight_hook = DeepInsight2Hook(uid_tensor, req_time_tensor, score_tensor, label_tensor)\\n \\n train_hooks = [logging_hook] #, summary_hook, deep_insight_hook]\\n\\n # optimizer = tf.train.GradientDescentOptimizer(0.1)\\n # optimizer = tf.train.AdagradOptimizer(0.1)\\n # optimizer = tf.train.AdamOptimizer()\\n optimizer = tf.train.FtrlOptimizer(learning_rate=0.16921544485102483, \\n l1_regularization_strength=1e-05, l2_regularization_strength=0.0005945795938393141,\\n initial_accumulator_value=0.44352,\\n learning_rate_power=-0.59496)\\n train_op = model.minimize(optimizer, loss, global_step=global_step)\\n return model.make_spec(mode, loss=loss, train_op=train_op,\\n training_hooks=train_hooks)\\n\\n if mode == tf.estimator.ModeKeys.PREDICT:\\n return model.make_spec(mode, predictions=logits)\\n\\n\\n\\nif __name__ == '__main__':\\n logging.basicConfig(\\n level=logging.INFO,\\n format='%(asctime)-15s [%(filename)s:%(lineno)d] %(levelname)s %(message)s'\\n )\\n apply_clean()\\n try:\\n flt.trainer_worker.train(\\n ROLE, args, input_fn,\\n model_fn, raw_serving_input_receiver_fn)\\n except ValueError as err:\\n logging.info('cannot save model as there is no parameters: details:{}'.format(err))\\n\\n\\n\"}",
+ "value_type": "CODE",
+ "widget_schema": "{\"component\":\"Code\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "epoch_num",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "pod_cpu",
+ "tag": "",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "pod_mem",
+ "tag": "",
+ "typed_value": "8Gi",
+ "value": "8Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_pod_num",
+ "tag": "",
+ "typed_value": "1",
+ "value": "1",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "sparse_estimator",
+ "tag": "",
+ "typed_value": "false",
+ "value": "false",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"true\",\"false\"]}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"activeDeadlineSeconds\": 1200, \n \"fedReplicaSpecs\": {\n \"Master\": {\n \"backoffLimit\": 1,\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(int(self.variables.epoch_num))\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(None)\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(None)\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": workflow.jobs['psi-data-join-job'].name\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": \"\"\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(bool(self.variables.sparse_estimator))\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": \"\"\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": self.variables.code_tar\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": \"\" and project.variables.storage_root_path + \"/job_output/\" + \"\" + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": \"\"\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": self.variables.pod_cpu,\n \"memory\": self.variables.pod_mem\n },\n \"requests\": {\n \"cpu\": self.variables.pod_cpu,\n \"memory\": self.variables.pod_mem\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n\n }\n },\n \"replicas\": int(1)\n },\n \"PS\": {\n \"backoffLimit\": 1,\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": self.variables.pod_cpu,\n \"memory\": self.variables.pod_mem\n },\n \"requests\": {\n \"cpu\": self.variables.pod_cpu,\n \"memory\": self.variables.pod_mem\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(1)\n },\n \"Worker\": {\n \"backoffLimit\": 6,\n \"mustSuccess\": True,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": \"train\"\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(1)\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": \"\"\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": self.variables.code_tar\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(1000)\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(None)\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(bool(self.variables.sparse_estimator))\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(None)\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": workflow.variables.image,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\":[\"/bin/bash\",\"-c\"],\n \"args\": [\"export WORKER_RANK=$$INDEX && export PEER_ADDR=$$SERVICE_ID && /app/deploy/scripts/trainer/run_trainer_worker.sh\"],\n \"resources\": {\n \"limits\": {\n \"cpu\": self.variables.pod_cpu,\n \"memory\": \"3Gi\"\n },\n \"requests\": {\n \"cpu\": self.variables.pod_cpu,\n \"memory\": \"3Gi\"\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(int(self.variables.worker_pod_num))\n }\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "image",
+ "tag": "",
+ "typed_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "num_partitions",
+ "tag": "",
+ "typed_value": "4",
+ "value": "4",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "role",
+ "tag": "",
+ "typed_value": "Follower",
+ "value": "Follower",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"]}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "input_base_dir",
+ "tag": "",
+ "typed_value": "/data/upload/test_sparse_estimator",
+ "value": "/data/upload/test_sparse_estimator",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_wildcard",
+ "tag": "",
+ "typed_value": "*part-r-*",
+ "value": "*part-r-*",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "input_data_format",
+ "tag": "",
+ "typed_value": "TF_RECORD",
+ "value": "TF_RECORD",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"CSV_DICT\",\"TF_RECORD\"]}"
+ }
+ ]
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "nn-train-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(${Slot_epoch_num})\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(${Slot_start_date})\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(${Slot_end_date})\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": ${Slot_online_training}\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": ${Slot_checkpoint_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": ${Slot_load_checkpoint_filename}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": ${Slot_load_checkpoint_filename_with_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": ${Slot_load_checkpoint_from_job} and ${Slot_storage_root_path} + \"/job_output/\" + ${Slot_load_checkpoint_from_job} + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": ${Slot_export_path}\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_master_replicas})\n },\n \"PS\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + ${Slot_ps_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_ps_replicas})\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(${Slot_save_checkpoint_steps})\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(${Slot_save_checkpoint_secs})\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(${Slot_summary_save_steps})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_trainer_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_worker_replicas})\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_checkpoint_path": {
+ "default": "",
+ "default_value": "",
+ "help": "不建议修改,checkpoint输出路径,建议为空,会默认使用{storage_root_path}/job_output/{job_name}/checkpoints,强烈建议保持空值",
+ "label": "CHECKPOINT_PATH",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_key": {
+ "default": "",
+ "default_value": "",
+ "help": "代码tar包地址,如果为空则使用code tar",
+ "label": "模型代码路径",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_tar": {
+ "default": "",
+ "default_value": "",
+ "help": "代码包,variable中请使用代码类型",
+ "label": "代码",
+ "reference": "self.variables.code_tar",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,求交任务的名字",
+ "label": "数据源",
+ "reference": "workflow.jobs['psi-data-join-job'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_end_date": {
+ "default": "",
+ "default_value": null,
+ "help": "training data end date",
+ "label": "结束时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_epoch_num": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "number of epoch for training, not support in online training",
+ "label": "epoch数量",
+ "reference": "self.variables.epoch_num",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_export_path": {
+ "default": "",
+ "default_value": "",
+ "help": "使用默认空值,将把models保存到$OUTPUT_BASE_DIR/exported_models 路径下。",
+ "label": "EXPORT_PATH",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模板不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_load_checkpoint_filename": {
+ "default": "",
+ "default_value": "",
+ "help": "加载checkpoint_path下的相对路径的checkpoint, 默认会加载checkpoint_path下的latest checkpoint",
+ "label": "LOAD_CHECKPOINT_FILENAME",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_filename_with_path": {
+ "default": "",
+ "default_value": "",
+ "help": "加载绝对路径下的checkpoint,需要细致到文件名",
+ "label": "从绝对路径加载checkpoint",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_from_job": {
+ "default": "",
+ "default_value": "",
+ "help": "指定任务名job_output下的latest checkpoint",
+ "label": "以任务名加载checkpoint",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "self.variables.pod_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "self.variables.pod_mem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_mode": {
+ "default": "",
+ "default_value": "train",
+ "help": "choices:['train','eval'] 训练还是验证",
+ "label": "模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_online_training": {
+ "default": "",
+ "default_value": "",
+ "help": "['','--online_training'] 否 是,the train master run for online training",
+ "label": "是否在线训练",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "PS的CPU",
+ "reference": "self.variables.pod_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_ps_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,ps pod额外的环境变量",
+ "label": "PS额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_ps_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "PS的内存",
+ "reference": "self.variables.pod_mem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_ps_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的PS Pods数量",
+ "label": "PS的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_save_checkpoint_secs": {
+ "default": "",
+ "default_value": null,
+ "help": "int,Number of secs between checkpoints.",
+ "label": "SAVE_CHECKPOINT_SECS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_save_checkpoint_steps": {
+ "default": "",
+ "default_value": 1000.0,
+ "help": "int, Number of steps between checkpoints.",
+ "label": "SAVE_CHECKPOINT_STEPS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_sparse_estimator": {
+ "default": "",
+ "default_value": false,
+ "help": "bool,default False Whether using sparse estimator.",
+ "label": "SPARSE_ESTIMATOR",
+ "reference": "self.variables.sparse_estimator",
+ "reference_type": "SELF",
+ "value_type": "BOOL"
+ },
+ "Slot_start_date": {
+ "default": "",
+ "default_value": null,
+ "help": "training data start date",
+ "label": "开始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_suffle_data_block": {
+ "default": "",
+ "default_value": "",
+ "help": "['','--shuffle_data_block'] 否 是,shuffle the data block or not",
+ "label": "是否shuffle数据块",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_summary_save_steps": {
+ "default": "",
+ "default_value": null,
+ "help": "int, Number of steps to save summary files.",
+ "label": "SUMMARY_SAVE_STEPS",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_verbosity": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "int, Logging level",
+ "label": "日志等级",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "self.variables.pod_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Worker Pods数量",
+ "label": "Worker的Pod个数",
+ "reference": "self.variables.worker_pod_num",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ }
+ },
+ "variables": []
+ },
+ "psi-data-join-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(${Slot_start_time})\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(${Slot_end_time})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": ${Slot_batch_mode}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": ${Slot_rsa_key_pem}\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": ${Slot_rsa_key_path}\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": ${Slot_rsa_private_key_path}\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": ${Slot_kms_key_name}\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": ${Slot_kms_client}\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": ${Slot_psi_raw_data_iter}\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": ${Slot_data_block_builder}\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": ${Slot_psi_output_builder}\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(${Slot_data_block_dump_interval})\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(${Slot_data_block_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(${Slot_example_id_dump_interval})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(${Slot_example_id_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(${Slot_psi_read_ahead_size})\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(${Slot_run_merger_read_ahead_buffer})\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(${Slot_enable_negative_example_generator})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_partition_num})\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_mode": {
+ "default": "",
+ "default_value": "--batch_mode",
+ "help": "如果为空则为常驻求交",
+ "label": "是否为批处理模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "data block output数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次data block,小于0则无此限制",
+ "label": "数据dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_block_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多多少个样本就dump为一个data block,小于等于0则无此限制",
+ "label": "数据dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_enable_negative_example_generator": {
+ "default": "",
+ "default_value": false,
+ "help": "建议不修改,是否开启负采样,当follower求交时遇到无法匹配上的leader的example id,会以negative_sampling_rate为概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_end_time": {
+ "default": "",
+ "default_value": 999999999999.0,
+ "help": "建议不修改,使用自这个时间以前的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据末尾时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_kms_client": {
+ "default": "",
+ "default_value": "data.aml.fl",
+ "help": "kms client",
+ "label": "kms client",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_kms_key_name": {
+ "default": "",
+ "default_value": "",
+ "help": "kms中的密钥名称,站内镜像需使用KMS",
+ "label": "密钥名称",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_negative_sampling_rate": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,负采样比例,当follower求交时遇到无法匹配上的leader的example id,会以此概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "建议修改,求交后数据分区的数量,建议和raw_data一致",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_psi_output_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "PSI output数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_psi_raw_data_iter": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "raw data数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_psi_read_ahead_size": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, the read ahead size for raw data",
+ "label": "psi_read_ahead_size",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,原始数据的发布地址,根据参数内容在portal_publish_dir地址下寻找",
+ "label": "raw_data名字",
+ "reference": "workflow.jobs['raw-data-job-psi'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "workflow.variables.role",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_path": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA公钥或私钥的地址,在无RSA_KEY_PEM时必填",
+ "label": "RSA钥匙地址",
+ "reference": "self.variables.rsa_key_path",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_pem": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA公钥,follower需提供",
+ "label": "RSA公钥",
+ "reference": "self.variables.rsa_key_pem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_private_key_path": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA私钥的地址, leader必填",
+ "label": "RSA私钥地址",
+ "reference": "self.variables.rsa_private_key_path",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_run_merger_read_ahead_buffer": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, sort run merger read ahead buffer",
+ "label": "run_merger_read_ahead_buffer",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_start_time": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,使用自这个时间起的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据起始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "self.variables.worker_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "self.variables.worker_mem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "raw-data-job-psi": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "PSI",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "workflow.variables.file_wildcard",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模版不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "workflow.variables.input_base_dir",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "workflow.variables.input_data_format",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ }
+ }
+ },
+ "group_alias": "e2e-test",
+ "name": "e2e-sparse-estimator-test-right"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-alignment-task.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-alignment-task.json
new file mode 100644
index 000000000..faebee076
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-alignment-task.json
@@ -0,0 +1,196 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "sys_preset_data_alignment",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "alignment-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/dataset_alignment.py\",\n \"arguments\": [\n \"--input_dataset_path=\" + workflow.variables.input_dataset_path,\n \"--input_batch_path=\" + workflow.variables.input_batch_path,\n \"--json_schema=\" + workflow.variables.json_schema,\n \"--wildcard=\" + workflow.variables.wildcard,\n \"--output_dataset_path=\" + workflow.variables.output_dataset_path,\n \"--output_batch_path=\" + workflow.variables.output_batch_path,\n \"--output_error_path=\" + workflow.variables.output_dataset_path + \"/errors\",\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": int(workflow.variables.driver_cores),\n \"coreLimit\": workflow.variables.driver_cores_limit,\n \"memory\": workflow.variables.driver_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": int(workflow.variables.initial_executors),\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executors),\n \"maxExecutors\": int(workflow.variables.max_executors),\n \"minExecutors\": int(workflow.variables.min_executors)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "alignment-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n workflow.variables.data_type,\n \"--data_path=\" + workflow.variables.output_dataset_path,\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--buckets_num=\" + workflow.variables.buckets_num,\n \"--thumbnail_path=\" + workflow.variables.thumbnail_path,\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 1,\n \"coreLimit\": \"1200m\",\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list\n },\n \"executor\": {\n \"cores\": 2,\n \"instances\": 2,\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": 2,\n \"maxExecutors\": 64,\n \"minExecutors\": 2,\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "buckets_num",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "thumbnail_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_dataset_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_batch_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_dataset_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_batch_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "data_type",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "tabular",
+ "value": "tabular",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "wildcard",
+ "tag": "INPUT_PATH",
+ "typed_value": "**",
+ "value": "**",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "json_schema",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "driver_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "1",
+ "value": "1",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "driver_cores_limit",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "driver_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "initial_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "128",
+ "value": "128",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "min_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OUTPUT_PATH\",\"hidden\":true}"
+ }
+ ]
+ },
+ "editor_info": {},
+ "name": "sys-preset-alignment-task"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-analyzer.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-analyzer.json
new file mode 100644
index 000000000..2bee220ba
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-analyzer.json
@@ -0,0 +1,148 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "sys_preset_analyzer",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n workflow.variables.data_type,\n \"--data_path=\" + workflow.variables.input_batch_path,\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--buckets_num=\" + workflow.variables.buckets_num,\n \"--thumbnail_path=\" + workflow.variables.thumbnail_path,\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": int(workflow.variables.driver_cores),\n \"coreLimit\": workflow.variables.driver_cores_limit,\n \"memory\": workflow.variables.driver_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": int(workflow.variables.executor_nums),\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executors),\n \"maxExecutors\": int(workflow.variables.max_executors),\n \"minExecutors\": int(workflow.variables.min_executors)\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_batch_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PATH\",\"tooltip\":\"输入batch路径\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "data_type",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "tabular",
+ "value": "tabular",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "buckets_num",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"the number of buckets for hist\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "thumbnail_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"dir path to save the thumbnails\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "driver_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "1",
+ "value": "1",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "driver_cores_limit",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "driver_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "executor_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "executor_nums",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "executor_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "initial_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "max_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "64",
+ "value": "64",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "min_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ }
+ ]
+ },
+ "editor_info": {},
+ "group_alias": "sys_preset_analyzer",
+ "name": "sys-preset-analyzer"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-converter-analyzer.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-converter-analyzer.json
new file mode 100644
index 000000000..dc3387407
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-converter-analyzer.json
@@ -0,0 +1,246 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "sys_preset_converter_analyzer",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "format-checker",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/dataset_format_checker.py\",\n \"arguments\": [\n workflow.variables.data_type,\n \"--input_batch_path=\" + workflow.variables.input_batch_path,\n \"--format=\" + workflow.variables.file_format,\n \"--checkers=\" + workflow.variables.checkers,\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": int(workflow.variables.driver_cores),\n \"coreLimit\": workflow.variables.driver_cores_limit,\n \"memory\": workflow.variables.driver_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": int(workflow.variables.executor_nums),\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executors),\n \"maxExecutors\": int(workflow.variables.max_executors),\n \"minExecutors\": int(workflow.variables.min_executors)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "format-checker"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "converter",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/converter_v2.py\",\n \"arguments\": [\n workflow.variables.data_type,\n \"--output_dataset_path=\" + workflow.variables.dataset_path, \n \"--output_batch_path=\" + workflow.variables.batch_path,\n \"--input_batch_path=\" + workflow.variables.input_batch_path,\n \"--format=\" + workflow.variables.file_format,\n \"--manifest_name=\" + workflow.variables.manifest_name,\n \"--images_dir_name=\" + workflow.variables.images_dir_name,\n \"--import_type=\" + workflow.variables.import_type,\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": int(workflow.variables.driver_cores),\n \"coreLimit\": workflow.variables.driver_cores_limit,\n \"memory\": workflow.variables.driver_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": int(workflow.variables.executor_nums),\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executors),\n \"maxExecutors\": int(workflow.variables.max_executors),\n \"minExecutors\": int(workflow.variables.min_executors)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "converter"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n workflow.variables.data_type,\n \"--data_path=\" + workflow.variables.dataset_path,\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--buckets_num=\" + workflow.variables.buckets_num,\n \"--thumbnail_path=\" + workflow.variables.thumbnail_path,\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n \"--skip\" if str(workflow.variables.skip_analyzer)==\"true\" else ''\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": int(workflow.variables.driver_cores),\n \"coreLimit\": workflow.variables.driver_cores_limit,\n \"memory\": workflow.variables.driver_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": int(workflow.variables.executor_nums),\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executors),\n \"maxExecutors\": int(workflow.variables.max_executors),\n \"minExecutors\": int(workflow.variables.min_executors)\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_format",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "csv",
+ "value": "csv",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "dataset_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "batch_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_batch_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "data_type",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "tabular",
+ "value": "tabular",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "manifest_name",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "manifest.json",
+ "value": "manifest.json",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"manifest file name in image dataset directory\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "images_dir_name",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "images",
+ "value": "images",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"images directory name in image dataset directory\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "buckets_num",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"the number of buckets for hist\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "thumbnail_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"dir path to save the thumbnails\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "driver_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "1",
+ "value": "1",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "driver_cores_limit",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "driver_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "executor_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "executor_nums",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4",
+ "value": "4",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "executor_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "initial_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "max_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "64",
+ "value": "64",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "min_executors",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "checkers",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"OPERATING_PARAM\",\"tooltip\":\"数据集导入检查项\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "skip_analyzer",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "false",
+ "value": "false",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OPERATING_PARAM\",\"hidden\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "import_type",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "COPY",
+ "value": "COPY",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OUTPUT_PATH\",\"hidden\":true}"
+ }
+ ]
+ },
+ "editor_info": {},
+ "group_alias": "sys_preset_converter_analyzer",
+ "name": "sys-preset-converter-analyzer"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-export-dataset.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-export-dataset.json
new file mode 100644
index 000000000..6e57e1fc3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-export-dataset.json
@@ -0,0 +1,277 @@
+{
+ "name": "sys-preset-export-dataset",
+ "group_alias": "sys_preset_export_dataset",
+ "config": {
+ "group_alias": "sys_preset_export_dataset",
+ "variables": [
+ {
+ "name": "dataset_path",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"待倒出数据集路径\",\"hidden\":true,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "export_path",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"导出路径\",\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "OUTPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "file_wildcard",
+ "value": "batch/**/**",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "batch/**/**",
+ "tag": "INPUT_PATH",
+ "value_type": "STRING"
+ },
+ {
+ "name": "driver_cores",
+ "value": "1",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "1",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "driver_cores_limit",
+ "value": "4000m",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4000m",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "driver_mem",
+ "value": "4g",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4g",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "executor_cores",
+ "value": "2",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "2",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "executor_mem",
+ "value": "4g",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4g",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "initial_executors",
+ "value": "2",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "2",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "max_executors",
+ "value": "128",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "128",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "min_executors",
+ "value": "2",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "2",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "file_format",
+ "value": "tfrecords",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OPERATING_PARAM\",\"tooltip\":\"文件存储格式\"}",
+ "typed_value": "tfrecords",
+ "tag": "OPERATING_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "batch_name",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"数据批次名\",\"tag\":\"OPERATING_PARAM\"}",
+ "typed_value": "",
+ "tag": "OPERATING_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ }
+ ],
+ "job_definitions": [
+ {
+ "name": "export-dataset",
+ "job_type": "TRANSFORMER",
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/export_dataset.py\",\n \"arguments\": [\n \"--data_path=\" + str(workflow.variables.dataset_path),\n \"--file_wildcard=\" + str(workflow.variables.file_wildcard),\n \"--export_path=\" + str(workflow.variables.export_path),\n \"--batch_name=\" + str(workflow.variables.batch_name),\n \"--file_format=\" + str(workflow.variables.file_format)\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": int(workflow.variables.driver_cores),\n \"coreLimit\": workflow.variables.driver_cores_limit,\n \"memory\": workflow.variables.driver_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": int(workflow.variables.initial_executors),\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executors),\n \"maxExecutors\": int(workflow.variables.max_executors),\n \"minExecutors\": int(workflow.variables.min_executors)\n }\n }\n}\n",
+ "is_federated": false,
+ "variables": [],
+ "dependencies": [],
+ "easy_mode": false
+ }
+ ]
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "export-dataset": {
+ "slots": {
+ "Slot_labels": {
+ "reference": "system.variables.labels",
+ "help": "建议不修改,格式: {}",
+ "reference_type": "SYSTEM",
+ "label": "FLAPP额外元信息",
+ "default_value": {},
+ "value_type": "OBJECT",
+ "default": ""
+ },
+ "Slot_volumes": {
+ "reference": "system.variables.volumes_list",
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "reference_type": "SYSTEM",
+ "label": "为Pod提供的卷",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_driver_cores": {
+ "reference": "self.variables.undefined",
+ "help": "driver核心数",
+ "reference_type": "SELF",
+ "label": "driver核心数",
+ "default_value": "1000m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_configs": {
+ "help": "使用特征选择组件",
+ "label": "配置",
+ "default_value": {},
+ "value_type": "OBJECT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_executor_instances": {
+ "reference": "self.variables.undefined",
+ "help": "excutor实例数",
+ "reference_type": "SELF",
+ "label": "excutor实例数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_executor_cores": {
+ "reference": "self.variables.undefined",
+ "help": "excutor核心数",
+ "reference_type": "SELF",
+ "label": "excutor核心数",
+ "default_value": "1000m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "reference": "system.variables.spark_image",
+ "help": "特征工程时选用的镜像",
+ "reference_type": "SYSTEM",
+ "label": "镜像",
+ "default_value": "artifact.bytedance.com/tce/spark_tfrecords_base:a3b2965430074bce316b13ec98ba8856",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_spark_transformer_file": {
+ "label": "特征工程脚本文件",
+ "default_value": "aaaaaa",
+ "reference": "",
+ "default": "",
+ "help": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_driver_core_limit": {
+ "reference": "self.variables.undefined",
+ "help": "driver核心数限制",
+ "reference_type": "SELF",
+ "label": "driver核心数限制",
+ "default_value": "1200m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "reference": "system.variables.volume_mounts_list",
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "reference_type": "SYSTEM",
+ "label": "卷挂载位置",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_dataset": {
+ "label": "输入数据集",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "help": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_driver_memory": {
+ "reference": "self.variables.undefined",
+ "help": "driver内存",
+ "reference_type": "SELF",
+ "label": "driver内存",
+ "default_value": "1024m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_executor_memory": {
+ "reference": "self.variables.undefined",
+ "help": "excutor内存",
+ "reference_type": "SELF",
+ "label": "excutor内存",
+ "default_value": "512m",
+ "default": "",
+ "value_type": "STRING"
+ }
+ },
+ "meta_yaml": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": ${Slot_labels},\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": ${Slot_image},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": ${Slot_volumes},\n \"mainApplicationFile\": ${Slot_spark_transformer_file},\n \"arguments\": [\n ${Slot_dataset},\n \"rds/**\",\n str(${Slot_configs})\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": ${Slot_driver_cores},\n \"coreLimit\": ${Slot_driver_core_limit},\n \"memory\": ${Slot_driver_memory},\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": ${Slot_volume_mounts}\n },\n \"executor\": {\n \"cores\": ${Slot_executor_cores},\n \"instances\": ${Slot_executor_instances},\n \"memory\": ${Slot_executor_memory},\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": ${Slot_volume_mounts}\n }\n }\n}\n",
+ "variables": []
+ }
+ }
+ },
+ "comment": ""
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-hash-data-join-analyzer.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-hash-data-join-analyzer.json
new file mode 100644
index 000000000..10a073159
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-hash-data-join-analyzer.json
@@ -0,0 +1,205 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "hash-data-join",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "partition-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/partition.py\",\n \"arguments\": [\n \"--input_path=\" + workflow.variables.input_batch_path + '/' + 'part*',\n \"--file_format=\" + 'tfrecords',\n \"--part_key=\" + workflow.variables.part_key,\n \"--part_num=\" + workflow.variables.part_num,\n \"--output_file_format=\" + 'tfrecords',\n \"--output_dir=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1]\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 2,\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "partition-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": true,
+ "job_type": "PSI_DATA_JOIN",
+ "name": "hash-data-join",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"fedReplicaSpecs\": {\n \"Worker\": {\n \"backoffLimit\": 5,\n \"port\": {\n \"containerPort\": 32443,\n \"name\": \"flapp-port\"\n },\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"name\": \"psi\",\n \"image\": system.variables.image_repo + \"/pp_lite:\" + system.version,\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role\n },\n {\n \"name\": \"JOB_TYPE\",\n \"value\": \"psi-hash\"\n },\n {\n \"name\": \"PEER_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"AUTHORITY\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"PEER_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"INPUT_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/ids'\n },\n {\n \"name\": \"OUTPUT_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined'\n },\n {\n \"name\": \"KEY_COLUMN\",\n \"value\": workflow.variables.part_key\n }\n ],\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"resources\": {\n \"limits\": {\n \"cpu\": workflow.variables.worker_cpu,\n \"memory\": workflow.variables.worker_mem\n },\n \"requests\": {\n \"cpu\": workflow.variables.worker_cpu,\n \"memory\": workflow.variables.worker_mem\n }\n },\n \"ports\": [\n {\n \"containerPort\": 32443,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50051,\n \"name\": \"server-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tunnel-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 1212,\n \"name\": \"joiner-port\",\n \"protocol\": \"TCP\"\n }\n ],\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list),\n }\n },\n \"pair\": True,\n \"replicas\": int(workflow.variables.replicas)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "hash-data-join"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "feature-extraction",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/feature_extraction_v2.py\",\n \"arguments\": [\n \"--original_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/raw',\n \"--joined_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined',\n \"--part_key=\" + workflow.variables.part_key,\n \"--part_num=\" + workflow.variables.part_num,\n \"--file_format=\" + 'tfrecords',\n \"--output_file_format=\" + 'tfrecords',\n \"--output_batch_name=\" + workflow.variables.output_batch_path.split('/')[-1],\n \"--output_dataset_path=\" + workflow.variables.output_dataset_path\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 2,\n \"memory\": '4g',\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "feature-extraction"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n \"tabular\",\n \"--data_path=\"+ (str(workflow.variables.output_dataset_path) or str(project.variables.storage_root_path) + \"/\" + \"dataset\" + \"/\" + \"\"),\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--buckets_num=\" + str(10),\n \"--thumbnail_path=\" + \"\",\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 1,\n \"coreLimit\": \"1200m\",\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.basic_envs_list + system.variables.envs_list + []\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_dataset_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "undefined",
+ "value": "undefined",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"DatasetPath\",\"required\":true,\"tooltip\":\"输入数据地址\",\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_key",
+ "tag": "INPUT_PARAM",
+ "typed_value": "raw_id",
+ "value": "raw_id",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交的key\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_num",
+ "tag": "INPUT_PARAM",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"分区数量\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_batch_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输出数据batch地址\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_dataset_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输出数据集地址\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "role",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "undefined",
+ "value": "undefined",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"tooltip\":\"OtPsi角色\",\"enum\":[\"client\",\"server\"],\"hidden\":false,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "replicas",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交worker数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_cpu",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2000m",
+ "value": "2000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4Gi",
+ "value": "4Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_batch_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输入batch地址\",\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor核数\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor内存\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "initial_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务初始化executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "min_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最小executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "64",
+ "value": "64",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最大executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ }
+ ]
+ },
+ "editor_info": {},
+ "group_alias": "hash-data-join",
+ "name": "sys-preset-hash-data-join-analyzer"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-light-ot-data-join.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-light-ot-data-join.json
new file mode 100644
index 000000000..a8a6c5d32
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-light-ot-data-join.json
@@ -0,0 +1,214 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "light-ot-psi",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "partition-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/partition.py\",\n \"arguments\": [\n \"--input_path=\" + workflow.variables.input_batch_path + '/' + 'part*',\n \"--file_format=\" + workflow.variables.file_format,\n \"--part_key=\" + workflow.variables.part_key,\n \"--part_num=\" + workflow.variables.part_num,\n \"--output_file_format=\" + workflow.variables.file_format,\n \"--output_dir=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1]\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 2,\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "partition-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "PSI_DATA_JOIN",
+ "name": "lc-start-server",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\",\n \"min-member\": \"1\",\n \"resource-cpu\": str(workflow.variables.worker_cpu),\n \"resource-mem\": str(workflow.variables.worker_mem),\n },\n },\n \"spec\": {\n \"fedReplicaSpecs\": {\n \"Worker\": {\n \"backoffLimit\": 5,\n \"port\": {\n \"containerPort\": 32443,\n \"name\": \"flapp-port\"\n },\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"name\": \"psi\",\n \"image\": system.variables.image_repo + \"/pp_lite:\" + system.version,\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"ROLE\",\n \"value\": \"server\"\n },\n {\n \"name\": \"JOB_TYPE\",\n \"value\": \"psi-ot\"\n },\n {\n \"name\": \"PEER_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"AUTHORITY\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"PEER_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"INPUT_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/ids'\n },\n {\n \"name\": \"OUTPUT_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined'\n },\n {\n \"name\": \"KEY_COLUMN\",\n \"value\": workflow.variables.part_key\n }\n ],\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"resources\": {\n \"limits\": {\n \"cpu\": workflow.variables.worker_cpu,\n \"memory\": workflow.variables.worker_mem\n },\n \"requests\": {\n \"cpu\": workflow.variables.worker_cpu,\n \"memory\": workflow.variables.worker_mem\n }\n },\n \"ports\": [\n {\n \"containerPort\": 32443,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50051,\n \"name\": \"server-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tunnel-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 1212,\n \"name\": \"joiner-port\",\n \"protocol\": \"TCP\"\n }\n ],\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list),\n }\n },\n \"pair\": True,\n \"replicas\": int(workflow.variables.replicas)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "lc-start-server"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "feature-extraction",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/feature_extraction_v2.py\",\n \"arguments\": [\n \"--original_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/raw',\n \"--joined_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined',\n \"--part_key=\" + workflow.variables.part_key,\n \"--part_num=\" + workflow.variables.part_num,\n \"--file_format=\" + workflow.variables.file_format,\n \"--output_file_format=\" + workflow.variables.output_file_format,\n \"--output_batch_name=\" + workflow.variables.output_batch_path.split('/')[-1],\n \"--output_dataset_path=\" + workflow.variables.output_dataset_path\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 2,\n \"memory\": '4g',\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "feature-extraction"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n \"tabular\",\n \"--data_path=\"+ (str(workflow.variables.output_dataset_path) or str(project.variables.storage_root_path) + \"/\" + \"dataset\" + \"/\" + \"\"),\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--buckets_num=\" + str(10),\n \"--thumbnail_path=\" + \"\",\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 1,\n \"coreLimit\": \"1200m\",\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.basic_envs_list + system.variables.envs_list + []\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_dataset_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "undefined",
+ "value": "undefined",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"DatasetPath\",\"required\":true,\"tooltip\":\"输入数据地址\",\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_key",
+ "tag": "INPUT_PARAM",
+ "typed_value": "raw_id",
+ "value": "raw_id",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交的key\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_num",
+ "tag": "INPUT_PARAM",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"分区数量\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_batch_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "0",
+ "value": "0",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输出数据batch地址\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_dataset_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输出数据集地址\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "replicas",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交worker数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_cpu",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2000m",
+ "value": "2000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4Gi",
+ "value": "4Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_batch_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输入batch地址\",\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor核数\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor内存\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "initial_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务初始化executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "min_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最小executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "64",
+ "value": "64",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最大executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_format",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "tfrecords",
+ "value": "tfrecords",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"输入数据格式,支持csv或tfrecords\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_file_format",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "tfrecords",
+ "value": "tfrecords",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"输出数据格式,支持csv或tfrecords\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ }
+ ]
+ },
+ "editor_info": {},
+ "group_alias": "ot-psi",
+ "name": "sys-preset-light-ot-data-join"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-light-psi-data-join.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-light-psi-data-join.json
new file mode 100644
index 000000000..1fce367bc
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-light-psi-data-join.json
@@ -0,0 +1,1107 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "light-psi",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "lc-sign-raw-data",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"Always\",\n \"volumes\": list(system.variables.volumes_list),\n \"arguments\": ['/opt/spark/work-dir/psi.py'],\n \"sparkVersion\": \"3.0.0\",\n \"sparkConf\":{\n \"spark.shuffle.service.enabled\": \"false\"\n },\n \"restartPolicy\": {\n \"type\": \"Never\"\n },\n \"driver\": {\n \"cores\": 1,\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [\n { \n \"name\": \"INPUT_DIR\",\n \"value\": workflow.variables.input_batch_path\n },\n { \n \"name\": \"OUTPUT_DIR\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1]\n },\n { \n \"name\": \"RSA_KEY_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/rsa_private.key'\n },\n { \n \"name\": \"PART_NUM\",\n \"value\": workflow.variables.part_num\n },\n {\n \"name\": \"PART_KEY\",\n \"value\": workflow.variables.part_key\n },\n { \n \"name\": \"FILE_FORMAT\",\n \"value\": \"tfrecords\"\n },\n { \n \"name\": \"RSA_KEY_BITS\",\n \"value\": workflow.variables.rsa_key_bits\n },\n { \n \"name\": \"WILDCARD\",\n \"value\": \"part*\"\n }\n ]\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "lc-sign-raw-data"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "lc-start-server",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": system.variables.labels,\n \"annotations\": {\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\",\n \"min-member\": \"1\",\n \"resource-cpu\": str(workflow.variables.server_cpu),\n \"resource-mem\": str(workflow.variables.server_mem),\n },\n },\n \"spec\": {\n \"activeDeadlineSeconds\": int(workflow.variables.active_deadline_seconds),\n \"cleanPodPolicy\": \"All\",\n \"fedReplicaSpecs\": {\n \"Worker\": {\n \"replicas\": 1,\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"image\": system.variables.image_repo + \"/pp_lite:\" + system.version,\n \"ports\": [\n {\n \"containerPort\": 32443,\n \"name\": \"flapp-port\"\n },\n {\n \"containerPort\": 50051,\n \"name\": \"server-port\"\n },\n ],\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"JOB_TYPE\",\n \"value\": \"psi-rsa\"\n },\n {\n \"name\": \"ROLE\",\n \"value\": \"server\"\n },\n {\n \"name\": \"PRIVATE_KEY_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/rsa_private.key'\n },\n {\n \"name\": \"INPUT_DIR\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/signed'\n },\n {\n \"name\": \"OUTPUT_DIR\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1]\n },\n {\n \"name\": \"KEY_COLUMN\",\n \"value\": workflow.variables.part_key\n },\n {\n \"name\": \"SIGNED_COLUMN\",\n \"value\": \"signed\"\n },\n {\n \"name\": \"NUM_SIGN_PARALLEL\",\n \"value\": workflow.variables.sign_number_workers\n },\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": \"4096\"\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"METRIC_COLLECTOR_SERVICE_NAME\",\n \"value\": \"pplite_psi\"\n },\n ],\n \"volumeMounts\": system.variables.volume_mounts_list,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\"\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": system.variables.volumes_list,\n \"restartPolicy\": \"Never\"\n }\n }\n }\n },\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ]\n }\n}\n "
+ },
+ {
+ "dependencies": [
+ {
+ "source": "lc-start-server"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "DATA_JOIN",
+ "name": "lc-feature-extraction",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/feature_extraction_v2.py\",\n \"arguments\": [\n \"--original_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/raw',\n \"--joined_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined',\n \"--part_key=\" + workflow.variables.part_key,\n \"--part_num=\" + workflow.variables.part_num,\n \"--file_format=\" + 'csv',\n \"--output_file_format=\" + 'tfrecords',\n \"--output_batch_name=\" + workflow.variables.output_batch_path.split('/')[-1],\n \"--output_dataset_path=\" + workflow.variables.output_dataset_path\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 1,\n \"memory\": '4g',\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [\n { \n \"name\": \"ORIGINAL_DATA_PATH\",\n \"value\": workflow.variables.input_dataset_path\n },\n { \n \"name\": \"ORIGINAL_FILE_WILDCARD\",\n \"value\": \"batch/*/part*\"\n },\n { \n \"name\": \"JOINED_DATA_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined'\n },\n { \n \"name\": \"DATABLOCK_PATH\",\n \"value\": workflow.variables.output_batch_path\n },\n { \n \"name\": \"PART_KEY\",\n \"value\": workflow.variables.part_key\n },\n {\n \"name\": \"FILE_FORMAT\",\n \"value\": 'tfrecords'\n }\n ]\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "lc-feature-extraction"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n \"tabular\",\n \"--data_path=\"+ workflow.variables.output_dataset_path,\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--buckets_num=\" + str(10),\n \"--thumbnail_path=\" + \"\",\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 1,\n \"coreLimit\": \"1200m\",\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.basic_envs_list + system.variables.envs_list + []\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 2,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.basic_envs_list + system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_dataset_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "undefined",
+ "value": "undefined",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"DatasetPath\",\"required\":true,\"tooltip\":\"输入数据\",\"hidden\":false,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_num",
+ "tag": "INPUT_PARAM",
+ "typed_value": "4",
+ "value": "4",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"分区数量\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_key",
+ "tag": "INPUT_PARAM",
+ "typed_value": "raw_id",
+ "value": "raw_id",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"分区以及求交的key\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_bits",
+ "tag": "INPUT_PARAM",
+ "typed_value": "1024",
+ "value": "1024",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"RSA密钥长度\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "min_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最小executor数量\",\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "64",
+ "value": "64",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最大executor数量\",\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "initial_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务初始化executor数量\",\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor核数\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor内存\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "server_cpu",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "16000m",
+ "value": "16000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交服务端cpu数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "server_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "20Gi",
+ "value": "20Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交服务端内存容量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "sign_number_workers",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "30",
+ "value": "30",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交服务端进程数量\",\"hidden\":false,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "active_deadline_seconds",
+ "tag": "INPUT_PARAM",
+ "typed_value": "86400",
+ "value": "86400",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交任务最长运行时间,超过该运行时间求交任务会自动停止。\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_dataset_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_batch_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_batch_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输入batch路径\",\"hidden\":true,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OUTPUT_PATH\",\"hidden\":true}"
+ }
+ ]
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "analyzer": {
+ "meta_yaml": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": ${Slot_labels},\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": ${Slot_image} or system.variables.image_repo + \"/pp_data_inspection:\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": ${Slot_volumes},\n \"mainApplicationFile\": ${Slot_spark_main_file},\n \"arguments\": [\n \"--data_path=\"+ (${Slot_dataset_path} or ${Slot_storage_root_path} + \"/\" + ${Slot_inner_folder_name} + \"/\" + ${Slot_input_job_name}),\n \"--file_wildcard=\" + ${Slot_wildcard},\n \"--buckets_num=\" + str(${Slot_buckets_num}),\n \"--thumbnail_path=\" + ${Slot_thumbnail_path},\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": ${Slot_driver_cores},\n \"coreLimit\": ${Slot_driver_core_limit},\n \"memory\": ${Slot_driver_memory},\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"env\": system.basic_envs_list + system.variables.envs_list + ${Slot_drvier_envs}\n },\n \"executor\": {\n \"cores\": ${Slot_executor_cores},\n \"instances\": ${Slot_executor_instances},\n \"memory\": ${Slot_executor_memory},\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": ${Slot_volume_mounts},\n \"env\": system.basic_envs_list + system.variables.envs_list + ${Slot_executor_envs}\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": ${Slot_initial_executors},\n \"maxExecutors\": ${Slot_max_executors},\n \"minExecutors\": ${Slot_min_executors},\n }\n }\n}\n",
+ "slots": {
+ "Slot_buckets_num": {
+ "default": "",
+ "default_value": 10.0,
+ "help": "用于数据探查时统计直方图的分通数",
+ "label": "直方图分桶数",
+ "reference": "",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_dataset_path": {
+ "default": "",
+ "default_value": "",
+ "help": "用于数据集存储的路径",
+ "label": "数据集存储路径",
+ "reference": "",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_driver_core_limit": {
+ "default": "",
+ "default_value": "1200m",
+ "help": "driver核心数限制",
+ "label": "driver核心数限制",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_driver_cores": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "driver核心数",
+ "label": "driver核心数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_driver_memory": {
+ "default": "",
+ "default_value": "4g",
+ "help": "driver内存",
+ "label": "driver内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_drvier_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "driver环境变量",
+ "label": "driver环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_executor_cores": {
+ "default": "",
+ "default_value": 2.0,
+ "help": "executor核心数",
+ "label": "executor核心数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_executor_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "executor环境变量",
+ "label": "executor环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_executor_instances": {
+ "default": "",
+ "default_value": 2.0,
+ "help": "executor实例数",
+ "label": "executor实例数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_executor_memory": {
+ "default": "",
+ "default_value": "4g",
+ "help": "executor内存",
+ "label": "executor内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "",
+ "help": "镜像地址,建议不填写,默认会使用system.variables.image_repo + '/pp_data_inspection:' + system.version",
+ "label": "镜像",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_initial_executors": {
+ "default": "",
+ "default_value": 2.0,
+ "help": "初始化executor数量",
+ "label": "初始化executor数量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_inner_folder_name": {
+ "default": "",
+ "default_value": "dataset",
+ "help": "为了兼容老的路径的临时Slot,['dataset', 'datasource']",
+ "label": "中间文件夹名",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_input_job_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,求交任务名或数据集名称",
+ "label": "数据集名",
+ "reference": "",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_max_executors": {
+ "default": "",
+ "default_value": 64.0,
+ "help": "初始化executor数量",
+ "label": "最大executor数量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_min_executors": {
+ "default": "",
+ "default_value": 2.0,
+ "help": "初始化executor数量",
+ "label": "最小executor数量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_spark_main_file": {
+ "default": "",
+ "default_value": "/opt/spark/work-dir/analyzer_v2.py",
+ "help": "spark入口脚本",
+ "label": "入口脚本文件",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_thumbnail_path": {
+ "default": "",
+ "default_value": "",
+ "help": "用于存放预览图像的位置",
+ "label": "预览图像位置",
+ "reference": "",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_wildcard": {
+ "default": "",
+ "default_value": "batch/**/*.data",
+ "help": "文件通配符",
+ "label": "文件通配符",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "lc-feature-extraction": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": ${Slot_batch_mode}\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(${Slot_start_time})\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(${Slot_end_time})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n ${Slot_volume_mounts}\n ,\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n ${Slot_volumes}\n\n }\n },\n \"pair\": true,\n \"replicas\": ${Slot_master_replicas}\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(${Slot_data_block_dump_interval})\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(${Slot_data_block_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(${Slot_example_id_dump_interval})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(${Slot_example_id_dump_threshold})\n },\n {\n \"name\": \"MIN_MATCHING_WINDOW\",\n \"value\": str(${Slot_min_matching_window})\n },\n {\n \"name\": \"MAX_MATCHING_WINDOW\",\n \"value\": str(${Slot_max_matching_window})\n },\n {\n \"name\": \"RAW_DATA_ITER\",\n \"value\": ${Slot_raw_data_iter}\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": ${Slot_data_block_builder}\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(${Slot_enable_negative_example_generator})\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\":\n ${Slot_volume_mounts}\n ,\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/data_join/run_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\":\n ${Slot_volumes}\n\n }\n },\n \"pair\": true,\n \"replicas\": ${Slot_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_mode": {
+ "default": "",
+ "default_value": "--batch_mode",
+ "help": "如果为空则为常驻求交",
+ "label": "是否为批处理模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "data block output数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次data block,小于0则无此限制",
+ "label": "数据dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_block_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多多少个样本就dump为一个data block,小于等于0则无此限制",
+ "label": "数据dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_enable_negative_example_generator": {
+ "default": "",
+ "default_value": false,
+ "help": "建议不修改,是否开启负采样,当follower求交时遇到无法匹配上的leader的example id,会以negative_sampling_rate为概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_end_time": {
+ "default": "",
+ "default_value": 999999999999.0,
+ "help": "建议不修改,使用自这个时间以前的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据末尾时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image_version": {
+ "default": "",
+ "default_value": "882310f",
+ "help": "建议不修改,指定Pod中运行的容器镜像版本,前缀为system.variables.image_repo + '/fedlearner:'",
+ "label": "容器镜像版本",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_max_matching_window": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,the max matching window for example join. <=0 means window size is infinite",
+ "label": "最大匹配滑窗",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_min_matching_window": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "建议不修改,the min matching window for example join ,<=0 means window size is infinite",
+ "label": "最小匹配滑窗",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_negative_sampling_rate": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,负采样比例,当follower求交时遇到无法匹配上的leader的example id,会以此概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "建议修改,求交后数据分区的数量,建议和raw_data一致",
+ "label": "数据分区的数量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_iter": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "raw_data文件类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_raw_data_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,原始数据的发布地址,根据参数内容在portal_publish_dir地址下寻找",
+ "label": "raw_data名字",
+ "reference": "",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_start_time": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,使用自这个时间起的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据起始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "lc-sign-raw-data": {
+ "meta_yaml": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": ${Slot_labels},\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": ${Slot_image} or system.variables.image_repo + \"/pp_data_inspection:\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": ${Slot_volumes},\n \"mainApplicationFile\": ${Slot_spark_transformer_file},\n \"arguments\": [\n ${Slot_dataset},\n \"rds/**\",\n str(${Slot_configs})\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": ${Slot_driver_cores},\n \"coreLimit\": ${Slot_driver_core_limit},\n \"memory\": ${Slot_driver_memory},\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": ${Slot_volume_mounts}\n },\n \"executor\": {\n \"cores\": ${Slot_executor_cores},\n \"instances\": ${Slot_executor_instances},\n \"memory\": ${Slot_executor_memory},\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": ${Slot_volume_mounts}\n }\n }\n}\n",
+ "slots": {
+ "Slot_configs": {
+ "default": "",
+ "default_value": {},
+ "help": "使用特征选择组件",
+ "label": "配置",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "OBJECT"
+ },
+ "Slot_dataset": {
+ "default": "",
+ "default_value": "",
+ "help": "",
+ "label": "输入数据集",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_driver_core_limit": {
+ "default": "",
+ "default_value": "1200m",
+ "help": "driver核心数限制",
+ "label": "driver核心数限制",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_driver_cores": {
+ "default": "",
+ "default_value": "1000m",
+ "help": "driver核心数",
+ "label": "driver核心数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_driver_memory": {
+ "default": "",
+ "default_value": "1024m",
+ "help": "driver内存",
+ "label": "driver内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_executor_cores": {
+ "default": "",
+ "default_value": "1000m",
+ "help": "excutor核心数",
+ "label": "excutor核心数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_executor_instances": {
+ "default": "",
+ "default_value": 1.0,
+ "help": "excutor实例数",
+ "label": "excutor实例数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_executor_memory": {
+ "default": "",
+ "default_value": "512m",
+ "help": "excutor内存",
+ "label": "excutor内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "",
+ "help": "镜像地址,建议不填写,默认会使用system.variables.image_repo + '/pp_data_inspection:' + system.version",
+ "label": "镜像",
+ "reference": "system.variables.spark_image",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_spark_transformer_file": {
+ "default": "",
+ "default_value": "transformer.py",
+ "help": "特征工程的脚本",
+ "label": "特征工程脚本文件",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ }
+ },
+ "variables": []
+ },
+ "lc-start-server": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(${Slot_raw_data_metrics_sample_rate})\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(${Slot_raw_data_metrics_sample_rate})\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "Streaming",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image_version": {
+ "default": "",
+ "default_value": "882310f",
+ "help": "建议不修改,指定Pod中运行的容器镜像版本,前缀为system.variables.image_repo + '/fedlearner:'",
+ "label": "容器镜像版本",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_metrics_sample_rate": {
+ "default": "",
+ "default_value": "1",
+ "help": "建议不修改,es metrics 取样比例",
+ "label": "metrics_sample_rate",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ }
+ }
+ },
+ "name": "sys-preset-light-psi-data-join"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-horizontal-eval-model.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-horizontal-eval-model.json
new file mode 100644
index 000000000..de56968ba
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-horizontal-eval-model.json
@@ -0,0 +1,451 @@
+{
+ "name": "sys-preset-nn-horizontal-eval-model",
+ "group_alias": "sys_preset_nn_horizontal_model",
+ "config": {
+ "group_alias": "sys_preset_nn_horizontal_model",
+ "job_definitions": [
+ {
+ "name": "train-job",
+ "job_type": "NN_MODEL_TRANINING",
+ "variables": [
+ {
+ "name": "algorithm",
+ "value": "{\"path\":[],\"config\":[]}",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"AlgorithmSelect\",\"required\":true,\"tag\":\"OPERATING_PARAM\"}",
+ "value_type": "OBJECT",
+ "typed_value": {
+ "config": [],
+ "path": []
+ },
+ "tag": "OPERATING_PARAM"
+ },
+ {
+ "name": "image_version",
+ "value": "50a6945",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "50a6945",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "data_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "load_model_name",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "steps_per_sync",
+ "value": "10",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\",\"tooltip\":\"每隔几步同步一次\"}",
+ "typed_value": "10",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_cpu",
+ "value": "4000m",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4000m",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_mem",
+ "value": "8Gi",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "8Gi",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "load_checkpoint_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\",\n \"min-member\": \"1\",\n \"resource-cpu\": str(self.variables.worker_cpu),\n \"resource-mem\": str(self.variables.worker_mem),\n },\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": \"follower\"\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": self.variables.algorithm.path\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": self.variables.data_path\n },\n {\n \"name\": \"LOAD_MODEL_FROM\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.variables.load_model_name + \"/checkpoints\" if self.variables.load_model_name else self.variables.load_checkpoint_path\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name + \"/exported_models\"\n },\n {\n \"name\": \"MODE\",\n \"value\": \"eval\"\n },\n {\n \"name\": \"FL_STEPS_PER_SYNC\",\n \"value\": self.variables.steps_per_sync\n }\n ] + list(self.variables.algorithm.config),\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + str(self.variables.image_version),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": self.variables.worker_cpu,\n \"memory\": self.variables.worker_mem\n },\n \"requests\": {\n \"cpu\": self.variables.worker_cpu,\n \"memory\": self.variables.worker_mem\n }\n },\n \"command\": [\n \"/app/deploy/scripts/trainer/run_fedavg.sh\"\n ],\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n\n }\n },\n \"pair\": False,\n \"replicas\": int(1)\n }\n }\n }\n}\n",
+ "is_federated": false,
+ "dependencies": [],
+ "easy_mode": false
+ }
+ ],
+ "variables": []
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "train-job": {
+ "slots": {
+ "Slot_worker_cpu": {
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "default_value": "2000m",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_memory": {
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "PS的内存",
+ "default_value": "3Gi",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_mode": {
+ "help": "choices:['train','eval'] 训练还是验证",
+ "label": "模式",
+ "default_value": "train",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "reference": "system.variables.volume_mounts_list",
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "reference_type": "SYSTEM",
+ "label": "卷挂载位置",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_summary_save_steps": {
+ "help": "int, Number of steps to save summary files.",
+ "label": "SUMMARY_SAVE_STEPS",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_start_date": {
+ "help": "training data start date",
+ "label": "开始时间",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_labels": {
+ "reference": "system.variables.labels",
+ "help": "建议不修改,格式: {}",
+ "reference_type": "SYSTEM",
+ "label": "FLAPP额外元信息",
+ "default_value": {},
+ "value_type": "OBJECT",
+ "default": ""
+ },
+ "Slot_epoch_num": {
+ "help": "number of epoch for training, not support in online training",
+ "label": "epoch数量",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_storage_root_path": {
+ "reference": "project.variables.storage_root_path",
+ "help": "联邦学习中任务存储根目录",
+ "reference_type": "PROJECT",
+ "label": "存储根目录",
+ "default_value": "/data",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_master_memory": {
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "default_value": "3Gi",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_volumes": {
+ "reference": "system.variables.volumes_list",
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "reference_type": "SYSTEM",
+ "label": "为Pod提供的卷",
+ "default_value": [
+ {
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ },
+ "name": "data"
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_load_checkpoint_from_job": {
+ "help": "指定任务名job_output下的latest checkpoint",
+ "label": "以任务名加载checkpoint",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_tar": {
+ "help": "代码包,variable中请使用代码类型",
+ "label": "代码",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_end_date": {
+ "help": "training data end date",
+ "label": "结束时间",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_load_checkpoint_filename": {
+ "help": "加载checkpoint_path下的相对路径的checkpoint, 默认会加载checkpoint_path下的latest checkpoint",
+ "label": "LOAD_CHECKPOINT_FILENAME",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_checkpoint_path": {
+ "help": "不建议修改,checkpoint输出路径,建议为空,会默认使用{storage_root_path}/job_output/{job_name}/checkpoints,强烈建议保持空值",
+ "label": "CHECKPOINT_PATH",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_suffle_data_block": {
+ "help": "['','--shuffle_data_block'] 否 是,shuffle the data block or not",
+ "label": "是否shuffle数据块",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_code_key": {
+ "reference": "self.variables.algorithm.path",
+ "help": "代码tar包地址,如果为空则使用code tar",
+ "reference_type": "SELF",
+ "label": "模型代码路径",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_export_path": {
+ "help": "使用默认空值,将把models保存到$OUTPUT_BASE_DIR/exported_models 路径下。",
+ "label": "EXPORT_PATH",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_envs": {
+ "help": "数组类型,ps pod额外的环境变量",
+ "label": "PS额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_worker_memory": {
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "default_value": "3Gi",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_filename_with_path": {
+ "help": "加载绝对路径下的checkpoint,需要细致到文件名",
+ "label": "从绝对路径加载checkpoint",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_replicas": {
+ "help": "同时运行的完全相同的PS Pods数量",
+ "label": "PS的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_save_checkpoint_secs": {
+ "help": "int,Number of secs between checkpoints.",
+ "label": "SAVE_CHECKPOINT_SECS",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_ps_cpu": {
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "PS的CPU",
+ "default_value": "2000m",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "help": "必须修改,求交任务的名字",
+ "label": "数据源",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "default_value": "2000m",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "reference": "self.variables.algorithm.config",
+ "help": "数组类型,worker pod额外的环境变量",
+ "reference_type": "SELF",
+ "label": "Worker额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_worker_replicas": {
+ "help": "同时运行的完全相同的Worker Pods数量",
+ "label": "Worker的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_sparse_estimator": {
+ "help": "bool,default False Whether using sparse estimator.",
+ "label": "SPARSE_ESTIMATOR",
+ "default_value": false,
+ "value_type": "BOOL",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_verbosity": {
+ "help": "int, Logging level",
+ "label": "日志等级",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_role": {
+ "reference": "self.variables.undefined",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "reference_type": "SELF",
+ "label": "Flapp通讯时角色",
+ "default_value": "Leader",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_image": {
+ "reference": "self.variables.image_version",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模板不适用",
+ "reference_type": "SELF",
+ "label": "容器镜像",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_online_training": {
+ "help": "['','--online_training'] 否 是,the train master run for online training",
+ "label": "是否在线训练",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_save_checkpoint_steps": {
+ "help": "int, Number of steps between checkpoints.",
+ "label": "SAVE_CHECKPOINT_STEPS",
+ "default_value": 1000.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ }
+ },
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(${Slot_epoch_num})\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(${Slot_start_date})\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(${Slot_end_date})\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": ${Slot_online_training}\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": ${Slot_checkpoint_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": ${Slot_load_checkpoint_filename}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": ${Slot_load_checkpoint_filename_with_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": ${Slot_load_checkpoint_from_job} and ${Slot_storage_root_path} + \"/job_output/\" + ${Slot_load_checkpoint_from_job} + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": ${Slot_export_path}\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_master_replicas})\n },\n \"PS\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + ${Slot_ps_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_ps_replicas})\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(${Slot_save_checkpoint_steps})\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(${Slot_save_checkpoint_secs})\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(${Slot_summary_save_steps})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_trainer_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_worker_replicas})\n }\n }\n }\n}\n",
+ "variables": []
+ }
+ }
+ },
+ "comment": ""
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-horizontal-model.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-horizontal-model.json
new file mode 100644
index 000000000..fa318b43f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-horizontal-model.json
@@ -0,0 +1,460 @@
+{
+ "name": "sys-preset-nn-horizontal-model",
+ "group_alias": "sys_preset_nn_horizontal_model",
+ "config": {
+ "group_alias": "sys_preset_nn_horizontal_model",
+ "job_definitions": [
+ {
+ "name": "train-job",
+ "job_type": "NN_MODEL_TRANINING",
+ "variables": [
+ {
+ "name": "algorithm",
+ "value": "{\"path\":[],\"config\":[]}",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"AlgorithmSelect\",\"required\":true,\"tag\":\"OPERATING_PARAM\"}",
+ "value_type": "OBJECT",
+ "typed_value": {
+ "path": [],
+ "config": []
+ },
+ "tag": "OPERATING_PARAM"
+ },
+ {
+ "name": "image_version",
+ "value": "50a6945",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "50a6945",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "role",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"],\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "data_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "steps_per_sync",
+ "value": "10",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "10",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_cpu",
+ "value": "4000m",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4000m",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_mem",
+ "value": "8Gi",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "8Gi",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "epoch_num",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "1",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "verbosity",
+ "value": "0",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"enum\":[],\"tooltip\":\"\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "0",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\",\n \"min-member\": \"1\",\n \"resource-cpu\": str(self.variables.worker_cpu),\n \"resource-mem\": str(self.variables.worker_mem),\n },\n },\n \"spec\": {\n \"role\": self.variables.role,\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if self.variables.role==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"ROLE\",\n \"value\": self.variables.role.lower()\n },\n {\n \"name\": \"MODE\",\n \"value\": \"train\"\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": self.variables.algorithm.path\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": self.variables.data_path\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name + \"/exported_models\"\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": self.variables.epoch_num\n },\n {\n \"name\": \"FL_STEPS_PER_SYNC\",\n \"value\": self.variables.steps_per_sync\n }\n ] + list(self.variables.algorithm.config),\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + str(self.variables.image_version),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": self.variables.worker_cpu,\n \"memory\": self.variables.worker_mem\n },\n \"requests\": {\n \"cpu\": self.variables.worker_cpu,\n \"memory\": self.variables.worker_mem\n }\n },\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_fedavg.sh\"\n ]\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n\n }\n },\n \"pair\": True,\n \"replicas\": int(1)\n }\n }\n }\n}\n",
+ "is_federated": false,
+ "dependencies": [],
+ "easy_mode": false
+ }
+ ],
+ "variables": []
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "train-job": {
+ "slots": {
+ "Slot_end_date": {
+ "help": "training data end date",
+ "label": "结束时间",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_role": {
+ "reference": "self.variables.role",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "reference_type": "SELF",
+ "label": "Flapp通讯时角色",
+ "default_value": "Leader",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "reference": "system.variables.labels",
+ "help": "建议不修改,格式: {}",
+ "reference_type": "SYSTEM",
+ "label": "FLAPP额外元信息",
+ "default_value": {},
+ "value_type": "OBJECT",
+ "default": ""
+ },
+ "Slot_save_checkpoint_steps": {
+ "help": "int, Number of steps between checkpoints.",
+ "label": "SAVE_CHECKPOINT_STEPS",
+ "default_value": 1000.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_code_tar": {
+ "help": "代码包,variable中请使用代码类型",
+ "label": "代码",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_filename": {
+ "help": "加载checkpoint_path下的相对路径的checkpoint, 默认会加载checkpoint_path下的latest checkpoint",
+ "label": "LOAD_CHECKPOINT_FILENAME",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_image": {
+ "reference": "self.variables.image_version",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模板不适用",
+ "reference_type": "SELF",
+ "label": "容器镜像",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "help": "必须修改,求交任务的名字",
+ "label": "数据源",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_filename_with_path": {
+ "help": "加载绝对路径下的checkpoint,需要细致到文件名",
+ "label": "从绝对路径加载checkpoint",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_mode": {
+ "help": "choices:['train','eval'] 训练还是验证",
+ "label": "模式",
+ "default_value": "train",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_code_key": {
+ "reference": "self.variables.algorithm.path",
+ "help": "代码tar包地址,如果为空则使用code tar",
+ "reference_type": "SELF",
+ "label": "模型代码路径",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "help": "同时运行的完全相同的Master Pods数量",
+ "label": "Master的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_worker_envs": {
+ "reference": "self.variables.algorithm.config",
+ "help": "数组类型,worker pod额外的环境变量",
+ "reference_type": "SELF",
+ "label": "Worker额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_storage_root_path": {
+ "reference": "project.variables.storage_root_path",
+ "help": "联邦学习中任务存储根目录",
+ "reference_type": "PROJECT",
+ "label": "存储根目录",
+ "default_value": "/data",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "default_value": "2000m",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_suffle_data_block": {
+ "help": "['','--shuffle_data_block'] 否 是,shuffle the data block or not",
+ "label": "是否shuffle数据块",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_cpu": {
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "default_value": "2000m",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_replicas": {
+ "help": "同时运行的完全相同的Worker Pods数量",
+ "label": "Worker的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_volume_mounts": {
+ "reference": "system.variables.volume_mounts_list",
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "reference_type": "SYSTEM",
+ "label": "卷挂载位置",
+ "default_value": [
+ {
+ "name": "data",
+ "mountPath": "/data"
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_ps_replicas": {
+ "help": "同时运行的完全相同的PS Pods数量",
+ "label": "PS的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_start_date": {
+ "help": "training data start date",
+ "label": "开始时间",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_volumes": {
+ "reference": "system.variables.volumes_list",
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "reference_type": "SYSTEM",
+ "label": "为Pod提供的卷",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_load_checkpoint_from_job": {
+ "help": "指定任务名job_output下的latest checkpoint",
+ "label": "以任务名加载checkpoint",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_sparse_estimator": {
+ "help": "bool,default False Whether using sparse estimator.",
+ "label": "SPARSE_ESTIMATOR",
+ "default_value": false,
+ "value_type": "BOOL",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_verbosity": {
+ "help": "int, Logging level",
+ "label": "日志等级",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_ps_envs": {
+ "help": "数组类型,ps pod额外的环境变量",
+ "label": "PS额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_export_path": {
+ "help": "使用默认空值,将把models保存到$OUTPUT_BASE_DIR/exported_models 路径下。",
+ "label": "EXPORT_PATH",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_memory": {
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "PS的内存",
+ "default_value": "3Gi",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_memory": {
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "default_value": "3Gi",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_online_training": {
+ "help": "['','--online_training'] 否 是,the train master run for online training",
+ "label": "是否在线训练",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_cpu": {
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "PS的CPU",
+ "default_value": "2000m",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_summary_save_steps": {
+ "help": "int, Number of steps to save summary files.",
+ "label": "SUMMARY_SAVE_STEPS",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_save_checkpoint_secs": {
+ "help": "int,Number of secs between checkpoints.",
+ "label": "SAVE_CHECKPOINT_SECS",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_worker_memory": {
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "default_value": "3Gi",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_checkpoint_path": {
+ "help": "不建议修改,checkpoint输出路径,建议为空,会默认使用{storage_root_path}/job_output/{job_name}/checkpoints,强烈建议保持空值",
+ "label": "CHECKPOINT_PATH",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_epoch_num": {
+ "help": "number of epoch for training, not support in online training",
+ "label": "epoch数量",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ }
+ },
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(${Slot_epoch_num})\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(${Slot_start_date})\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(${Slot_end_date})\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": ${Slot_online_training}\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": ${Slot_checkpoint_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": ${Slot_load_checkpoint_filename}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": ${Slot_load_checkpoint_filename_with_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": ${Slot_load_checkpoint_from_job} and ${Slot_storage_root_path} + \"/job_output/\" + ${Slot_load_checkpoint_from_job} + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": ${Slot_export_path}\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_master_replicas})\n },\n \"PS\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + ${Slot_ps_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_ps_replicas})\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(${Slot_save_checkpoint_steps})\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(${Slot_save_checkpoint_secs})\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(${Slot_summary_save_steps})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_trainer_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_worker_replicas})\n }\n }\n }\n}\n",
+ "variables": []
+ }
+ }
+ },
+ "comment": ""
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-model.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-model.json
new file mode 100644
index 000000000..2d5b5a582
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-nn-model.json
@@ -0,0 +1,667 @@
+{
+ "name": "sys-preset-nn-model",
+ "group_alias": "sys_preset_nn_model",
+ "config": {
+ "group_alias": "sys_preset_nn_model",
+ "job_definitions": [
+ {
+ "name": "nn-model",
+ "job_type": "NN_MODEL_TRANINING",
+ "is_federated": true,
+ "variables": [
+ {
+ "name": "master_cpu",
+ "value": "3000m",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "3000m",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "master_mem",
+ "value": "4Gi",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4Gi",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_cpu",
+ "value": "2000m",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "2000m",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_mem",
+ "value": "4Gi",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4Gi",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "ps_replicas",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "1",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "master_replicas",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "1",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "ps_cpu",
+ "value": "2000m",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "2000m",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "ps_mem",
+ "value": "4Gi",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "4Gi",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_replicas",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "1",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "data_source",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "epoch_num",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "1",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "shuffle_data_block",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "verbosity",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":false,\"enum\":[\"0\",\"1\",\"2\"],\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "1",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "mode",
+ "value": "train",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"train\",\"eval\"],\"tag\":\"OPERATING_PARAM\"}",
+ "typed_value": "train",
+ "tag": "OPERATING_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "save_checkpoint_secs",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "save_checkpoint_steps",
+ "value": "1000",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "1000",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "load_checkpoint_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "load_checkpoint_filename_with_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "role",
+ "value": "Leader",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"],\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "Leader",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "load_model_name",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "algorithm",
+ "value": "{\"config\":[],\"path\":\"\"}",
+ "access_mode": "PEER_READABLE",
+ "widget_schema": "{\"component\":\"AlgorithmSelect\",\"required\":true,\"tag\":\"OPERATING_PARAM\"}",
+ "value_type": "OBJECT",
+ "typed_value": {
+ "config": [],
+ "path": ""
+ },
+ "tag": "OPERATING_PARAM"
+ },
+ {
+ "name": "image_version",
+ "value": "50a6945",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"镜像版本\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "50a6945",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "start_date",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"数据开始时间\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "end_date",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"数据结束时间\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "export_model",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"可不填,或输入true或false。如果不填,那么训练任务会export模型,如果填,则根据输入值决定是否export\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "data_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"和 data_source 二选一,输入输入数据集所在的路径。\",\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "data_path_wildcard",
+ "value": "**/part*",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"配合 data_path 使用,输入数据集文件的 wildcard,默认是 **/part*。\",\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "**/part*",
+ "tag": "INPUT_PATH",
+ "value_type": "STRING"
+ },
+ {
+ "name": "fedapp_active_ttl",
+ "value": "86400",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OPERATING_PARAM\",\"tooltip\":\"运行最大时长,单位秒\"}",
+ "typed_value": "86400",
+ "tag": "OPERATING_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "local_worker_replicas",
+ "value": "0",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\",\"tooltip\":\"local worker数量\"}",
+ "typed_value": "0",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "local_data_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "sparse_estimator",
+ "value": "false",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"是否使用sparse_estimator;true为使用,false为不使用。\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "false",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_backoff_limit",
+ "value": "6",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"worker最大重试次数\"}",
+ "typed_value": "6",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\",\n \"min-member\": \"1\",\n \"resource-cpu\": str(self.variables.worker_cpu),\n \"resource-mem\": str(self.variables.worker_mem),\n },\n },\n \"spec\": {\n \"activeDeadlineSeconds\": int(self.variables.fedapp_active_ttl),\n \"fedReplicaSpecs\": {\n \"Master\": {\n \"backoffLimit\": 1,\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(self.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": str(self.variables.mode)\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(int(self.variables.epoch_num))\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(int(self.variables.start_date))\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(int(self.variables.end_date))\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": str(self.variables.data_source)\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": str(self.variables.data_path)\n },\n {\n \"name\": \"DATA_PATH_WILDCARD\",\n \"value\": str(self.variables.data_path_wildcard)\n },\n {\n \"name\": \"EXPORT_MODEL\",\n \"value\": str(self.variables.export_model)\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": \"\"\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(bool(self.variables.sparse_estimator))\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": str(self.variables.algorithm.path)\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": str(self.variables.load_checkpoint_filename_with_path)\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": str(project.variables.storage_root_path) + \"/job_output/\" + str(self.variables.load_model_name) + \"/checkpoints\" if str(self.variables.load_model_name) else str(self.variables.load_checkpoint_path)\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": \"\"\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(int(self.variables.save_checkpoint_steps))\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(int(self.variables.save_checkpoint_secs))\n },\n {\n \"name\": \"METRIC_COLLECTOR_ENABLE\",\n \"value\": str(True)\n },\n {\n \"name\": \"METRIC_COLLECTOR_SERVICE_NAME\",\n \"value\": \"fedlearner_model\"\n },\n {\n \"name\": \"LOCAL_DATA_PATH\",\n \"value\": self.variables.local_data_path\n },\n {\n \"name\": \"SHUFFLE_IN_DAY\",\n \"value\": \"true\"\n }\n ] + list(self.variables.algorithm.config),\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + str(self.variables.image_version),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.master_cpu),\n \"memory\": str(self.variables.master_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.master_cpu),\n \"memory\": str(self.variables.master_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n\n }\n },\n \"replicas\": int(int(self.variables.master_replicas))\n },\n \"PS\": {\n \"mustSuccess\": False,\n \"backoffLimit\": 1,\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"METRIC_COLLECTOR_ENABLE\",\n \"value\": str(True)\n },\n {\n \"name\": \"METRIC_COLLECTOR_SERVICE_NAME\",\n \"value\": \"fedlearner_model\"\n },\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + str(self.variables.image_version),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.ps_cpu),\n \"memory\": str(self.variables.ps_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.ps_cpu),\n \"memory\": str(self.variables.ps_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(int(self.variables.ps_replicas))\n },\n \"Worker\": {\n \"mustSuccess\": True,\n \"backoffLimit\": int(self.variables.worker_backoff_limit),\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(self.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": str(self.variables.mode)\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(int(self.variables.verbosity))\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": str(self.variables.algorithm.path)\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": \"\"\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(bool(self.variables.sparse_estimator))\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(None)\n },\n {\n \"name\": \"METRIC_COLLECTOR_ENABLE\",\n \"value\": str(True)\n },\n {\n \"name\": \"METRIC_COLLECTOR_SERVICE_NAME\",\n \"value\": \"fedlearner_model\"\n },\n ] + list(self.variables.algorithm.config),\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + str(self.variables.image_version),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\":[\"/bin/bash\",\"-c\"],\n \"args\": [\"export WORKER_RANK=$$INDEX && export PEER_ADDR=$$SERVICE_ID && /app/deploy/scripts/trainer/run_trainer_worker.sh\"],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(int(self.variables.worker_replicas))\n },\n \"LocalWorker\": {\n \"backoffLimit\": 6,\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"name\": \"localworker\",\n \"image\": system.variables.image_repo + \"/fedlearner:\" + str(self.variables.image_version),\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"command\": [\"/app/deploy/scripts/trainer/run_trainer_local_worker.sh\"],\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\":\"FL_STATS_URL\",\n \"value\":\"udp://statsd-v1-service.fedlearner:8125\"\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(self.variables.role).lower()\n },\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": project.variables.storage_root_path\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": project.variables.storage_root_path + \"/job_output/\" + self.name\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": str(self.variables.algorithm.path)\n },\n {\n \"name\": \"MODE\",\n \"value\": self.variables.mode\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": self.variables.sparse_estimator\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(bool(self.variables.sparse_estimator))\n },\n {\n \"name\": \"METRIC_COLLECTOR_ENABLE\",\n \"value\": str(True)\n },\n {\n \"name\": \"METRIC_COLLECTOR_SERVICE_NAME\",\n \"value\": \"fedlearner_model\"\n },\n ] + system.variables.envs_list + list(self.variables.algorithm.config),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": self.variables.worker_cpu,\n \"memory\": self.variables.worker_mem\n },\n \"requests\": {\n \"cpu\": self.variables.worker_cpu,\n \"memory\": self.variables.worker_mem\n }\n }\n },\n ],\n \"imagePullSecrets\": [{\"name\": \"regcred\"}],\n \"volumes\": list(system.variables.volumes_list) + []\n }\n },\n \"replicas\": int(self.variables.local_worker_replicas)\n }\n }\n }\n}\n",
+ "dependencies": [],
+ "easy_mode": false
+ }
+ ],
+ "variables": []
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "nn-model": {
+ "slots": {
+ "Slot_start_date": {
+ "reference": "self.variables.start_date",
+ "help": "training data start date",
+ "reference_type": "SELF",
+ "label": "开始时间",
+ "default_value": null,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_worker_replicas": {
+ "reference": "self.variables.worker_replicas",
+ "help": "同时运行的完全相同的Worker Pods数量",
+ "reference_type": "SELF",
+ "label": "Worker的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_load_checkpoint_filename": {
+ "reference": "self.variables.load_checkpoint_path",
+ "help": "加载checkpoint_path下的相对路径的checkpoint, 默认会加载checkpoint_path下的latest checkpoint",
+ "reference_type": "SELF",
+ "label": "LOAD_CHECKPOINT_FILENAME",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_checkpoint_path": {
+ "help": "不建议修改,checkpoint输出路径,建议为空,会默认使用{storage_root_path}/job_output/{job_name}/checkpoints,强烈建议保持空值",
+ "label": "CHECKPOINT_PATH",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "reference": "system.variables.volume_mounts_list",
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "reference_type": "SYSTEM",
+ "label": "卷挂载位置",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_ps_envs": {
+ "help": "数组类型,ps pod额外的环境变量",
+ "label": "PS额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_role": {
+ "reference": "self.variables.role",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "reference_type": "SELF",
+ "label": "Flapp通讯时角色",
+ "default_value": "Leader",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "reference": "system.variables.labels",
+ "help": "建议不修改,格式: {}",
+ "reference_type": "SYSTEM",
+ "label": "FLAPP额外元信息",
+ "default_value": {},
+ "value_type": "OBJECT",
+ "default": ""
+ },
+ "Slot_mode": {
+ "reference": "self.variables.mode",
+ "help": "choices:['train','eval'] 训练还是验证",
+ "reference_type": "SELF",
+ "label": "模式",
+ "default_value": "train",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_ps_cpu": {
+ "reference": "self.variables.ps_cpu",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "reference_type": "SELF",
+ "label": "PS的CPU",
+ "default_value": "2000m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_epoch_num": {
+ "reference": "self.variables.epoch_num",
+ "help": "number of epoch for training, not support in online training",
+ "reference_type": "SELF",
+ "label": "epoch数量",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_code_tar": {
+ "help": "代码包,variable中请使用代码类型",
+ "label": "代码",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_worker_memory": {
+ "reference": "self.variables.worker_mem",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "reference_type": "SELF",
+ "label": "Worker的内存",
+ "default_value": "3Gi",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "reference": "project.variables.storage_root_path",
+ "help": "联邦学习中任务存储根目录",
+ "reference_type": "PROJECT",
+ "label": "存储根目录",
+ "default_value": "/data",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_ps_memory": {
+ "reference": "self.variables.ps_mem",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "reference_type": "SELF",
+ "label": "PS的内存",
+ "default_value": "3Gi",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_save_checkpoint_secs": {
+ "reference": "self.variables.save_checkpoint_secs",
+ "help": "int,Number of secs between checkpoints.",
+ "reference_type": "SELF",
+ "label": "SAVE_CHECKPOINT_SECS",
+ "default_value": null,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_online_training": {
+ "help": "['','--online_training'] 否 是,the train master run for online training",
+ "label": "是否在线训练",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "reference": "self.variables.master_cpu",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "reference_type": "SELF",
+ "label": "Master的CPU",
+ "default_value": "2000m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_master_replicas": {
+ "reference": "self.variables.master_replicas",
+ "help": "同时运行的完全相同的Master Pods数量",
+ "reference_type": "SELF",
+ "label": "Master的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_export_path": {
+ "help": "使用默认空值,将把models保存到$OUTPUT_BASE_DIR/exported_models 路径下。",
+ "label": "EXPORT_PATH",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_ps_replicas": {
+ "reference": "self.variables.ps_replicas",
+ "help": "同时运行的完全相同的PS Pods数量",
+ "reference_type": "SELF",
+ "label": "PS的Pod个数",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_verbosity": {
+ "reference": "self.variables.verbosity",
+ "help": "int, Logging level",
+ "reference_type": "SELF",
+ "label": "日志等级",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_summary_save_steps": {
+ "help": "int, Number of steps to save summary files.",
+ "label": "SUMMARY_SAVE_STEPS",
+ "default_value": null,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_load_checkpoint_filename_with_path": {
+ "reference": "self.variables.load_checkpoint_filename_with_path",
+ "help": "加载绝对路径下的checkpoint,需要细致到文件名",
+ "reference_type": "SELF",
+ "label": "从绝对路径加载checkpoint",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_load_checkpoint_from_job": {
+ "reference": "self.variables.load_model_name",
+ "help": "指定任务名job_output下的latest checkpoint",
+ "reference_type": "SELF",
+ "label": "以任务名加载checkpoint",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_worker_cpu": {
+ "reference": "self.variables.worker_cpu",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "reference_type": "SELF",
+ "label": "Worker的CPU",
+ "default_value": "2000m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_sparse_estimator": {
+ "reference": "self.variables.sparse_estimator",
+ "help": "bool,default False Whether using sparse estimator.",
+ "reference_type": "SELF",
+ "label": "SPARSE_ESTIMATOR",
+ "default_value": false,
+ "value_type": "BOOL",
+ "default": ""
+ },
+ "Slot_image_version": {
+ "reference": "self.variables.image_version",
+ "help": "建议不修改,指定Pod中运行的容器镜像版本,前缀为system.variables.image_repo + '/fedlearner:'",
+ "reference_type": "SELF",
+ "label": "容器镜像版本",
+ "default_value": "882310f",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_data_source": {
+ "reference": "self.variables.data_source",
+ "help": "必须修改,求交任务的名字",
+ "reference_type": "SELF",
+ "label": "数据源",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_suffle_data_block": {
+ "reference": "self.variables.shuffle_data_block",
+ "help": "['','--shuffle_data_block'] 否 是,shuffle the data block or not",
+ "reference_type": "SELF",
+ "label": "是否shuffle数据块",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_code_key": {
+ "reference": "self.variables.algorithm.path",
+ "help": "代码tar包地址,如果为空则使用code tar",
+ "reference_type": "SELF",
+ "label": "模型代码路径",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "reference": "self.variables.algorithm.config",
+ "help": "数组类型,worker pod额外的环境变量",
+ "reference_type": "SELF",
+ "label": "Worker额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_end_date": {
+ "reference": "self.variables.end_date",
+ "help": "training data end date",
+ "reference_type": "SELF",
+ "label": "结束时间",
+ "default_value": null,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_master_memory": {
+ "reference": "self.variables.master_mem",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "reference_type": "SELF",
+ "label": "Master的内存",
+ "default_value": "3Gi",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_save_checkpoint_steps": {
+ "reference": "self.variables.save_checkpoint_steps",
+ "help": "int, Number of steps between checkpoints.",
+ "reference_type": "SELF",
+ "label": "SAVE_CHECKPOINT_STEPS",
+ "default_value": 1000.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_volumes": {
+ "reference": "system.variables.volumes_list",
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "reference_type": "SYSTEM",
+ "label": "为Pod提供的卷",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_master_envs": {
+ "reference": "self.variables.algorithm.config",
+ "help": "数组类型,master pod额外的环境变量",
+ "reference_type": "SELF",
+ "label": "Master额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "default": ""
+ }
+ },
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"EPOCH_NUM\",\n \"value\": str(${Slot_epoch_num})\n },\n {\n \"name\": \"START_DATE\",\n \"value\": str(${Slot_start_date})\n },\n {\n \"name\": \"END_DATE\",\n \"value\": str(${Slot_end_date})\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"ONLINE_TRAINING\",\n \"value\": ${Slot_online_training}\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"CHECKPOINT_PATH\",\n \"value\": ${Slot_checkpoint_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME\",\n \"value\": ${Slot_load_checkpoint_filename}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_FILENAME_WITH_PATH\",\n \"value\": ${Slot_load_checkpoint_filename_with_path}\n },\n {\n \"name\": \"LOAD_CHECKPOINT_PATH\",\n \"value\": ${Slot_load_checkpoint_from_job} and ${Slot_storage_root_path} + \"/job_output/\" + ${Slot_load_checkpoint_from_job} + \"/checkpoints\"\n },\n {\n \"name\": \"EXPORT_PATH\",\n \"value\": ${Slot_export_path}\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_master_replicas})\n },\n \"PS\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n }\n\n ] + ${Slot_ps_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/trainer/run_trainer_ps.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_ps_cpu},\n \"memory\": ${Slot_ps_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": int(${Slot_ps_replicas})\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"CODE_KEY\",\n \"value\": ${Slot_code_key}\n },\n {\n \"name\": \"CODE_TAR\",\n \"value\": ${Slot_code_tar}\n },\n {\n \"name\": \"SAVE_CHECKPOINT_STEPS\",\n \"value\": str(${Slot_save_checkpoint_steps})\n },\n {\n \"name\": \"SAVE_CHECKPOINT_SECS\",\n \"value\": str(${Slot_save_checkpoint_secs})\n },\n {\n \"name\": \"SPARSE_ESTIMATOR\",\n \"value\": str(${Slot_sparse_estimator})\n },\n {\n \"name\": \"SUMMARY_SAVE_STEPS\",\n \"value\": str(${Slot_summary_save_steps})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_trainer_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_worker_replicas})\n }\n }\n }\n}\n",
+ "variables": []
+ }
+ }
+ },
+ "comment": ""
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-ot-psi-analyzer.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-ot-psi-analyzer.json
new file mode 100644
index 000000000..d60d6626f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-ot-psi-analyzer.json
@@ -0,0 +1,223 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "ot-psi",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "partition-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/partition.py\",\n \"arguments\": [\n \"--input_path=\" + workflow.variables.input_batch_path + '/' + 'part*',\n \"--file_format=\" + workflow.variables.file_format,\n \"--part_key=\" + workflow.variables.part_key,\n \"--part_num=\" + workflow.variables.part_num,\n \"--output_file_format=\" + workflow.variables.file_format,\n \"--output_dir=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1]\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 2,\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "partition-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": true,
+ "job_type": "PSI_DATA_JOIN",
+ "name": "ot-psi",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"fedReplicaSpecs\": {\n \"Worker\": {\n \"backoffLimit\": 5,\n \"port\": {\n \"containerPort\": 32443,\n \"name\": \"flapp-port\"\n },\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"name\": \"psi\",\n \"image\": system.variables.image_repo + \"/pp_lite:\" + system.version,\n \"env\": system.variables.envs_list + system.basic_envs_list + [\n {\n \"name\": \"ROLE\",\n \"value\": workflow.variables.role\n },\n {\n \"name\": \"JOB_TYPE\",\n \"value\": \"psi-ot\"\n },\n {\n \"name\": \"PEER_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"AUTHORITY\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"PEER_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"INPUT_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/ids'\n },\n {\n \"name\": \"OUTPUT_PATH\",\n \"value\": workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined'\n },\n {\n \"name\": \"KEY_COLUMN\",\n \"value\": workflow.variables.part_key\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"METRIC_COLLECTOR_SERVICE_NAME\",\n \"value\": \"pplite_psi\"\n }\n ],\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"resources\": {\n \"limits\": {\n \"cpu\": workflow.variables.worker_cpu,\n \"memory\": workflow.variables.worker_mem\n },\n \"requests\": {\n \"cpu\": workflow.variables.worker_cpu,\n \"memory\": workflow.variables.worker_mem\n }\n },\n \"ports\": [\n {\n \"containerPort\": 32443,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50051,\n \"name\": \"server-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tunnel-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 1212,\n \"name\": \"joiner-port\",\n \"protocol\": \"TCP\"\n }\n ],\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list),\n }\n },\n \"pair\": True,\n \"replicas\": int(workflow.variables.replicas)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "ot-psi"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "feature-extraction",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/feature_extraction_v2.py\",\n \"arguments\": [\n \"--original_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/raw',\n \"--joined_data_path=\" + workflow.variables.output_dataset_path + '/side_output/' + workflow.variables.output_batch_path.split('/')[-1] + '/joined',\n \"--part_key=\" + workflow.variables.part_key,\n \"--part_num=\" + workflow.variables.part_num,\n \"--file_format=\" + workflow.variables.file_format,\n \"--output_file_format=\" + workflow.variables.output_file_format,\n \"--output_batch_name=\" + workflow.variables.output_batch_path.split('/')[-1],\n \"--output_dataset_path=\" + workflow.variables.output_dataset_path\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 2,\n \"memory\": '4g',\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "feature-extraction"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n \"tabular\",\n \"--data_path=\"+ (str(workflow.variables.output_dataset_path) or str(project.variables.storage_root_path) + \"/\" + \"dataset\" + \"/\" + \"\"),\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--buckets_num=\" + str(10),\n \"--thumbnail_path=\" + \"\",\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 1,\n \"coreLimit\": \"1200m\",\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.basic_envs_list + system.variables.envs_list + []\n },\n \"executor\": {\n \"cores\": int(workflow.variables.executor_cores),\n \"instances\": 1,\n \"memory\": workflow.variables.executor_mem,\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": int(workflow.variables.initial_executor_num),\n \"minExecutors\": int(workflow.variables.min_executor_num),\n \"maxExecutors\": int(workflow.variables.max_executor_num)\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_dataset_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "undefined",
+ "value": "undefined",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"DatasetPath\",\"required\":true,\"tooltip\":\"输入数据地址\",\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_key",
+ "tag": "INPUT_PARAM",
+ "typed_value": "raw_id",
+ "value": "raw_id",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交的key\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "part_num",
+ "tag": "INPUT_PARAM",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"分区数量\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_batch_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输出数据batch地址\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "output_dataset_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输出数据集地址\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "role",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "undefined",
+ "value": "undefined",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"tooltip\":\"OtPsi角色\",\"enum\":[\"client\",\"server\"],\"hidden\":false,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "replicas",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "10",
+ "value": "10",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"求交worker数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_cpu",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2000m",
+ "value": "2000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4Gi",
+ "value": "4Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "input_batch_path",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"输入batch地址\",\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_cores",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor核数\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "executor_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4g",
+ "value": "4g",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"executor内存\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "initial_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务初始化executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "min_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2",
+ "value": "2",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最小executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "max_executor_num",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "64",
+ "value": "64",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"spark任务最大executor数量\",\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_format",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "tfrecords",
+ "value": "tfrecords",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"输入数据格式,支持csv或tfrecords\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_file_format",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "tfrecords",
+ "value": "tfrecords",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"输出数据格式,支持csv或tfrecords\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true,\"tag\":\"OUTPUT_PATH\"}"
+ }
+ ]
+ },
+ "editor_info": {},
+ "name": "sys-preset-ot-psi-analyzer",
+ "revision_index": 32
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-psi-data-join-analyzer.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-psi-data-join-analyzer.json
new file mode 100644
index 000000000..e7fc038be
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-psi-data-join-analyzer.json
@@ -0,0 +1,240 @@
+{
+ "comment": "",
+ "config": {
+ "group_alias": "sys_preset_psi_data_join",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"activeDeadlineSeconds\": int(workflow.variables.fedapp_active_ttl),\n \"fedReplicaSpecs\": {\n \"Master\": {\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": str(workflow.variables.data_portal_type)\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": str(workflow.variables.dataset)\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(workflow.variables.file_wildcard)\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + workflow.variables.fedlearner_image_version,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(workflow.variables.master_cpu),\n \"memory\": str(workflow.variables.master_mem)\n },\n \"requests\": {\n \"cpu\": str(workflow.variables.master_cpu),\n \"memory\": str(workflow.variables.master_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": 1\n },\n \"Worker\": {\n \"backoffLimit\": 6,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(int(workflow.variables.batch_size))\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + workflow.variables.fedlearner_image_version,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(workflow.variables.raw_worker_cpu),\n \"memory\": str(workflow.variables.raw_worker_mem)\n },\n \"requests\": {\n \"cpu\": str(workflow.variables.raw_worker_cpu),\n \"memory\": str(workflow.variables.raw_worker_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "raw-data-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": true,
+ "job_type": "TRANSFORMER",
+ "name": "psi-data-join-job",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"activeDeadlineSeconds\": int(workflow.variables.fedapp_active_ttl),\n \"fedReplicaSpecs\": {\n \"Master\": {\n \"mustSuccess\": False,\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(0)\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(999999999999)\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job'].name)\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n {\n \"name\": \"DATA_JOIN_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + workflow.variables.fedlearner_image_version,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\":[\"/bin/bash\",\"-c\"],\n \"args\": [\"export PEER_ADDR=$$SERVICE_ID && /app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"],\n \n \"resources\": {\n \"limits\": {\n \"cpu\": str(workflow.variables.master_cpu),\n \"memory\": str(workflow.variables.master_mem)\n },\n \"requests\": {\n \"cpu\": str(workflow.variables.master_cpu),\n \"memory\": str(workflow.variables.master_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": 1\n },\n \"Worker\": {\n \"backoffLimit\": 6,\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(workflow.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": \"--batch_mode\"\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job'].name)\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": str(workflow.variables.rsa_key_pem)\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": str(workflow.variables.rsa_key_path)\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": str(workflow.variables.rsa_key_path)\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": \"data.aml.fl\"\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": str(workflow.variables.output_type)\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": str(workflow.variables.output_type)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(None)\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(None)\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"DATA_JOIN_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + workflow.variables.fedlearner_image_version,\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\":[\"/bin/bash\",\"-c\"],\n \"args\": [\"export WORKER_RANK=$$INDEX && export PEER_ADDR=$$SERVICE_ID && /app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(workflow.variables.psi_worker_cpu),\n \"memory\": str(workflow.variables.psi_worker_mem)\n },\n \"requests\": {\n \"cpu\": str(workflow.variables.psi_worker_cpu),\n \"memory\": str(workflow.variables.psi_worker_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": int(int(workflow.variables.num_partitions))\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "psi-data-join-job"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "TRANSFORMER",
+ "name": "converter",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/converter_v2.py\",\n \"arguments\": [\n \"tabular\",\n \"--output_dataset_path=\" + workflow.variables.output_dataset_path, \n \"--output_batch_path=\" + workflow.variables.output_batch_path,\n \"--input_batch_path=\" + str(project.variables.storage_root_path) + \"/data_source/\" + workflow.jobs['psi-data-join-job'].name + \"/data_block/**/*.data\",\n \"--format=tfrecords\",\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20,\n },\n \"driver\": {\n \"cores\": 1,\n \"coreLimit\": \"1200m\",\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"executor\": {\n \"cores\": 2,\n \"instances\": 2,\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + [],\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": 2,\n \"maxExecutors\": 64,\n \"minExecutors\": 2,\n },\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "converter"
+ }
+ ],
+ "easy_mode": false,
+ "is_federated": false,
+ "job_type": "ANALYZER",
+ "name": "analyzer",
+ "variables": [],
+ "yaml_template": "{\n \"apiVersion\": \"sparkoperator.k8s.io/v1beta2\",\n \"kind\": \"SparkApplication\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner-spark\",\n \"schedulerName\": \"batch\",\n },\n },\n \"spec\": {\n \"type\": \"Python\",\n \"pythonVersion\": \"3\",\n \"mode\": \"cluster\",\n \"image\": (system.variables.get(\"spark_image_repo\") or str(system.variables.image_repo + \"/pp_data_inspection\")) + \":\" + system.version,\n \"imagePullPolicy\": \"IfNotPresent\",\n \"volumes\": list(system.variables.volumes_list),\n \"mainApplicationFile\": \"/opt/spark/work-dir/analyzer_v2.py\",\n \"arguments\": [\n \"tabular\",\n \"--data_path=\" + workflow.variables.output_dataset_path,\n \"--file_wildcard=\" + \"batch/**/**\",\n \"--batch_name=\" + str(workflow.variables.output_batch_name),\n ],\n \"sparkVersion\": \"3.0.0\",\n \"restartPolicy\": {\n \"type\": \"OnFailure\",\n \"onFailureRetries\": 3,\n \"onFailureRetryInterval\": 10,\n \"onSubmissionFailureRetries\": 5,\n \"onSubmissionFailureRetryInterval\": 20\n },\n \"driver\": {\n \"cores\": 1,\n \"coreLimit\": \"1200m\",\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"serviceAccount\": \"spark\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"executor\": {\n \"cores\": 2,\n \"instances\": 2,\n \"memory\": \"4g\",\n \"labels\": {\n \"version\": \"3.0.0\"\n },\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"env\": system.variables.envs_list + []\n },\n \"dynamicAllocation\": {\n \"enabled\": True,\n \"initialExecutors\": 2,\n \"maxExecutors\": 64,\n \"minExecutors\": 2,\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "fedlearner_image_version",
+ "tag": "INPUT_PARAM",
+ "typed_value": "50a6945",
+ "value": "50a6945",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"镜像版本不建议修改,如若修改请使用新于此版本的镜像\",\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "num_partitions",
+ "tag": "INPUT_PARAM",
+ "typed_value": 2.0,
+ "value": "2",
+ "value_type": "NUMBER",
+ "widget_schema": "{\"component\":\"NumberPicker\",\"required\":true,\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "dataset",
+ "tag": "INPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"DatasetPath\",\"required\":true,\"hidden\":false,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "raw_worker_cpu",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "raw_worker_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "8Gi",
+ "value": "8Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_wildcard",
+ "tag": "INPUT_PATH",
+ "typed_value": "*part*",
+ "value": "*part*",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "batch_size",
+ "tag": "INPUT_PARAM",
+ "typed_value": "102400",
+ "value": "102400",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "data_portal_type",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "PSI",
+ "value": "PSI",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"PSI\"],\"hidden\":true,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "role",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "Leader",
+ "value": "Leader",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"],\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_path",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "/app/deploy/integrated_test/rsa_private.key",
+ "value": "/app/deploy/integrated_test/rsa_private.key",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"RSA公钥或私钥的地址,在无RSA_KEY_PEM时必填,私钥需要同时填写rsa_key_path和rsa_private_key_path,且内容一致\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PRIVATE",
+ "name": "rsa_key_pem",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"TextArea\",\"required\":false,\"tooltip\":\"直接输入RSA公钥和私钥,Leader会从中读取私钥,Follower会从中读取公钥。如果为空会使用path读取。\",\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_type",
+ "tag": "OPERATING_PARAM",
+ "typed_value": "TF_RECORD",
+ "value": "TF_RECORD",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"TF_RECORD\"],\"tooltip\":\"输出的datablock的格式,支持csv和tfrecord两种格式\",\"hidden\":true,\"tag\":\"OPERATING_PARAM\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "psi_worker_cpu",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "psi_worker_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "8Gi",
+ "value": "8Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "master_cpu",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "2000m",
+ "value": "2000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "master_mem",
+ "tag": "RESOURCE_ALLOCATION",
+ "typed_value": "4Gi",
+ "value": "4Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_dataset_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"结果数据集的路径\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_path",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"结果数据集的 batch 的路径\",\"tag\":\"OUTPUT_PATH\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "fedapp_active_ttl",
+ "tag": "",
+ "typed_value": "259200",
+ "value": "259200",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"单个分片或求交任务运行最大时间\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_batch_name",
+ "tag": "OUTPUT_PATH",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"OUTPUT_PATH\",\"hidden\":true}"
+ }
+ ]
+ },
+ "editor_info": {},
+ "name": "sys-preset-psi-data-join-analyzer"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-psi-data-join.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-psi-data-join.json
new file mode 100644
index 000000000..3d862f073
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-psi-data-join.json
@@ -0,0 +1,778 @@
+{
+ "comment": null,
+ "config": {
+ "group_alias": "sys_preset_psi_data_join",
+ "job_definitions": [
+ {
+ "dependencies": [],
+ "easy_mode": true,
+ "is_federated": false,
+ "job_type": "RAW_DATA",
+ "name": "raw-data-job",
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "dataset",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"DatasetPath\",\"required\":true,\"hidden\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "raw_worker_cpu",
+ "tag": "",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "raw_worker_mem",
+ "tag": "",
+ "typed_value": "8Gi",
+ "value": "8Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "file_wildcard",
+ "tag": "",
+ "typed_value": "*part*",
+ "value": "*part*",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"hidden\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "batch_size",
+ "tag": "",
+ "typed_value": "102400",
+ "value": "102400",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "data_portal_type",
+ "tag": "",
+ "typed_value": "PSI",
+ "value": "PSI",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":false,\"enum\":[\"PSI\"],\"hidden\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "raw_master_cpu",
+ "tag": "",
+ "typed_value": "2000m",
+ "value": "2000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "raw_master_mem",
+ "tag": "",
+ "typed_value": "4Gi",
+ "value": "4Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": str(self.variables.data_portal_type)\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": str(self.variables.dataset)\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(self.variables.file_wildcard)\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": \"\"\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": \"\"\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(None)\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": \"\"\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.raw_master_cpu),\n \"memory\": str(self.variables.raw_master_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.raw_master_cpu),\n \"memory\": str(self.variables.raw_master_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(int(self.variables.batch_size))\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": \"\"\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(70)\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": \"\"\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n\n\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.raw_worker_cpu),\n \"memory\": str(self.variables.raw_worker_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.raw_worker_cpu),\n \"memory\": str(self.variables.raw_worker_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": False,\n \"replicas\": int(workflow.variables.num_partitions)\n }\n }\n }\n}\n"
+ },
+ {
+ "dependencies": [
+ {
+ "source": "raw-data-job"
+ }
+ ],
+ "easy_mode": true,
+ "is_federated": true,
+ "job_type": "PSI_DATA_JOIN",
+ "name": "psi-data-join-job",
+ "variables": [
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "role",
+ "tag": "",
+ "typed_value": "Leader",
+ "value": "Leader",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"],\"tooltip\":\"\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_path",
+ "tag": "",
+ "typed_value": "/app/deploy/integrated_test/rsa_private.key",
+ "value": "/app/deploy/integrated_test/rsa_private.key",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"RSA公钥或私钥的地址,在无RSA_KEY_PEM时必填,私钥需要同时填写rsa_key_path和rsa_private_key_path,且内容一致\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "rsa_key_pem",
+ "tag": "",
+ "typed_value": "",
+ "value": "",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"TextArea\",\"required\":false,\"tooltip\":\"直接输入RSA公钥和私钥,Leader会从中读取私钥,Follower会从中读取公钥。如果为空会使用path读取。\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "output_type",
+ "tag": "",
+ "typed_value": "TF_RECORD",
+ "value": "TF_RECORD",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"TF_RECORD\"],\"tooltip\":\"输出的datablock的格式,支持csv和tfrecord两种格式\",\"hidden\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_cpu",
+ "tag": "",
+ "typed_value": "4000m",
+ "value": "4000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "worker_mem",
+ "tag": "",
+ "typed_value": "8Gi",
+ "value": "8Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "master_cpu",
+ "tag": "",
+ "typed_value": "2000m",
+ "value": "2000m",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "master_mem",
+ "tag": "",
+ "typed_value": "4Gi",
+ "value": "4Gi",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false}"
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": dict(system.variables.labels)\n },\n \"spec\": {\n \"role\": str(self.variables.role),\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if str(self.variables.role)==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(self.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(0)\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(999999999999)\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job'].name)\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n {\n \"name\": \"DATA_JOIN_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.master_cpu),\n \"memory\": str(self.variables.master_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.master_cpu),\n \"memory\": str(self.variables.master_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(self.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": \"--batch_mode\"\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(int(workflow.variables.num_partitions))\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + str(workflow.jobs['raw-data-job'].name)\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": str(self.variables.rsa_key_pem)\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": str(self.variables.rsa_key_path)\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": str(self.variables.rsa_key_path)\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": \"\"\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": \"data.aml.fl\"\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": \"TF_RECORD\"\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": str(self.variables.output_type)\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": str(self.variables.output_type)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(-1)\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(4096)\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(None)\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(None)\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(0.0)\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(False)\n },\n {\n \"name\": \"DATA_JOIN_METRICS_SAMPLE_RATE\",\n \"value\": str(\"0\")\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": str(workflow.variables.image),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"pair\": True,\n \"replicas\": int(int(workflow.variables.num_partitions))\n }\n }\n }\n}\n"
+ }
+ ],
+ "variables": [
+ {
+ "access_mode": "PEER_READABLE",
+ "name": "image",
+ "tag": "",
+ "typed_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "value_type": "STRING",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"镜像版本不建议修改,如若修改请使用新于此版本的镜像\"}"
+ },
+ {
+ "access_mode": "PEER_WRITABLE",
+ "name": "num_partitions",
+ "tag": "",
+ "typed_value": 2.0,
+ "value": "2",
+ "value_type": "NUMBER",
+ "widget_schema": "{\"component\":\"NumberPicker\",\"required\":true}"
+ }
+ ]
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "psi-data-join-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"START_TIME\",\n \"value\": str(${Slot_start_time})\n },\n {\n \"name\": \"END_TIME\",\n \"value\": str(${Slot_end_time})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n # not work, remove it after prepare_launch_data_join_cli been removed\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n {\n \"name\": \"DATA_JOIN_METRICS_SAMPLE_RATE\",\n \"value\": str(${Slot_data_join_metrics_sample_rate})\n }\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"BATCH_MODE\",\n \"value\": ${Slot_batch_mode}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"PARTITION_NUM\",\n \"value\": str(${Slot_partition_num})\n },\n {\n \"name\": \"RAW_DATA_SUB_DIR\",\n \"value\": \"portal_publish_dir/\" + ${Slot_raw_data_name}\n },\n {\n \"name\": \"RSA_KEY_PEM\",\n \"value\": ${Slot_rsa_key_pem}\n },\n {\n \"name\": \"RSA_KEY_PATH\",\n \"value\": ${Slot_rsa_key_path}\n },\n {\n \"name\": \"RSA_PRIVATE_KEY_PATH\",\n \"value\": ${Slot_rsa_key_path}\n },\n {\n \"name\": \"KMS_KEY_NAME\",\n \"value\": ${Slot_kms_key_name}\n },\n {\n \"name\": \"KMS_CLIENT\",\n \"value\": ${Slot_kms_client}\n },\n {\n \"name\": \"PSI_RAW_DATA_ITER\",\n \"value\": ${Slot_psi_raw_data_iter}\n },\n {\n \"name\": \"DATA_BLOCK_BUILDER\",\n \"value\": ${Slot_data_block_builder}\n },\n {\n \"name\": \"PSI_OUTPUT_BUILDER\",\n \"value\": ${Slot_psi_output_builder}\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_INTERVAL\",\n \"value\": str(${Slot_data_block_dump_interval})\n },\n {\n \"name\": \"DATA_BLOCK_DUMP_THRESHOLD\",\n \"value\": str(${Slot_data_block_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_INTERVAL\",\n \"value\": str(${Slot_example_id_dump_interval})\n },\n {\n \"name\": \"EXAMPLE_ID_DUMP_THRESHOLD\",\n \"value\": str(${Slot_example_id_dump_threshold})\n },\n {\n \"name\": \"EXAMPLE_JOINER\",\n \"value\": \"SORT_RUN_JOINER\"\n },\n {\n \"name\": \"PSI_READ_AHEAD_SIZE\",\n \"value\": str(${Slot_psi_read_ahead_size})\n },\n {\n \"name\": \"SORT_RUN_MERGER_READ_AHEAD_BUFFER\",\n \"value\": str(${Slot_run_merger_read_ahead_buffer})\n },\n {\n \"name\": \"NEGATIVE_SAMPLING_RATE\",\n \"value\": str(${Slot_negative_sampling_rate})\n },\n {\n \"name\": \"ENABLE_NEGATIVE_EXAMPLE_GENERATOR\",\n \"value\": str(${Slot_enable_negative_example_generator})\n },\n {\n \"name\": \"DATA_JOIN_METRICS_SAMPLE_RATE\",\n \"value\": str(${Slot_data_join_metrics_sample_rate})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": int(${Slot_partition_num})\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_mode": {
+ "default": "",
+ "default_value": "--batch_mode",
+ "help": "如果为空则为常驻求交",
+ "label": "是否为批处理模式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "data block output数据类型",
+ "reference": "self.variables.output_type",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_data_block_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次data block,小于0则无此限制",
+ "label": "数据dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_block_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多多少个样本就dump为一个data block,小于等于0则无此限制",
+ "label": "数据dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_data_join_metrics_sample_rate": {
+ "default": "",
+ "default_value": "0",
+ "help": "建议不修改,es metrics 取样比例",
+ "label": "metrics_sample_rate",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_enable_negative_example_generator": {
+ "default": "",
+ "default_value": false,
+ "help": "建议不修改,是否开启负采样,当follower求交时遇到无法匹配上的leader的example id,会以negative_sampling_rate为概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "BOOL"
+ },
+ "Slot_end_time": {
+ "default": "",
+ "default_value": 999999999999.0,
+ "help": "建议不修改,使用自这个时间以前的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据末尾时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_interval": {
+ "default": "",
+ "default_value": -1.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump时间间隔",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_example_id_dump_threshold": {
+ "default": "",
+ "default_value": 4096.0,
+ "help": "建议不修改,最多每隔多少时间(实际时间,非样本时间)就dump一次example id,小于0则无此限制",
+ "label": "数据id dump临界点",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模板不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_kms_client": {
+ "default": "",
+ "default_value": "data.aml.fl",
+ "help": "kms client",
+ "label": "kms client",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_kms_key_name": {
+ "default": "",
+ "default_value": "",
+ "help": "kms中的密钥名称,站内镜像需使用KMS",
+ "label": "密钥名称",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "self.variables.master_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "self.variables.master_mem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_negative_sampling_rate": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,负采样比例,当follower求交时遇到无法匹配上的leader的example id,会以此概率生成一个新的样本。",
+ "label": "负采样比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "NUMBER"
+ },
+ "Slot_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "建议修改,求交后数据分区的数量,建议和raw_data一致",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_psi_output_builder": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "PSI output数据类型",
+ "reference": "self.variables.output_type",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_psi_raw_data_iter": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "建议不修改,choices=['TF_RECORD', 'CSV_DICT']",
+ "label": "raw data数据类型",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_psi_read_ahead_size": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, the read ahead size for raw data",
+ "label": "psi_read_ahead_size",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_name": {
+ "default": "",
+ "default_value": "",
+ "help": "必须修改,原始数据的发布地址,根据参数内容在portal_publish_dir地址下寻找",
+ "label": "raw_data名字",
+ "reference": "workflow.jobs['raw-data-job'].name",
+ "reference_type": "JOB_PROPERTY",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "default": "",
+ "default_value": "Leader",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "label": "Flapp通讯时角色",
+ "reference": "self.variables.role",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_path": {
+ "default": "",
+ "default_value": "",
+ "help": "RSA公钥或私钥的地址,在无RSA_KEY_PEM时必填",
+ "label": "RSA钥匙地址",
+ "reference": "self.variables.rsa_key_path",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_rsa_key_pem": {
+ "default": "",
+ "default_value": "",
+ "help": "直接输入RSA公钥和私钥,请使用Textarea,Leader会从中读取私钥,Follower会从中读取公钥。如果为空会使用path读取。",
+ "label": "RSA公钥",
+ "reference": "self.variables.rsa_key_pem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_run_merger_read_ahead_buffer": {
+ "default": "",
+ "default_value": null,
+ "help": "建议不填, sort run merger read ahead buffer",
+ "label": "run_merger_read_ahead_buffer",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_start_time": {
+ "default": "",
+ "default_value": 0.0,
+ "help": "建议不修改,使用自这个时间起的数据,仅从文件名筛选所以格式依据文件名(yyyymmdd或timestamp)",
+ "label": "数据起始时间",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "self.variables.worker_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "self.variables.worker_mem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ },
+ "raw-data-job": {
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": \"Follower\",\n \"peerSpecs\": {\n \"Leader\": {\n \"peerURL\": \"\",\n \"authority\": \"\"\n }\n },\n \"flReplicaSpecs\": {\n \"Master\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_NAME\",\n \"value\": self.name\n },\n {\n \"name\": \"DATA_PORTAL_TYPE\",\n \"value\": ${Slot_data_portal_type}\n },\n {\n \"name\": \"OUTPUT_PARTITION_NUM\",\n \"value\": str(${Slot_output_partition_num})\n },\n {\n \"name\": \"INPUT_BASE_DIR\",\n \"value\": ${Slot_input_base_dir}\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/raw_data/\" + self.name\n },\n {\n \"name\": \"RAW_DATA_PUBLISH_DIR\",\n \"value\": \"portal_publish_dir/\" + self.name\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": ${Slot_file_wildcard}\n },\n {\n \"name\": \"LONG_RUNNING\",\n \"value\": ${Slot_long_running}\n },\n {\n \"name\": \"CHECK_SUCCESS_TAG\",\n \"value\": ${Slot_check_success_tag}\n },\n {\n \"name\": \"FILES_PER_JOB_LIMIT\",\n \"value\": str(${Slot_files_per_job_limit})\n },\n {\n \"name\": \"SINGLE_SUBFOLDER\",\n \"value\": ${Slot_single_subfolder}\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(${Slot_raw_data_metrics_sample_rate})\n }\n\n ] + ${Slot_master_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_master.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_master_cpu},\n \"memory\": ${Slot_master_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": 1\n },\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/data_source/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n\n {\n \"name\": \"BATCH_SIZE\",\n \"value\": str(${Slot_batch_size})\n },\n {\n \"name\": \"INPUT_DATA_FORMAT\",\n \"value\": ${Slot_input_data_format}\n },\n {\n \"name\": \"COMPRESSED_TYPE\",\n \"value\": ${Slot_compressed_type}\n },\n {\n \"name\": \"OUTPUT_DATA_FORMAT\",\n \"value\": ${Slot_output_data_format}\n },\n {\n \"name\": \"BUILDER_COMPRESSED_TYPE\",\n \"value\": ${Slot_builder_compressed_type}\n },\n {\n \"name\": \"MEMORY_LIMIT_RATIO\",\n \"value\": str(${Slot_memory_limit_ratio})\n },\n {\n \"name\": \"OPTIONAL_FIELDS\",\n \"value\": ${Slot_optional_fields}\n },\n {\n \"name\": \"RAW_DATA_METRICS_SAMPLE_RATE\",\n \"value\": str(${Slot_raw_data_metrics_sample_rate})\n }\n\n\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": ${Slot_image},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/data_portal/run_data_portal_worker.sh\"\n ],\n \"args\": [\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_memory}\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": False,\n \"replicas\": ${Slot_output_partition_num}\n }\n }\n }\n}\n",
+ "slots": {
+ "Slot_batch_size": {
+ "default": "",
+ "default_value": 1024.0,
+ "help": "原始数据是一批一批的从文件系统中读出来,batch_size为batch的大小",
+ "label": "Batch大小",
+ "reference": "self.variables.batch_size",
+ "reference_type": "SELF",
+ "value_type": "INT"
+ },
+ "Slot_builder_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the format for output file",
+ "label": "输出压缩格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_check_success_tag": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--check_success_tag'] means false and true, Check that a _SUCCESS file exists before processing files in a subfolder",
+ "label": "是否检查成功标志",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_compressed_type": {
+ "default": "",
+ "default_value": "",
+ "help": "choices=['', 'ZLIB', 'GZIP'] the compressed type of input data file",
+ "label": "压缩方式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_data_portal_type": {
+ "default": "",
+ "default_value": "PSI",
+ "help": "运行过一次后修改无效!! the type of data portal type ,choices=['PSI', 'Streaming']",
+ "label": "数据入口类型",
+ "reference": "self.variables.data_portal_type",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_file_wildcard": {
+ "default": "",
+ "default_value": "*.rd",
+ "help": "文件名称的通配符, 将会读取input_base_dir下所以满足条件的文件,如\n1. *.csv,意为读取所有csv格式文件\n2. *.tfrecord,意为读取所有tfrecord格式文件\n3. xxx.txt,意为读取文件名为xxx.txt的文件",
+ "label": "文件名称的通配符",
+ "reference": "self.variables.file_wildcard",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_files_per_job_limit": {
+ "default": "",
+ "default_value": null,
+ "help": "空即不设限制,Max number of files in a job",
+ "label": "每个任务最多文件数",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_image": {
+ "default": "",
+ "default_value": "artifact.bytedance.com/fedlearner/fedlearner:50a6945",
+ "help": "建议不修改,指定Pod中运行的容器镜像地址,修改此项可能导致本基本模板不适用",
+ "label": "容器镜像",
+ "reference": "workflow.variables.image",
+ "reference_type": "WORKFLOW",
+ "value_type": "STRING"
+ },
+ "Slot_input_base_dir": {
+ "default": "",
+ "default_value": "/app/deploy/integrated_test/tfrecord_raw_data",
+ "help": "必须修改,运行过一次后修改无效!!the base dir of input directory",
+ "label": "输入路径",
+ "reference": "self.variables.dataset",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_input_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the type for input data iterator",
+ "label": "输入数据格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "default": "",
+ "default_value": {},
+ "help": "建议不修改,格式: {}",
+ "label": "FLAPP额外元信息",
+ "reference": "system.variables.labels",
+ "reference_type": "SYSTEM",
+ "value_type": "OBJECT"
+ },
+ "Slot_long_running": {
+ "default": "",
+ "default_value": "",
+ "help": "choices: ['','--long_running']否,是。是否为常驻上传原始数据",
+ "label": "是否常驻",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_master_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Master Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Master的CPU",
+ "reference": "self.variables.raw_master_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_master_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,master pod额外的环境变量",
+ "label": "Master额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_master_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Master Pod 所分配的内存资源(request和limit一致)",
+ "label": "Master的内存",
+ "reference": "self.variables.raw_master_mem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_memory_limit_ratio": {
+ "default": "",
+ "default_value": 70.0,
+ "help": "预测是否会OOM的时候用到,如果预测继续执行下去时占用内存会超过这个比例,就阻塞,直到尚未处理的任务处理完成。 注意这是个40-81之间的整数。",
+ "label": "内存限制比例",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "INT"
+ },
+ "Slot_optional_fields": {
+ "default": "",
+ "default_value": "",
+ "help": "optional stat fields used in joiner, separated by comma between fields, e.g. \"label,rit\"Each field will be stripped",
+ "label": "可选字段",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_data_format": {
+ "default": "",
+ "default_value": "TF_RECORD",
+ "help": "choices=['TF_RECORD', 'CSV_DICT'] the format for output file",
+ "label": "输出格式",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_output_partition_num": {
+ "default": "",
+ "default_value": 4.0,
+ "help": "运行过一次后修改无效!!输出数据的文件数量,对应Worker数量",
+ "label": "数据分区的数量",
+ "reference": "workflow.variables.num_partitions",
+ "reference_type": "WORKFLOW",
+ "value_type": "INT"
+ },
+ "Slot_raw_data_metrics_sample_rate": {
+ "default": "",
+ "default_value": "0",
+ "help": "建议不修改,es metrics 取样比例",
+ "label": "metrics_sample_rate",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_single_subfolder": {
+ "default": "",
+ "default_value": "",
+ "help": "choices:['','--single_subfolder'] 否 是,Only process one subfolder at a time",
+ "label": "是否单一子文件夹",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "default": "",
+ "default_value": "/data",
+ "help": "联邦学习中任务存储根目录",
+ "label": "存储根目录",
+ "reference": "project.variables.storage_root_path",
+ "reference_type": "PROJECT",
+ "value_type": "STRING"
+ },
+ "Slot_volume_mounts": {
+ "default": "",
+ "default_value": [
+ {
+ "mountPath": "/data",
+ "name": "data"
+ }
+ ],
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "label": "卷挂载位置",
+ "reference": "system.variables.volume_mounts_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_volumes": {
+ "default": "",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "label": "为Pod提供的卷",
+ "reference": "system.variables.volumes_list",
+ "reference_type": "SYSTEM",
+ "value_type": "LIST"
+ },
+ "Slot_worker_cpu": {
+ "default": "",
+ "default_value": "2000m",
+ "help": "Worker Pod 所分配的CPU资源(request和limit一致)",
+ "label": "Worker的CPU",
+ "reference": "self.variables.raw_worker_cpu",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "default": "",
+ "default_value": [],
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "reference": "",
+ "reference_type": "DEFAULT",
+ "value_type": "LIST"
+ },
+ "Slot_worker_memory": {
+ "default": "",
+ "default_value": "3Gi",
+ "help": "Worker Pod 所分配的内存资源(request和limit一致)",
+ "label": "Worker的内存",
+ "reference": "self.variables.raw_worker_mem",
+ "reference_type": "SELF",
+ "value_type": "STRING"
+ }
+ },
+ "variables": []
+ }
+ }
+ },
+ "group_alias": "sys_preset_psi_data_join",
+ "name": "sys-preset-psi-data-join"
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-tree-model.json b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-tree-model.json
new file mode 100644
index 000000000..9598943e7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/sys_preset_templates/sys-preset-tree-model.json
@@ -0,0 +1,592 @@
+{
+ "name": "sys-preset-tree-model",
+ "group_alias": "sys_preset_tree_model",
+ "config": {
+ "group_alias": "sys_preset_tree_model",
+ "job_definitions": [
+ {
+ "name": "tree-model",
+ "job_type": "TREE_MODEL_TRAINING",
+ "is_federated": true,
+ "variables": [
+ {
+ "name": "image_version",
+ "value": "50a6945",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tooltip\":\"镜像版本\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "50a6945",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "mode",
+ "value": "train",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"train\",\"eval\"],\"tag\":\"OPERATING_PARAM\"}",
+ "typed_value": "train",
+ "tag": "OPERATING_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "data_source",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"求交数据集名称\",\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "validation_data_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "file_type",
+ "value": "tfrecord",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"csv\",\"tfrecord\"],\"tooltip\":\"文件类型,csv或tfrecord\",\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "tfrecord",
+ "tag": "INPUT_PATH",
+ "value_type": "STRING"
+ },
+ {
+ "name": "load_model_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"模型文件地址\",\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "loss_type",
+ "value": "logistic",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":false,\"enum\":[\"logistic\",\"mse\"],\"tooltip\":\"损失函数类型,logistic或mse,默认logistic\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "logistic",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "learning_rate",
+ "value": "0.3",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "0.3",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "max_iters",
+ "value": "10",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"树的数量\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "10",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "max_depth",
+ "value": "5",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "5",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "max_bins",
+ "value": "33",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"最大分箱数\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "33",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "l2_regularization",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"L2惩罚系数\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "1",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "num_parallel",
+ "value": "5",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"进程数量\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "5",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "enable_packing",
+ "value": "true",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":false,\"enum\":[\"true\",\"false\"],\"tooltip\":\"是否开启优化\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "true",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "ignore_fields",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"不入模特征,以逗号分隔如:name,age,sex\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "cat_fields",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"类别类型特征,特征的值需要是非负整数。以逗号分隔如:alive,country,sex\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "verify_example_ids",
+ "value": "false",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":false,\"tooltip\":\"是否检查example_id对齐 If set to true, the first column of the data will be treated as example ids that must match between leader and follower\",\"enum\":[\"false\",\"true\"],\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "false",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "verbosity",
+ "value": "1",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":false,\"enum\":[\"0\",\"1\",\"2\"],\"tooltip\":\"日志输出等级\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "1",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "no_data",
+ "value": "false",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":false,\"tooltip\":\"Leader是否没数据,不建议乱用\",\"enum\":[\"false\",\"true\"],\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "false",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_cpu",
+ "value": "8000m",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "8000m",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "worker_mem",
+ "value": "16Gi",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"RESOURCE_ALLOCATION\"}",
+ "typed_value": "16Gi",
+ "tag": "RESOURCE_ALLOCATION",
+ "value_type": "STRING"
+ },
+ {
+ "name": "role",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Select\",\"required\":true,\"enum\":[\"Leader\",\"Follower\"],\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "",
+ "tag": "INPUT_PARAM",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "label_field",
+ "value": "label",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"label特征名\",\"tag\":\"INPUT_PARAM\"}",
+ "typed_value": "label",
+ "tag": "INPUT_PARAM",
+ "value_type": "STRING"
+ },
+ {
+ "name": "load_model_name",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tooltip\":\"按任务名称加载模型,{STORAGE_ROOT_PATH}/job_output/{LOAD_MODEL_NAME}/exported_models\",\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "data_path",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":false,\"tag\":\"INPUT_PATH\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "file_wildcard",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true,\"tag\":\"INPUT_PATH\",\"tooltip\":\"*.data或**/part*\"}",
+ "typed_value": "",
+ "tag": "INPUT_PATH",
+ "value": "",
+ "value_type": "STRING"
+ },
+ {
+ "name": "fedapp_active_ttl",
+ "value": "86400",
+ "access_mode": "PEER_WRITABLE",
+ "widget_schema": "{\"component\":\"Input\",\"required\":true}",
+ "typed_value": "86400",
+ "value_type": "STRING",
+ "tag": ""
+ }
+ ],
+ "yaml_template": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FedApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"labels\": dict(system.variables.labels),\n \"annotations\": {\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\",\n \"min-member\": \"1\",\n \"resource-cpu\": str(self.variables.worker_cpu),\n \"resource-mem\": str(self.variables.worker_mem),\n },\n },\n \"spec\": {\n \"activeDeadlineSeconds\": int(self.variables.fedapp_active_ttl),\n \"fedReplicaSpecs\": {\n \"Worker\": {\n \"backoffLimit\": 6,\n \"port\": # 可以没有,没有就是{containerPort: 50051, name: flapp-port, protocol: TCP}\n { \n \"containerPort\": 50051,\n \"name\": \"flapp-port\"\n },\n \"template\": {\n \"spec\": {\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": str(project.variables.storage_root_path)\n },\n {\n \"name\": \"ROLE\",\n \"value\": str(self.variables.role).lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": str(project.variables.storage_root_path) + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": str(self.variables.mode)\n },\n {\n \"name\": \"LOSS_TYPE\",\n \"value\": str(self.variables.loss_type)\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": str(self.variables.data_source)\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": str(self.variables.data_path)\n },\n {\n \"name\": \"VALIDATION_DATA_PATH\",\n \"value\": str(self.variables.validation_data_path)\n },\n {\n \"name\": \"NO_DATA\",\n \"value\": str(bool(self.variables.no_data))\n },\n {\n \"name\": \"FILE_WILDCARD\",\n \"value\": str(self.variables.file_wildcard)\n },\n {\n \"name\": \"FILE_TYPE\",\n \"value\": str(self.variables.file_type)\n },\n {\n \"name\": \"LOAD_MODEL_PATH\",\n \"value\": str(self.variables.load_model_path)\n },\n {\n \"name\": \"LOAD_MODEL_NAME\",\n \"value\": str(self.variables.load_model_name)\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(int(self.variables.verbosity))\n },\n {\n \"name\": \"LEARNING_RATE\",\n \"value\": str(float(self.variables.learning_rate))\n },\n {\n \"name\": \"MAX_ITERS\",\n \"value\": str(int(self.variables.max_iters))\n },\n {\n \"name\": \"MAX_DEPTH\",\n \"value\": str(int(self.variables.max_depth))\n },\n {\n \"name\": \"MAX_BINS\",\n \"value\": str(int(self.variables.max_bins))\n },\n {\n \"name\": \"L2_REGULARIZATION\",\n \"value\": str(float(self.variables.l2_regularization))\n },\n {\n \"name\": \"NUM_PARALLEL\",\n \"value\": str(int(self.variables.num_parallel))\n },\n {\n \"name\": \"VERIFY_EXAMPLE_IDS\",\n \"value\": str(bool(self.variables.verify_example_ids))\n },\n {\n \"name\": \"IGNORE_FIELDS\",\n \"value\": str(self.variables.ignore_fields)\n },\n {\n \"name\": \"CAT_FIELDS\",\n \"value\": str(self.variables.cat_fields)\n },\n {\n \"name\": \"LABEL_FIELD\",\n \"value\": str(self.variables.label_field)\n },\n {\n \"name\": \"SEND_SCORES_TO_FOLLOWER\",\n \"value\": str(False)\n },\n {\n \"name\": \"SEND_METRICS_TO_FOLLOWER\",\n \"value\": str(False)\n },\n {\n \"name\": \"ENABLE_PACKING\",\n \"value\": str(bool(self.variables.enable_packing))\n },\n {\n \"name\": \"ES_BATCH_SIZE\",\n \"value\": str(10)\n },\n {\n \"name\": \"METRIC_COLLECTOR_ENABLE\",\n \"value\": str(True)\n },\n {\n \"name\": \"METRIC_COLLECTOR_SERVICE_NAME\",\n \"value\": \"fedlearner_model\"\n }\n ] + [],\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": list(system.variables.volume_mounts_list),\n \"image\": system.variables.image_repo + \"/fedlearner:\" + str(self.variables.image_version),\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\":[\"/bin/bash\",\"-c\"],\n \"args\": [\"export WORKER_RANK=$$INDEX && export PEER_ADDR=$$SERVICE_ID && /app/deploy/scripts/trainer/run_tree_worker.sh\"],\n \"resources\": {\n \"limits\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n },\n \"requests\": {\n \"cpu\": str(self.variables.worker_cpu),\n \"memory\": str(self.variables.worker_mem)\n }\n }\n }\n ],\n \"imagePullSecrets\": [\n {\n \"name\": \"regcred\"\n }\n ],\n \"volumes\": list(system.variables.volumes_list)\n }\n },\n \"replicas\": 1\n }\n }\n }\n}\n",
+ "dependencies": [],
+ "easy_mode": false
+ }
+ ],
+ "variables": []
+ },
+ "editor_info": {
+ "yaml_editor_infos": {
+ "tree-model": {
+ "slots": {
+ "Slot_send_scores_to_follower": {
+ "help": "是否发送结果到follower",
+ "label": "是否发送结果到follower",
+ "default_value": false,
+ "value_type": "BOOL",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_l2_regularization": {
+ "reference": "self.variables.l2_regularization",
+ "help": "L2惩罚系数",
+ "reference_type": "SELF",
+ "label": "L2惩罚系数",
+ "default_value": 1.0,
+ "value_type": "NUMBER",
+ "default": ""
+ },
+ "Slot_verify_example_ids": {
+ "reference": "self.variables.verify_example_ids",
+ "help": "是否检查example_id对齐 If set to true, the first column of the data will be treated as example ids that must match between leader and follower",
+ "reference_type": "SELF",
+ "label": "是否检查example_id对齐",
+ "default_value": false,
+ "value_type": "BOOL",
+ "default": ""
+ },
+ "Slot_max_iters": {
+ "reference": "self.variables.max_iters",
+ "help": "树的数量",
+ "reference_type": "SELF",
+ "label": "迭代数",
+ "default_value": 5.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_file_type": {
+ "reference": "self.variables.file_type",
+ "help": "文件类型,csv或tfrecord",
+ "reference_type": "SELF",
+ "label": "文件类型,csv或tfrecord",
+ "default_value": "tfrecord",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_load_model_name": {
+ "reference": "self.variables.load_model_name",
+ "help": "按任务名称加载模型,{STORAGE_ROOT_PATH}/job_output/{LOAD_MODEL_NAME}/exported_models",
+ "reference_type": "SELF",
+ "label": "模型任务名称",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_verbosity": {
+ "reference": "self.variables.verbosity",
+ "help": "日志输出等级",
+ "reference_type": "SELF",
+ "label": "日志输出等级",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_load_model_path": {
+ "reference": "self.variables.load_model_path",
+ "help": "模型文件地址",
+ "reference_type": "SELF",
+ "label": "模型文件地址",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_max_bins": {
+ "reference": "self.variables.max_bins",
+ "help": "最大分箱数",
+ "reference_type": "SELF",
+ "label": "最大分箱数",
+ "default_value": 33.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_data_path": {
+ "help": "数据存放位置",
+ "label": "数据存放位置",
+ "default_value": "",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT",
+ "value_type": "STRING"
+ },
+ "Slot_no_data": {
+ "reference": "self.variables.no_data",
+ "help": "Leader是否没数据",
+ "reference_type": "SELF",
+ "label": "Leader是否没数据",
+ "default_value": false,
+ "value_type": "BOOL",
+ "default": ""
+ },
+ "Slot_file_ext": {
+ "reference": "self.variables.undefined",
+ "help": "文件后缀",
+ "reference_type": "SELF",
+ "label": "文件后缀",
+ "default_value": ".data",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_label_field": {
+ "reference": "self.variables.label_field",
+ "help": "label特征名",
+ "reference_type": "SELF",
+ "label": "label特征名",
+ "default_value": "label",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_labels": {
+ "reference": "system.variables.labels",
+ "help": "建议不修改,格式: {}",
+ "reference_type": "SYSTEM",
+ "label": "FLAPP额外元信息",
+ "default_value": {},
+ "value_type": "OBJECT",
+ "default": ""
+ },
+ "Slot_loss_type": {
+ "reference": "self.variables.loss_type",
+ "help": "损失函数类型,logistic或mse,默认logistic",
+ "reference_type": "SELF",
+ "label": "损失函数类型",
+ "default_value": "logistic",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_enable_packing": {
+ "reference": "self.variables.enable_packing",
+ "help": "是否开启优化",
+ "reference_type": "SELF",
+ "label": "是否开启优化",
+ "default_value": true,
+ "value_type": "BOOL",
+ "default": ""
+ },
+ "Slot_worker_cpu": {
+ "reference": "self.variables.worker_cpu",
+ "help": "所需CPU",
+ "reference_type": "SELF",
+ "label": "所需CPU",
+ "default_value": "8000m",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_ignore_fields": {
+ "reference": "self.variables.ignore_fields",
+ "help": "以逗号分隔如:name,age,sex",
+ "reference_type": "SELF",
+ "label": "不入模的特征",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_image_version": {
+ "reference": "self.variables.image_version",
+ "help": "建议不修改,指定Pod中运行的容器镜像版本,前缀为system.variables.image_repo + '/fedlearner:'",
+ "reference_type": "SELF",
+ "label": "容器镜像版本",
+ "default_value": "882310f",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_storage_root_path": {
+ "reference": "project.variables.storage_root_path",
+ "help": "联邦学习中任务存储根目录",
+ "reference_type": "PROJECT",
+ "label": "存储根目录",
+ "default_value": "/data",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_num_parallel": {
+ "reference": "self.variables.num_parallel",
+ "help": "进程数量",
+ "reference_type": "SELF",
+ "label": "进程数量",
+ "default_value": 1.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_validation_data_path": {
+ "reference": "self.variables.validation_data_path",
+ "help": "验证数据集地址",
+ "reference_type": "SELF",
+ "label": "验证数据集地址",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_worker_envs": {
+ "help": "数组类型,worker pod额外的环境变量",
+ "label": "Worker额外环境变量",
+ "default_value": [],
+ "value_type": "LIST",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_send_metrics_to_follower": {
+ "help": "是否发送指标到follower",
+ "label": "是否发送指标到follower",
+ "default_value": false,
+ "value_type": "BOOL",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_max_depth": {
+ "reference": "self.variables.max_depth",
+ "help": "最大深度",
+ "reference_type": "SELF",
+ "label": "最大深度",
+ "default_value": 3.0,
+ "value_type": "INT",
+ "default": ""
+ },
+ "Slot_volume_mounts": {
+ "reference": "system.variables.volume_mounts_list",
+ "help": "建议不修改,容器中卷挂载的位置,数组类型",
+ "reference_type": "SYSTEM",
+ "label": "卷挂载位置",
+ "default_value": [
+ {
+ "name": "data",
+ "mountPath": "/data"
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_mode": {
+ "reference": "self.variables.mode",
+ "help": "任务类型,train或eval",
+ "reference_type": "SELF",
+ "label": "任务类型,train或eval",
+ "default_value": "train",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_learning_rate": {
+ "reference": "self.variables.learning_rate",
+ "help": "学习率",
+ "reference_type": "SELF",
+ "label": "学习率",
+ "default_value": 0.3,
+ "value_type": "NUMBER",
+ "default": ""
+ },
+ "Slot_worker_mem": {
+ "reference": "self.variables.worker_mem",
+ "help": "所需内存",
+ "reference_type": "SELF",
+ "label": "所需内存",
+ "default_value": "16Gi",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_es_batch_size": {
+ "help": "ES_BATCH_SIZE",
+ "label": "ES_BATCH_SIZE",
+ "default_value": 10.0,
+ "value_type": "INT",
+ "reference": "",
+ "default": "",
+ "reference_type": "DEFAULT"
+ },
+ "Slot_data_source": {
+ "reference": "self.variables.data_source",
+ "help": "求交数据集名称",
+ "reference_type": "SELF",
+ "label": "求交数据集名称",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_role": {
+ "reference": "self.variables.role",
+ "help": "Flapp 通讯时的角色 Leader 或 Follower",
+ "reference_type": "SELF",
+ "label": "Flapp通讯时角色",
+ "default_value": "Leader",
+ "default": "",
+ "value_type": "STRING"
+ },
+ "Slot_volumes": {
+ "reference": "system.variables.volumes_list",
+ "help": "建议不修改,数组类型,和volume_mounts一一对应",
+ "reference_type": "SYSTEM",
+ "label": "为Pod提供的卷",
+ "default_value": [
+ {
+ "name": "data",
+ "persistentVolumeClaim": {
+ "claimName": "pvc-fedlearner-default"
+ }
+ }
+ ],
+ "value_type": "LIST",
+ "default": ""
+ },
+ "Slot_cat_fields": {
+ "reference": "self.variables.cat_fields",
+ "help": "类别类型特征,特征的值需要是非负整数。以逗号分隔如:alive,country,sex",
+ "reference_type": "SELF",
+ "label": "类别类型特征",
+ "default_value": "",
+ "default": "",
+ "value_type": "STRING"
+ }
+ },
+ "meta_yaml": "{\n \"apiVersion\": \"fedlearner.k8s.io/v1alpha1\",\n \"kind\": \"FLApp\",\n \"metadata\": {\n \"name\": self.name,\n \"namespace\": system.variables.namespace,\n \"annotations\":{\n \"queue\": \"fedlearner\",\n \"schedulerName\": \"batch\"\n },\n \"labels\": ${Slot_labels}\n },\n \"spec\": {\n \"role\": ${Slot_role},\n \"cleanPodPolicy\": \"All\",\n \"peerSpecs\": {\n \"Leader\" if ${Slot_role}==\"Follower\" else \"Follower\": {\n \"peerURL\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\",\n \"authority\": project.participants[0].egress_host,\n \"extraHeaders\": {\n \"x-host\": \"fedlearner-operator.\" + project.participants[0].egress_domain\n }\n }\n },\n \"flReplicaSpecs\": {\n \"Worker\": {\n \"template\": {\n \"spec\": {\n \"restartPolicy\": \"Never\",\n \"containers\": [\n {\n \"env\": system.basic_envs_list + [\n {\n \"name\": \"STORAGE_ROOT_PATH\",\n \"value\": ${Slot_storage_root_path}\n },\n {\n \"name\": \"ROLE\",\n \"value\": ${Slot_role}.lower()\n },\n {\n \"name\": \"APPLICATION_ID\",\n \"value\": self.name\n },\n {\n \"name\": \"OUTPUT_BASE_DIR\",\n \"value\": ${Slot_storage_root_path} + \"/job_output/\" + self.name\n },\n {\n \"name\": \"EGRESS_URL\",\n \"value\": \"fedlearner-stack-ingress-nginx-controller.default.svc:80\"\n },\n {\n \"name\": \"EGRESS_HOST\",\n \"value\": project.participants[0].egress_host\n },\n {\n \"name\": \"EGRESS_DOMAIN\",\n \"value\": project.participants[0].egress_domain\n },\n {\n \"name\": \"MODE\",\n \"value\": ${Slot_mode}\n },\n {\n \"name\": \"LOSS_TYPE\",\n \"value\": ${Slot_loss_type}\n },\n {\n \"name\": \"DATA_SOURCE\",\n \"value\": ${Slot_data_source}\n },\n {\n \"name\": \"DATA_PATH\",\n \"value\": ${Slot_data_path}\n },\n {\n \"name\": \"VALIDATION_DATA_PATH\",\n \"value\": ${Slot_validation_data_path}\n },\n {\n \"name\": \"NO_DATA\",\n \"value\": str(${Slot_no_data})\n },\n {\n \"name\": \"FILE_EXT\",\n \"value\": ${Slot_file_ext}\n },\n {\n \"name\": \"FILE_TYPE\",\n \"value\": ${Slot_file_type}\n },\n {\n \"name\": \"LOAD_MODEL_PATH\",\n \"value\": ${Slot_load_model_path}\n },\n {\n \"name\": \"LOAD_MODEL_NAME\",\n \"value\": ${Slot_load_model_name}\n },\n {\n \"name\": \"VERBOSITY\",\n \"value\": str(${Slot_verbosity})\n },\n {\n \"name\": \"LEARNING_RATE\",\n \"value\": str(${Slot_learning_rate})\n },\n {\n \"name\": \"MAX_ITERS\",\n \"value\": str(${Slot_max_iters})\n },\n {\n \"name\": \"MAX_DEPTH\",\n \"value\": str(${Slot_max_depth})\n },\n {\n \"name\": \"MAX_BINS\",\n \"value\": str(${Slot_max_bins})\n },\n {\n \"name\": \"L2_REGULARIZATION\",\n \"value\": str(${Slot_l2_regularization})\n },\n {\n \"name\": \"NUM_PARALLEL\",\n \"value\": str(${Slot_num_parallel})\n },\n {\n \"name\": \"VERIFY_EXAMPLE_IDS\",\n \"value\": str(${Slot_verify_example_ids})\n },\n {\n \"name\": \"IGNORE_FIELDS\",\n \"value\": ${Slot_ignore_fields}\n },\n {\n \"name\": \"CAT_FIELDS\",\n \"value\": ${Slot_cat_fields}\n },\n {\n \"name\": \"LABEL_FIELD\",\n \"value\": ${Slot_label_field}\n },\n {\n \"name\": \"SEND_SCORES_TO_FOLLOWER\",\n \"value\": str(${Slot_send_scores_to_follower})\n },\n {\n \"name\": \"SEND_METRICS_TO_FOLLOWER\",\n \"value\": str(${Slot_send_metrics_to_follower})\n },\n {\n \"name\": \"ENABLE_PACKING\",\n \"value\": str(${Slot_enable_packing})\n },\n {\n \"name\": \"ES_BATCH_SIZE\",\n \"value\": str(${Slot_es_batch_size})\n }\n ] + ${Slot_worker_envs},\n \"imagePullPolicy\": \"IfNotPresent\",\n \"name\": \"tensorflow\",\n \"volumeMounts\": ${Slot_volume_mounts},\n \"image\": system.variables.image_repo + \"/fedlearner:\" + ${Slot_image_version},\n \"ports\": [\n {\n \"containerPort\": 50051,\n \"name\": \"flapp-port\",\n \"protocol\": \"TCP\"\n },\n {\n \"containerPort\": 50052,\n \"name\": \"tf-port\",\n \"protocol\": \"TCP\"\n }\n ],\n \"command\": [\n \"/app/deploy/scripts/wait4pair_wrapper.sh\"\n ],\n \"args\": [\n \"/app/deploy/scripts/trainer/run_tree_worker.sh\"\n ],\n \"resources\": {\n \"limits\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_mem}\n },\n \"requests\": {\n \"cpu\": ${Slot_worker_cpu},\n \"memory\": ${Slot_worker_mem}\n }\n }\n }\n ],\n \"volumes\": ${Slot_volumes}\n }\n },\n \"pair\": True,\n \"replicas\": 1\n }\n }\n }\n}\n",
+ "variables": []
+ }
+ }
+ },
+ "comment": ""
+}
\ No newline at end of file
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/tee/BUILD.bazel
new file mode 100644
index 000000000..e67137801
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/BUILD.bazel
@@ -0,0 +1,268 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "services_lib",
+ srcs = [
+ "services.py",
+ "tee_job_template.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:data_path_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = [
+ "utils.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "utils_lib_test",
+ size = "small",
+ srcs = [
+ "utils_test.py",
+ ],
+ imports = ["../.."],
+ main = "utils_test.py",
+ deps = [
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ ],
+)
+
+py_library(
+ name = "controller_lib",
+ srcs = [
+ "controller.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":services_lib",
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:job_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:system_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:transaction_manager_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ ],
+)
+
+py_test(
+ name = "controller_lib_test",
+ size = "small",
+ srcs = [
+ "controller_test.py",
+ ],
+ imports = ["../.."],
+ main = "controller_test.py",
+ deps = [
+ ":controller_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "runners_lib",
+ srcs = [
+ "runners.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":controller_lib",
+ ":models_lib",
+ ":utils_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "runners_lib_test",
+ size = "small",
+ srcs = [
+ "runners_test.py",
+ ],
+ imports = ["../.."],
+ main = "runners_test.py",
+ deps = [
+ ":models_lib",
+ ":runners_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":controller_lib",
+ ":models_lib",
+ ":services_lib",
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:data_path_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/flag:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/rpc/v2:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/__init__.py b/web_console_v2/api/fedlearner_webconsole/tee/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/apis.py b/web_console_v2/api/fedlearner_webconsole/tee/apis.py
new file mode 100644
index 000000000..8c46b9fd6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/apis.py
@@ -0,0 +1,834 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Optional
+from flask_restful import Resource
+from http import HTTPStatus
+from marshmallow import Schema, fields, post_load, validate
+from webargs.flaskparser import use_kwargs
+from google.protobuf.json_format import ParseDict
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+from fedlearner_webconsole.utils.decorators.pp_flask import input_validator
+from fedlearner_webconsole.utils.flask_utils import make_flask_response, FilterExpField, get_current_user
+from fedlearner_webconsole.utils.filtering import SupportedField, FieldType, FilterBuilder
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.proto.audit_pb2 import Event
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterOp
+from fedlearner_webconsole.proto.review_pb2 import TicketType, TicketDetails
+from fedlearner_webconsole.exceptions import ResourceConflictException, InvalidArgumentException, NoAccessException, \
+ InternalException, NotFoundException
+from fedlearner_webconsole.tee.models import TrustedJobGroup, GroupCreateStatus, TrustedJob, \
+ TrustedJobStatus, TrustedJobType
+from fedlearner_webconsole.tee.controller import TrustedJobGroupController, launch_trusted_job, stop_trusted_job, \
+ get_tee_enabled_participants, TrustedJobController
+from fedlearner_webconsole.tee.services import TrustedJobGroupService, TrustedJobService
+from fedlearner_webconsole.tee.utils import get_project, get_algorithm, get_dataset, get_participant, \
+ get_trusted_job_group, get_trusted_job, get_algorithm_with_uuid
+from fedlearner_webconsole.proto.tee_pb2 import Resource as ResourcePb, ParticipantDatasetList
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+
+
+class ParticipantDatasetParams(Schema):
+ participant_id = fields.Integer(required=True)
+ uuid = fields.Str(required=True)
+ name = fields.Str(required=True)
+
+
+class ResourceParams(Schema):
+ cpu = fields.Integer(required=True)
+ memory = fields.Integer(required=True)
+ replicas = fields.Integer(required=True)
+
+
+class CreateTrustedJobGroupParams(Schema):
+ name = fields.Str(required=True)
+ comment = fields.Str(required=False, load_default=None)
+ # TODO(liuledian): remove algorithm_id after frontend completed
+ algorithm_id = fields.Integer(required=False, load_default=None)
+ algorithm_uuid = fields.Str(required=False, load_default=None)
+ dataset_id = fields.Integer(required=False, load_default=None)
+ participant_datasets = fields.List(fields.Nested(ParticipantDatasetParams), required=False, load_default=None)
+ resource = fields.Nested(ResourceParams, required=True)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['resource'] = ParseDict(data['resource'], ResourcePb())
+ data['participant_datasets'] = ParseDict({'items': data['participant_datasets']}, ParticipantDatasetList())
+ return data
+
+
+class ConfigTrustedJobGroupParams(Schema):
+ comment = fields.Str(required=False, load_default=None)
+ auth_status = fields.Str(required=False,
+ load_default=None,
+ validate=validate.OneOf([AuthStatus.PENDING.name, AuthStatus.AUTHORIZED.name]))
+ # TODO(liuledian): remove algorithm_id after frontend completed
+ algorithm_id = fields.Integer(required=False, load_default=None)
+ algorithm_uuid = fields.Str(required=False, load_default=None)
+ resource = fields.Nested(ResourceParams, required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ if data['resource'] is not None:
+ data['resource'] = ParseDict(data['resource'], ResourcePb())
+ if data['auth_status'] is not None:
+ data['auth_status'] = AuthStatus[data['auth_status']]
+ return data
+
+
+class TrustedJobGroupsApi(Resource):
+
+ FILTER_FIELDS = {
+ 'name': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ }
+
+ def __init__(self):
+ self._filter_builder = FilterBuilder(model_class=TrustedJobGroup, supported_fields=self.FILTER_FIELDS)
+
+ @credentials_required
+ @use_kwargs(
+ {
+ 'page': fields.Integer(required=False, load_default=None),
+ 'page_size': fields.Integer(required=False, load_default=None),
+ 'filter_exp': FilterExpField(data_key='filter', required=False, load_default=None),
+ },
+ location='query')
+ def get(
+ self,
+ page: Optional[int],
+ page_size: Optional[int],
+ filter_exp: Optional[FilterExpression],
+ project_id: int,
+ ):
+ """Get the list of trusted job groups
+ ---
+ tags:
+ - tee
+ description: get the list of trusted job groups
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: filter
+ schema:
+ type: string
+ responses:
+ 200:
+ description: the list of trusted job groups
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobGroupRef'
+ 400:
+ description: invalid argument
+ 403:
+ description: the trusted job group is forbidden to access
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ # TODO(liuledian): filter out groups in notification
+ query = session.query(TrustedJobGroup).filter(TrustedJobGroup.resource.isnot(None)).order_by(
+ TrustedJobGroup.created_at.desc())
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ if filter_exp:
+ try:
+ query = self._filter_builder.build_query(query, filter_exp)
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ pagination = paginate(query, page, page_size)
+ data = [d.to_ref() for d in pagination.get_items()]
+ session.commit()
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.TRUSTED_JOB_GROUP, op_type=Event.OperationType.CREATE)
+ @use_kwargs(CreateTrustedJobGroupParams(), location='json')
+ def post(self, name: str, comment: Optional[str], algorithm_id: Optional[int], algorithm_uuid: Optional[str],
+ dataset_id: Optional[int], participant_datasets: ParticipantDatasetList, resource: ResourcePb,
+ project_id: int):
+ """Create a trusted job group
+ ---
+ tags:
+ - tee
+ description: create a trusted job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/CreateTrustedJobGroupParams'
+ responses:
+ 201:
+ description: the detail of the trusted job group
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobGroupPb'
+ 400:
+ description: invalid argument
+ 403:
+ description: the trusted job group is forbidden to create
+ 409:
+ description: the trusted job group already exists
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ if dataset_id is None and not participant_datasets.items:
+ raise InvalidArgumentException('dataset_id and participant_datasets are both missing')
+ with db.session_scope() as session:
+ project = get_project(session, project_id)
+ # TODO(liuledian): remove algorithm_id logic after frontend completed
+ if not algorithm_uuid:
+ algorithm_uuid = get_algorithm(session, algorithm_id).uuid
+ algorithm = get_algorithm_with_uuid(project_id, algorithm_uuid)
+ if algorithm.type != AlgorithmType.TRUSTED_COMPUTING.name:
+ raise InvalidArgumentException(f'algorithm {algorithm_uuid} invalid type')
+ if dataset_id is not None:
+ dataset = get_dataset(session, dataset_id)
+ if not dataset.is_published:
+ raise InvalidArgumentException(f'dataset {dataset.id} not published')
+ for pd in participant_datasets.items:
+ get_participant(session, pd.participant_id)
+ group = session.query(TrustedJobGroup).filter_by(name=name, project_id=project_id).first()
+ if group is not None:
+ raise ResourceConflictException(f'trusted job group {name} in project {project_id} already exists')
+ # TODO(liuledian): let creator assign analyzer id
+ enabled_pids = get_tee_enabled_participants(session, project_id)
+ if len(enabled_pids) != 1:
+ raise InternalException('tee enabled participants not valid')
+ analyzer_id = enabled_pids[0]
+
+ with db.session_scope() as session:
+ group = TrustedJobGroup(
+ name=name,
+ uuid=resource_uuid(),
+ latest_version=0,
+ comment=comment,
+ project_id=project.id,
+ creator_username=get_current_user().username,
+ coordinator_id=0,
+ analyzer_id=analyzer_id,
+ auth_status=AuthStatus.AUTHORIZED,
+ algorithm_uuid=algorithm_uuid,
+ dataset_id=dataset_id,
+ )
+ participants = ParticipantService(session).get_participants_by_project(project.id)
+ group.set_unauth_participant_ids([p.id for p in participants])
+ group.set_resource(resource)
+ group.set_participant_datasets(participant_datasets)
+ session.add(group)
+ get_ticket_helper(session).create_ticket(TicketType.TK_CREATE_TRUSTED_JOB_GROUP,
+ TicketDetails(uuid=group.uuid))
+ session.commit()
+ return make_flask_response(data=group.to_proto(), status=HTTPStatus.CREATED)
+
+
+class TrustedJobGroupApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, group_id: int):
+ """Get the trusted job group
+ ---
+ tags:
+ - tee
+ descriptions: get the trusted job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the trusted job group
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobGroupPb'
+ 403:
+ description: the trusted job group is forbidden access
+ 404:
+ description: trusted job group is not found
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ group = get_trusted_job_group(session, project_id, group_id)
+ try:
+ TrustedJobGroupController(session, project_id).update_unauth_participant_ids(group)
+ data = group.to_proto()
+ algorithm = AlgorithmFetcher(project_id).get_algorithm(group.algorithm_uuid)
+ data.algorithm_project_uuid = algorithm.algorithm_project_uuid
+ data.algorithm_participant_id = algorithm.participant_id
+ except InternalException:
+ logging.warning(f'[trusted-job-group] group {group_id} update unauth_participant_ids failed')
+ except NotFoundException:
+ logging.warning(f'[trusted-job-group] group {group_id} fetch algorithm {group.algorithm_uuid} failed')
+ session.commit()
+ return make_flask_response(data)
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.TRUSTED_JOB_GROUP, op_type=Event.OperationType.UPDATE)
+ @use_kwargs(ConfigTrustedJobGroupParams(), location='json')
+ def put(self, comment: Optional[str], auth_status: Optional[AuthStatus], algorithm_id: Optional[int],
+ algorithm_uuid: Optional[str], resource: Optional[ResourcePb], project_id: int, group_id: int):
+ """Update the trusted job group
+ ---
+ tags:
+ - tee
+ description: update the trusted job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/ConfigTrustedJobGroupParams'
+ responses:
+ 200:
+ description: update the trusted job group successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobGroupPb'
+ 400:
+ description: invalid argument
+ 403:
+ description: the trusted job group is forbidden to update
+ 404:
+ description: trusted job group is not found
+ 409:
+ description: the trusted job group has not been fully created
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ group = get_trusted_job_group(session, project_id, group_id)
+ controller = TrustedJobGroupController(session, project_id)
+ if group.status != GroupCreateStatus.SUCCEEDED:
+ raise ResourceConflictException('the trusted job group has not been fully created')
+ if comment is not None:
+ group.comment = comment
+ if auth_status is not None and auth_status != group.auth_status:
+ controller.inform_trusted_job_group(group, auth_status)
+ if algorithm_uuid or algorithm_id:
+ if group.coordinator_id:
+ raise NoAccessException('only coordinator can update algorithm')
+ # TODO(liuledian): remove after frontend completed
+ if not algorithm_uuid:
+ algorithm_uuid = get_algorithm(session, algorithm_id).uuid
+ algorithm = get_algorithm_with_uuid(project_id, algorithm_uuid)
+ old_algorithm = get_algorithm_with_uuid(project_id, group.algorithm_uuid)
+ if algorithm.algorithm_project_uuid != old_algorithm.algorithm_project_uuid:
+ raise InvalidArgumentException('algorithm project mismatch between old and new algorithm')
+ controller.update_trusted_job_group(group, algorithm_uuid)
+ if resource is not None:
+ group.set_resource(resource)
+ data = group.to_proto()
+ session.commit()
+ return make_flask_response(data)
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.TRUSTED_JOB_GROUP, op_type=Event.OperationType.DELETE)
+ def delete(self, project_id: int, group_id: int):
+ """Delete the trusted job group
+ ---
+ tags:
+ - tee
+ description: delete the trusted job group
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 204:
+ description: delete the trusted job group successfully
+ 403:
+ description: the trusted job group is forbidden to delete
+ 409:
+ description: the trusted job group cannot be deleted
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).filter_by(project_id=project_id, id=group_id).first()
+ if group is not None:
+ if group.coordinator_id:
+ raise NoAccessException('only creator can delete the trusted job group')
+ if not group.is_deletable():
+ raise ResourceConflictException('the trusted job group cannot be deleted')
+ TrustedJobGroupController(session, project_id).delete_trusted_job_group(group)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class LaunchTrustedJobApi(Resource):
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.TRUSTED_JOB, op_type=Event.OperationType.LAUNCH)
+ @use_kwargs({'comment': fields.Str(required=False, load_default=None)}, location='json')
+ def post(self, comment: Optional[str], project_id: int, group_id: int):
+ """Launch the trusted job
+ ---
+ tags:
+ - tee
+ description: launch the trusted job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: group_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: False
+ content:
+ application/json:
+ schema:
+ type: object
+ properties:
+ comment:
+ type: string
+ responses:
+ 201:
+ description: launch the trusted job successfully
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobPb'
+ 403:
+ description: the trusted job is forbidden to launch
+ 404:
+ description: trusted job group is not found
+ 409:
+ description: the trusted job group is not fully created or authorized
+ 500:
+ description: internal exception
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ group = get_trusted_job_group(session, project_id, group_id)
+ if (group.status != GroupCreateStatus.SUCCEEDED or group.get_unauth_participant_ids() or
+ group.auth_status != AuthStatus.AUTHORIZED):
+ raise ResourceConflictException('the trusted job group is not fully created or authorized')
+ group = TrustedJobGroupService(session).lock_and_update_version(group_id)
+ session.commit()
+ succeeded, msg = launch_trusted_job(project_id, group.uuid, group.latest_version)
+ if not succeeded:
+ raise InternalException(f'launching trusted job failed with message: {msg}')
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).filter_by(trusted_job_group_id=group_id,
+ version=group.latest_version).first()
+ trusted_job.comment = comment
+ session.commit()
+ return make_flask_response(trusted_job.to_proto(), status=HTTPStatus.CREATED)
+
+
+class GetTrustedJobsParams(Schema):
+ trusted_job_group_id = fields.Integer(required=True)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
+ trusted_job_type = fields.Str(required=False,
+ data_key='type',
+ load_default=TrustedJobType.ANALYZE.name,
+ validate=validate.OneOf([TrustedJobType.ANALYZE.name, TrustedJobType.EXPORT.name]))
+
+ @post_load()
+ def make(self, data, **kwargs):
+ if data['trusted_job_type'] is not None:
+ data['trusted_job_type'] = TrustedJobType[data['trusted_job_type']]
+ return data
+
+
+class TrustedJobsApi(Resource):
+
+ @credentials_required
+ @use_kwargs(GetTrustedJobsParams(), location='query')
+ def get(self, trusted_job_group_id: int, page: Optional[int], page_size: Optional[int], trusted_job_type: str,
+ project_id: int):
+ """Get the list of trusted jobs
+ ---
+ tags:
+ - tee
+ description: get the list of trusted jobs
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: query
+ name: group_id
+ schema:
+ type: integer
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: type
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of trusted jobs
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobRef'
+ 403:
+ description: trusted job list is forbidden to access
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ query = session.query(TrustedJob).filter_by(type=trusted_job_type)
+ # filter out trusted jobs in notification when getting the export type
+ if trusted_job_type == TrustedJobType.EXPORT:
+ query = query.filter(TrustedJob.auth_status != AuthStatus.PENDING)
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ if trusted_job_group_id:
+ query = query.filter_by(trusted_job_group_id=trusted_job_group_id)
+ if trusted_job_type == TrustedJobType.ANALYZE:
+ query = query.order_by(TrustedJob.version.desc())
+ else:
+ # the version of tee export job equals to corresponding tee analyze job, so sort by creation time
+ query = query.order_by(TrustedJob.created_at.desc())
+ pagination = paginate(query, page, page_size)
+ data = [d.to_ref() for d in pagination.get_items()]
+ session.commit()
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+
+class UpdateTrustedJobParams(Schema):
+ comment = fields.Str(required=False, load_default=None)
+ auth_status = fields.Str(required=False,
+ load_default=None,
+ validate=validate.OneOf([AuthStatus.AUTHORIZED.name, AuthStatus.WITHDRAW.name]))
+
+ @post_load()
+ def make(self, data, **kwargs):
+ if data['auth_status'] is not None:
+ data['auth_status'] = AuthStatus[data['auth_status']]
+ return data
+
+
+class TrustedJobApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int, trusted_job_id: int):
+ """Get the trusted job by id
+ ---
+ tags:
+ - tee
+ description: get the trusted job by id
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: trusted_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of the trusted job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobPb'
+ 403:
+ description: the trusted job is forbidden to access
+ 404:
+ description: the trusted job is not found
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ trusted_job = get_trusted_job(session, project_id, trusted_job_id)
+ if trusted_job.type == TrustedJobType.EXPORT:
+ TrustedJobController(session, project_id).update_participants_info(trusted_job)
+ data = trusted_job.to_proto()
+ session.commit()
+ return make_flask_response(data)
+
+ @input_validator
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.TRUSTED_JOB, op_type=Event.OperationType.UPDATE)
+ @use_kwargs(UpdateTrustedJobParams(), location='json')
+ def put(self, comment: str, auth_status: AuthStatus, project_id: int, trusted_job_id: int):
+ """Update the trusted job
+ ---
+ tags:
+ - tee
+ description: update the trusted job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: trusted_job_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/UpdateTrustedJobParams'
+ responses:
+ 200:
+ description: detail of the model job
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedJobPb'
+ 403:
+ description: the trusted job is forbidden to update
+ 404:
+ description: the trusted job is not found
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ trusted_job = get_trusted_job(session, project_id, trusted_job_id)
+ if comment is not None:
+ trusted_job.comment = comment
+ if auth_status is not None:
+ TrustedJobController(session, project_id).inform_auth_status(trusted_job, auth_status)
+ data = trusted_job.to_proto()
+ session.commit()
+ return make_flask_response(data)
+
+
+class StopTrustedJobApi(Resource):
+
+ @credentials_required
+ def post(self, project_id: int, trusted_job_id: int):
+ """Stop the trusted job
+ ---
+ tags:
+ - tee
+ description: stop the trusted job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: trusted_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 204:
+ description: stop the trusted job successfully
+ 403:
+ description: the trusted job is forbidden to stop
+ 404:
+ description: the trusted job is not found
+ 409:
+ description: the trusted job is not running
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ trusted_job = get_trusted_job(session, project_id, trusted_job_id)
+ if trusted_job.get_status() != TrustedJobStatus.RUNNING:
+ raise ResourceConflictException(f'the trusted job {trusted_job.id} is not running')
+ succeeded, msg = stop_trusted_job(project_id, trusted_job.uuid)
+ if not succeeded:
+ raise InternalException(f'stop trusted job failed with msg {msg}')
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+class TrustedNotificationsApi(Resource):
+
+ @credentials_required
+ def get(self, project_id: int):
+ """Get the list of trusted notifications
+ ---
+ tags:
+ - tee
+ description: get the list of trusted notifications
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: list of trusted notifications
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.TrustedNotification'
+ 403:
+ description: trusted notification is forbidden to access
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ query = session.query(TrustedJobGroup).filter_by(resource=None)
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ data = [d.to_notification() for d in query.all()]
+ query = session.query(TrustedJob).filter_by(auth_status=AuthStatus.PENDING, type=TrustedJobType.EXPORT)
+ if project_id:
+ query = query.filter_by(project_id=project_id)
+ data += [d.to_notification() for d in query.all()]
+ data.sort(key=lambda x: x.created_at, reverse=True)
+ return make_flask_response(data)
+
+
+class ExportTrustedJobApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.TRUSTED_EXPORT_JOB, op_type=Event.OperationType.CREATE)
+ def post(self, project_id: int, trusted_job_id: int):
+ """Export the trusted job
+ ---
+ tags:
+ - tee
+ description: export the trusted job
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ required: true
+ - in: path
+ name: trusted_job_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 204:
+ description: export the trusted job successfully
+ 403:
+ description: the trusted job is forbidden to export
+ 404:
+ description: the trusted job is not found
+ 409:
+ description: the trusted job is not succeeded
+ """
+ if not Flag.TRUSTED_COMPUTING_ENABLED.value:
+ raise NoAccessException('trusted computing is not enabled')
+ with db.session_scope() as session:
+ trusted_job = get_trusted_job(session, project_id, trusted_job_id)
+ if trusted_job.type != TrustedJobType.ANALYZE or trusted_job.get_status() != TrustedJobStatus.SUCCEEDED:
+ raise ResourceConflictException(f'the trusted job {trusted_job.id} is not valid')
+ trusted_job = TrustedJobService(session).lock_and_update_export_count(trusted_job_id)
+ session.commit()
+ with db.session_scope() as session:
+ uuid = resource_uuid()
+ TrustedJobService(session).create_internal_export(uuid, trusted_job)
+ get_ticket_helper(session).create_ticket(TicketType.TK_CREATE_TRUSTED_EXPORT_JOB, TicketDetails(uuid=uuid))
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+
+def initialize_tee_apis(api):
+ api.add_resource(TrustedJobGroupsApi, '/projects//trusted_job_groups')
+ api.add_resource(TrustedJobGroupApi, '/projects//trusted_job_groups/')
+ api.add_resource(LaunchTrustedJobApi, '/projects//trusted_job_groups/:launch')
+ api.add_resource(TrustedJobsApi, '/projects//trusted_jobs')
+ api.add_resource(TrustedJobApi, '/projects//trusted_jobs/')
+ api.add_resource(StopTrustedJobApi, '/projects//trusted_jobs/:stop')
+ api.add_resource(TrustedNotificationsApi, '/projects//trusted_notifications')
+ api.add_resource(ExportTrustedJobApi, '/projects//trusted_jobs/:export')
+
+ schema_manager.append(CreateTrustedJobGroupParams)
+ schema_manager.append(ConfigTrustedJobGroupParams)
+ schema_manager.append(GetTrustedJobsParams)
+ schema_manager.append(UpdateTrustedJobParams)
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/apis_test.py b/web_console_v2/api/fedlearner_webconsole/tee/apis_test.py
new file mode 100644
index 000000000..dc6a17869
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/apis_test.py
@@ -0,0 +1,832 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import urllib.parse
+from unittest.mock import patch, MagicMock
+from datetime import datetime
+
+import grpc
+from google.protobuf.text_format import MessageToString
+from google.protobuf.json_format import MessageToDict
+from google.protobuf.empty_pb2 import Empty
+from testing.common import BaseTestCase
+from testing.rpc.client import FakeRpcError
+from http import HTTPStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, AlgorithmType
+from fedlearner_webconsole.dataset.models import Dataset, DataBatch
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob, TrustedJobStatus, \
+ GroupCreateStatus, TrustedJobType
+from fedlearner_webconsole.proto.tee_pb2 import ParticipantDataset, ParticipantDatasetList, Resource
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.job.models import JobType, Job, JobState
+from fedlearner_webconsole.proto.rpc.v2 import system_service_pb2
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import GetTrustedJobGroupResponse
+from fedlearner_webconsole.setting.service import SettingService
+
+
+class TrustedJobGroupsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ participant2 = Participant(id=2, name='part3', domain_name='fl-domain3.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ proj_part2 = ProjectParticipant(project_id=1, participant_id=2)
+ dataset1 = Dataset(id=1, name='dataset-name1', uuid='dataset-uuid1', is_published=True)
+ dataset2 = Dataset(id=2, name='dataset-name2', uuid='dataset-uuid2', is_published=False)
+ algorithm = Algorithm(id=1,
+ uuid='algorithm-uuid1',
+ algorithm_project_id=1,
+ type=AlgorithmType.TRUSTED_COMPUTING)
+ algorithm_proj = AlgorithmProject(id=1, uuid='algorithm-proj-uuid')
+ resource = MessageToString(Resource(cpu=1, memory=1, replicas=1))
+ group1 = TrustedJobGroup(name='g1',
+ project_id=1,
+ coordinator_id=0,
+ created_at=datetime(2021, 1, 1, 0, 0, 1),
+ resource=resource)
+ group2 = TrustedJobGroup(name='g2-filter',
+ project_id=1,
+ coordinator_id=0,
+ created_at=datetime(2021, 1, 1, 0, 0, 2),
+ resource=resource)
+ group3 = TrustedJobGroup(name='g3',
+ project_id=2,
+ coordinator_id=0,
+ created_at=datetime(2021, 1, 1, 0, 0, 3),
+ resource=resource)
+ group4 = TrustedJobGroup(name='g4-filter',
+ project_id=1,
+ coordinator_id=0,
+ created_at=datetime(2021, 1, 1, 0, 0, 4),
+ resource=resource)
+ with db.session_scope() as session:
+ session.add_all([
+ project, participant1, participant2, proj_part1, proj_part2, dataset1, dataset2, algorithm,
+ algorithm_proj
+ ])
+ session.add_all([group1, group2, group3, group4])
+ session.commit()
+
+ def test_get_trusted_groups(self):
+ # get with project id 1
+ resp = self.get_helper('/api/v2/projects/1/trusted_job_groups')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g4-filter', 'g2-filter', 'g1'])
+ # get with project id 0
+ resp = self.get_helper('/api/v2/projects/0/trusted_job_groups')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g4-filter', 'g3', 'g2-filter', 'g1'])
+ # get with filter
+ filter_param = urllib.parse.quote('(name~="filter")')
+ resp = self.get_helper(f'/api/v2/projects/1/trusted_job_groups?filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g4-filter', 'g2-filter'])
+ # get with page
+ resp = self.get_helper(f'/api/v2/projects/1/trusted_job_groups?page=2&page_size=1&filter={filter_param}')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['g2-filter'])
+ # get nothing for invalid project id
+ resp = self.get_helper('/api/v2/projects/2/trusted_job_groups?page=2&page_size=1')
+ self.assertEqual(self.get_response_data(resp), [])
+
+ @patch('fedlearner_webconsole.rpc.v2.system_service_client.SystemServiceClient.check_tee_enabled')
+ def test_post_trusted_job_groups(self, mock_client: MagicMock):
+ mock_client.side_effect = [
+ system_service_pb2.CheckTeeEnabledResponse(tee_enabled=True),
+ system_service_pb2.CheckTeeEnabledResponse(tee_enabled=False),
+ ]
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups',
+ data={
+ 'name': 'group-name',
+ 'comment': 'This is a comment.',
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'dataset_id': 1,
+ 'participant_datasets': [{
+ 'participant_id': 1,
+ 'uuid': 'dataset-uuid3',
+ 'name': 'dataset-name3',
+ }],
+ 'resource': {
+ 'cpu': 2,
+ 'memory': 2,
+ 'replicas': 1,
+ },
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).filter_by(name='group-name', project_id=1).first()
+ self.assertEqual(group.name, 'group-name')
+ self.assertEqual(group.latest_version, 0)
+ self.assertEqual(group.comment, 'This is a comment.')
+ self.assertEqual(group.project_id, 1)
+ self.assertEqual(group.coordinator_id, 0)
+ self.assertEqual(group.analyzer_id, 1)
+ self.assertEqual(group.ticket_status, TicketStatus.APPROVED)
+ self.assertEqual(group.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(group.unauth_participant_ids, '1,2')
+ self.assertEqual(group.algorithm_uuid, 'algorithm-uuid1')
+ self.assertEqual(group.resource, MessageToString(Resource(cpu=2, memory=2, replicas=1)))
+ self.assertEqual(group.dataset_id, 1)
+ participant_datasets = ParticipantDatasetList(
+ items=[ParticipantDataset(
+ participant_id=1,
+ uuid='dataset-uuid3',
+ name='dataset-name3',
+ )])
+ self.assertEqual(group.participant_datasets, MessageToString(participant_datasets))
+
+ @patch('fedlearner_webconsole.algorithm.fetcher.AlgorithmFetcher.get_algorithm_from_participant')
+ @patch('fedlearner_webconsole.tee.apis.get_tee_enabled_participants')
+ def test_post_trusted_job_groups_failed(self, mock_get_tee_enabled_participants: MagicMock,
+ mock_get_algorithm: MagicMock):
+ mock_get_algorithm.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'not found')
+ mock_get_tee_enabled_participants.return_value = [0]
+ resource = {'cpu': 2, 'memory': 2, 'replicas': 1}
+ # fail due to no dataset is provided
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups',
+ data={
+ 'name': 'group-name',
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'resource': resource,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # fail due to dataset not found
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups',
+ data={
+ 'name': 'group-name',
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'dataset_id': 20,
+ 'resource': resource,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # fail due to dataset not published
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups',
+ data={
+ 'name': 'group-name',
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'dataset_id': 2,
+ 'resource': resource,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # fail due to participant not found
+ resp = self.post_helper(
+ '/api/v2/projects/1/trusted_job_groups',
+ data={
+ 'name': 'group-name',
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'resource': resource,
+ 'participant_datasets': [{
+ 'participant_id': 10,
+ 'uuid': 'dataset-uuid3',
+ 'name': 'dataset-name3',
+ }],
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # fail due to algorithm not found
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups',
+ data={
+ 'name': 'group-name',
+ 'algorithm_uuid': 'algorithm-uuid10',
+ 'resource': resource,
+ 'dataset_id': 2,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # fail due to duplicate name in project
+ with db.session_scope() as session:
+ group = TrustedJobGroup(name='group-name', project_id=1)
+ session.add(group)
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups',
+ data={
+ 'name': 'group-name',
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'resource': resource,
+ 'dataset_id': 1,
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+
+
+class TrustedJobGroupApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ algorithm_proj1 = AlgorithmProject(id=1, uuid='algo-proj-uuid1')
+ algorithm1 = Algorithm(id=1, uuid='algorithm-uuid1', algorithm_project_id=1)
+ participant = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ proj_part = ProjectParticipant(project_id=1, participant_id=1)
+ group1 = TrustedJobGroup(id=1,
+ name='group-name',
+ uuid='uuid',
+ comment='this is a comment',
+ project_id=1,
+ creator_username='admin',
+ coordinator_id=0,
+ created_at=datetime(2022, 7, 1, 0, 0, 0),
+ updated_at=datetime(2022, 7, 1, 0, 0, 0),
+ ticket_status=TicketStatus.APPROVED,
+ status=GroupCreateStatus.SUCCEEDED,
+ auth_status=AuthStatus.AUTHORIZED,
+ unauth_participant_ids='1,2',
+ algorithm_uuid='algorithm-uuid1')
+ group1.set_resource(Resource(cpu=2, memory=2, replicas=1))
+ group1.set_participant_datasets(
+ ParticipantDatasetList(
+ items=[ParticipantDataset(participant_id=1, name='dataset-name', uuid='dataset-uuid')]))
+ group2 = TrustedJobGroup(id=2,
+ name='group-name2',
+ uuid='uuid2',
+ project_id=1,
+ coordinator_id=1,
+ ticket_status=TicketStatus.APPROVED,
+ status=GroupCreateStatus.SUCCEEDED,
+ auth_status=AuthStatus.AUTHORIZED,
+ algorithm_uuid='algorithm-uuid1')
+ session.add_all([project, group1, group2, algorithm_proj1, algorithm1, participant, proj_part])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_trusted_job_group')
+ def test_get_trusted_job_group(self, mock_client: MagicMock):
+ mock_client.return_value = GetTrustedJobGroupResponse(auth_status='AUTHORIZED')
+ resp = self.get_helper('/api/v2/projects/1/trusted_job_groups/1')
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ data.pop('updated_at')
+ self.assertEqual(
+ data, {
+ 'id': 1,
+ 'name': 'group-name',
+ 'comment': 'this is a comment',
+ 'analyzer_id': 0,
+ 'coordinator_id': 0,
+ 'created_at': 1656633600,
+ 'creator_username': 'admin',
+ 'dataset_id': 0,
+ 'latest_job_status': 'NEW',
+ 'latest_version': 0,
+ 'project_id': 1,
+ 'resource': {
+ 'cpu': 2,
+ 'memory': 2,
+ 'replicas': 1
+ },
+ 'status': 'SUCCEEDED',
+ 'algorithm_id': 1,
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'algorithm_project_uuid': 'algo-proj-uuid1',
+ 'algorithm_participant_id': 0,
+ 'auth_status': 'AUTHORIZED',
+ 'ticket_auth_status': 'AUTH_PENDING',
+ 'ticket_status': 'APPROVED',
+ 'ticket_uuid': '',
+ 'unauth_participant_ids': [2],
+ 'uuid': 'uuid',
+ 'participant_datasets': {
+ 'items': [{
+ 'participant_id': 1,
+ 'name': 'dataset-name',
+ 'uuid': 'dataset-uuid'
+ }]
+ },
+ })
+ # failed due to not found
+ resp = self.get_helper('/api/v2/projects/1/trusted_job_groups/10')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ # get nothing due to project invalid
+ resp = self.get_helper('/api/v2/projects/10/trusted_job_groups/1')
+ self.assertIsNone(self.get_response_data(resp))
+
+ @patch('fedlearner_webconsole.algorithm.fetcher.AlgorithmFetcher.get_algorithm_from_participant')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.update_trusted_job_group')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_trusted_job_group')
+ def test_put_trusted_job_group(self, mock_inform: MagicMock, mock_update: MagicMock, mock_get_algorithm: MagicMock):
+ mock_get_algorithm.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'not found')
+ mock_inform.return_value = None
+ mock_update.return_value = None
+ with db.session_scope() as session:
+ algorithm_proj2 = AlgorithmProject(id=2, uuid='algo-proj-uuid2')
+ algorithm2 = Algorithm(id=2, uuid='algorithm-uuid2', algorithm_project_id=1)
+ algorithm3 = Algorithm(id=3, uuid='algorithm-uuid3', algorithm_project_id=2)
+ session.add_all([algorithm_proj2, algorithm2, algorithm3])
+ session.commit()
+ resp = self.put_helper('/api/v2/projects/1/trusted_job_groups/1',
+ data={
+ 'comment': 'new comment',
+ 'auth_status': 'PENDING',
+ 'algorithm_uuid': 'algorithm-uuid2',
+ 'resource': {
+ 'cpu': 4,
+ 'memory': 4,
+ 'replicas': 1
+ }
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ self.assertEqual(group.comment, 'new comment')
+ self.assertEqual(group.auth_status, AuthStatus.PENDING)
+ self.assertEqual(group.algorithm_uuid, 'algorithm-uuid2')
+ self.assertEqual(group.resource, MessageToString(Resource(cpu=4, memory=4, replicas=1)))
+ # failed due to group not found
+ resp = self.put_helper('/api/v2/projects/1/trusted_job_groups/10', data={'comment': 'new comment'})
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ # failed due to not creator but update algorithm
+ resp = self.put_helper('/api/v2/projects/1/trusted_job_groups/2', data={'algorithm_uuid': 'algorithm-uuid2'})
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+ # failed due to algorithm not found
+ resp = self.put_helper('/api/v2/projects/1/trusted_job_groups/1',
+ data={'algorithm_uuid': 'algorithm-not-exist'})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # failed due to algorithm project mismatch
+ resp = self.put_helper('/api/v2/projects/1/trusted_job_groups/1', data={'algorithm_uuid': 'algorithm-uuid3'})
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+ # failed due to group not fully created
+ with db.session_scope() as session:
+ group3 = TrustedJobGroup(id=3, project_id=1, status=GroupCreateStatus.PENDING)
+ session.add(group3)
+ session.commit()
+ resp = self.put_helper('/api/v2/projects/1/trusted_job_groups/3', data={'comment': 'new comment'})
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+ # failed due to grpc error, inconsistency in participants
+ mock_update.side_effect = FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT, 'mismatched algorithm project')
+ resp = self.put_helper('/api/v2/projects/1/trusted_job_groups/1', data={'algorithm_uuid': 'algorithm-uuid1'})
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.delete_trusted_job_group')
+ def test_delete_trusted_job_group(self, mock_delete: MagicMock):
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ mock_delete.return_value = None
+ # fail due to trusted job is running
+ with db.session_scope() as session:
+ trusted_job1 = TrustedJob(id=1,
+ name='V1',
+ trusted_job_group_id=1,
+ job_id=1,
+ status=TrustedJobStatus.RUNNING)
+ job1 = Job(id=1, name='job-name1', job_type=JobType.CUSTOMIZED, workflow_id=0, project_id=1)
+ trusted_job2 = TrustedJob(id=2,
+ name='V2',
+ trusted_job_group_id=1,
+ job_id=2,
+ status=TrustedJobStatus.SUCCEEDED)
+ job2 = Job(id=2, name='job-name2', job_type=JobType.CUSTOMIZED, workflow_id=0, project_id=1)
+ session.add_all([trusted_job1, job1, trusted_job2, job2])
+ session.commit()
+ resp = self.delete_helper('/api/v2/projects/1/trusted_job_groups/1')
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+ # fail due to grpc err
+ with db.session_scope() as session:
+ session.query(TrustedJob).filter_by(id=1).update({'status': TrustedJobStatus.FAILED})
+ session.commit()
+ mock_delete.side_effect = FakeRpcError(grpc.StatusCode.FAILED_PRECONDITION, 'trusted job is not deletable')
+ resp = self.delete_helper('/api/v2/projects/1/trusted_job_groups/1')
+ self.assertEqual(resp.status_code, HTTPStatus.INTERNAL_SERVER_ERROR)
+ # fail due to not creator
+ resp = self.delete_helper('/api/v2/projects/1/trusted_job_groups/2')
+ self.assertEqual(resp.status_code, HTTPStatus.FORBIDDEN)
+ # successfully delete group not exist
+ resp = self.delete_helper('/api/v2/projects/1/trusted_job_groups/3')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ # successfully delete
+ mock_delete.side_effect = None
+ resp = self.delete_helper('/api/v2/projects/1/trusted_job_groups/1')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ self.assertIsNone(session.query(TrustedJobGroup).get(1))
+ self.assertIsNone(session.query(TrustedJob).get(1))
+ self.assertIsNone(session.query(TrustedJob).get(2))
+ self.assertIsNone(session.query(Job).get(1))
+ self.assertIsNone(session.query(Job).get(2))
+
+
+class LaunchTrustedJobApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-name')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ algorithm = Algorithm(id=1,
+ uuid='algorithm-uuid1',
+ path='file:///data/algorithm/test',
+ type=AlgorithmType.TRUSTED_COMPUTING)
+ dataset1 = Dataset(id=1, name='dataset-name1', uuid='dataset-uuid1', is_published=True)
+ data_batch1 = DataBatch(id=1, dataset_id=1)
+ group = TrustedJobGroup(id=1,
+ uuid='group-uuid',
+ project_id=1,
+ latest_version=1,
+ coordinator_id=0,
+ status=GroupCreateStatus.SUCCEEDED,
+ auth_status=AuthStatus.AUTHORIZED,
+ algorithm_uuid='algorithm-uuid1',
+ dataset_id=1,
+ resource=MessageToString(Resource(cpu=2000, memory=2, replicas=1)))
+ session.add_all([project, participant1, proj_part1, algorithm, dataset1, data_batch1, group])
+ sys_var = SettingService(session).get_system_variables_dict()
+ session.commit()
+ sys_var['sgx_image'] = 'artifact.bytedance.com/fedlearner/pp_bioinformatics:e13eb8a1d96ad046ca7354b8197d41fd'
+ self.sys_var = sys_var
+
+ @patch('fedlearner_webconsole.tee.services.get_batch_data_path')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_variables_dict')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_launch_trusted_job(self, mock_remote_do_two_pc: MagicMock, mock_get_system_info: MagicMock,
+ mock_get_system_variables_dict: MagicMock, mock_get_batch_data_path: MagicMock):
+ mock_remote_do_two_pc.return_value = True, ''
+ mock_get_system_info.return_value = SystemInfo(domain_name='domain1')
+ mock_get_system_variables_dict.return_value = self.sys_var
+ mock_get_batch_data_path.return_value = 'file:///data/test'
+ # successful
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups/1:launch', data={'comment': 'this is a comment'})
+ self.assertEqual(resp.status_code, HTTPStatus.CREATED)
+ with db.session_scope() as session:
+ trusted_job = session.query(TrustedJob).filter_by(trusted_job_group_id=1, version=2).first()
+ self.assertIsNotNone(trusted_job)
+ self.assertEqual(trusted_job.comment, 'this is a comment')
+ self.assertEqual(trusted_job.coordinator_id, 0)
+ # fail due to not found group
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups/10:launch', data={})
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ # fail due to not fully auth
+ with db.session_scope() as session:
+ session.query(TrustedJobGroup).filter_by(id=1).update({'coordinator_id': 0, 'unauth_participant_ids': '1'})
+ session.commit()
+ resp = self.post_helper('/api/v2/projects/1/trusted_job_groups/1:launch', data={})
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+
+
+class TrustedJobsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ with db.session_scope() as session:
+ trusted_job1 = TrustedJob(id=1,
+ name='V1',
+ version=1,
+ project_id=1,
+ trusted_job_group_id=1,
+ job_id=1,
+ status=TrustedJobStatus.RUNNING)
+ trusted_job2 = TrustedJob(id=2,
+ name='V2',
+ version=2,
+ project_id=1,
+ trusted_job_group_id=1,
+ job_id=2,
+ status=TrustedJobStatus.SUCCEEDED)
+ trusted_job3 = TrustedJob(id=3,
+ name='V1',
+ version=1,
+ project_id=1,
+ trusted_job_group_id=2,
+ job_id=3,
+ status=TrustedJobStatus.RUNNING)
+ trusted_job4 = TrustedJob(id=4,
+ name='V1',
+ version=1,
+ project_id=2,
+ trusted_job_group_id=3,
+ job_id=4,
+ status=TrustedJobStatus.RUNNING)
+ trusted_job5 = TrustedJob(id=5,
+ name='V1-1',
+ type=TrustedJobType.EXPORT,
+ auth_status=AuthStatus.AUTHORIZED,
+ version=1,
+ project_id=1,
+ trusted_job_group_id=1,
+ job_id=5,
+ status=TrustedJobStatus.NEW,
+ created_at=datetime(2022, 11, 23, 12, 0, 0))
+ trusted_job6 = TrustedJob(id=6,
+ name='V2-1',
+ type=TrustedJobType.EXPORT,
+ auth_status=AuthStatus.WITHDRAW,
+ version=2,
+ project_id=1,
+ trusted_job_group_id=1,
+ job_id=6,
+ status=TrustedJobStatus.CREATED,
+ created_at=datetime(2022, 11, 23, 12, 0, 1))
+ job1 = Job(id=1,
+ name='job-name1',
+ job_type=JobType.CUSTOMIZED,
+ workflow_id=0,
+ project_id=1,
+ state=JobState.FAILED)
+ session.add_all([trusted_job1, trusted_job2, trusted_job3, trusted_job4, trusted_job5, trusted_job6, job1])
+ session.commit()
+
+ def test_get_trusted_job(self):
+ # successful and trusted job status is refreshed when api is called
+ resp = self.get_helper('/api/v2/projects/1/trusted_jobs?trusted_job_group_id=1')
+ data = self.get_response_data(resp)
+ self.assertEqual([(d['name'], d['status']) for d in data], [('V2', 'SUCCEEDED'), ('V1', 'FAILED')])
+
+ def test_get_export_trusted_job(self):
+ resp = self.get_helper('/api/v2/projects/1/trusted_jobs?trusted_job_group_id=1&type=EXPORT')
+ data = self.get_response_data(resp)
+ self.assertEqual([(d['name'], d['status']) for d in data], [('V2-1', 'CREATED'), ('V1-1', 'NEW')])
+
+
+class TrustedJobApi(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ self.participants_info = ParticipantsInfo(participants_map={
+ 'domain1': ParticipantInfo(auth_status='AUTHORIZED'),
+ 'domain2': ParticipantInfo(auth_status='PENDING'),
+ })
+ trusted_job1 = TrustedJob(id=1,
+ name='V1',
+ type=TrustedJobType.EXPORT,
+ job_id=1,
+ uuid='uuid1',
+ version=1,
+ comment='this is a comment',
+ project_id=1,
+ trusted_job_group_id=1,
+ status=TrustedJobStatus.PENDING,
+ auth_status=AuthStatus.AUTHORIZED,
+ algorithm_uuid='algorithm-uuid1',
+ created_at=datetime(2022, 6, 14, 0, 0, 0),
+ updated_at=datetime(2022, 6, 14, 0, 0, 1),
+ participants_info=MessageToString(self.participants_info))
+ job1 = Job(id=1,
+ name='job-name1',
+ job_type=JobType.CUSTOMIZED,
+ workflow_id=0,
+ project_id=1,
+ state=JobState.STARTED)
+ session.add_all([project, participant1, proj_part1, trusted_job1, job1])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_trusted_job')
+ def test_get_trusted_job(self, mock_client: MagicMock):
+ mock_client.return_value = GetTrustedJobGroupResponse(auth_status='AUTHORIZED')
+ # successful
+ resp = self.get_helper('/api/v2/projects/1/trusted_jobs/1')
+ data = self.get_response_data(resp)
+ part_info_dict = MessageToDict(
+ self.participants_info,
+ preserving_proto_field_name=True,
+ including_default_value_fields=True,
+ )
+ part_info_dict['participants_map']['domain2']['auth_status'] = 'AUTHORIZED'
+ del data['updated_at']
+ self.assertEqual(
+ data, {
+ 'algorithm_id': 0,
+ 'algorithm_uuid': 'algorithm-uuid1',
+ 'comment': 'this is a comment',
+ 'auth_status': 'AUTHORIZED',
+ 'export_dataset_id': 0,
+ 'finished_at': 0,
+ 'id': 1,
+ 'job_id': 1,
+ 'name': 'V1',
+ 'project_id': 1,
+ 'started_at': 0,
+ 'status': 'RUNNING',
+ 'ticket_status': 'APPROVED',
+ 'ticket_uuid': '',
+ 'trusted_job_group_id': 1,
+ 'coordinator_id': 0,
+ 'type': 'EXPORT',
+ 'uuid': 'uuid1',
+ 'version': 1,
+ 'created_at': 1655164800,
+ 'participants_info': part_info_dict,
+ 'ticket_auth_status': 'AUTHORIZED',
+ })
+ # fail due to not found
+ resp = self.get_helper('/api/v2/projects/1/trusted_jobs/10')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_trusted_job')
+ def test_put_trusted_job(self, mock_client: MagicMock, mock_get_system_info: MagicMock):
+ mock_client.return_value = Empty()
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ # successful update comment
+ resp = self.put_helper('/api/v2/projects/1/trusted_jobs/1', data={'comment': 'new comment'})
+ data = self.get_response_data(resp)
+ self.assertEqual(data['comment'], 'new comment')
+ # successful update auth_status
+ resp = self.put_helper('/api/v2/projects/1/trusted_jobs/1', data={'auth_status': 'WITHDRAW'})
+ data = self.get_response_data(resp)
+ self.assertEqual(data['auth_status'], 'WITHDRAW')
+ self.assertEqual(data['participants_info']['participants_map']['domain1']['auth_status'], 'WITHDRAW')
+ # fail due to not found
+ resp = self.put_helper('/api/v2/projects/1/trusted_jobs/10', data={'comment': 'new comment'})
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+
+
+class StopTrustedJobApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-name')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ group = TrustedJobGroup(id=1,
+ uuid='group-uuid',
+ project_id=1,
+ latest_version=1,
+ coordinator_id=0,
+ status=GroupCreateStatus.SUCCEEDED,
+ auth_status=AuthStatus.AUTHORIZED,
+ algorithm_uuid='algorithm-uuid')
+ trusted_job1 = TrustedJob(id=1,
+ uuid='trusted-job-uuid1',
+ name='V1',
+ project_id=1,
+ trusted_job_group_id=1,
+ job_id=1,
+ status=TrustedJobStatus.PENDING)
+ job1 = Job(id=1,
+ name='job-name1',
+ job_type=JobType.CUSTOMIZED,
+ project_id=1,
+ workflow_id=0,
+ state=JobState.STARTED)
+ session.add_all([project, participant1, proj_part1, group, trusted_job1, job1])
+ session.commit()
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_stop_trusted_job(self, mock_remote_do_two_pc):
+ mock_remote_do_two_pc.return_value = True, ''
+ # successful and trusted job status is refreshed to RUNNING before STOPPED
+ resp = self.post_helper('/api/v2/projects/1/trusted_jobs/1:stop')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ trusted_job1: TrustedJob = session.query(TrustedJob).get(1)
+ self.assertEqual(trusted_job1.status, TrustedJobStatus.STOPPED)
+ # fail due to not in RUNNING status since it is STOPPED before
+ resp = self.post_helper('/api/v2/projects/1/trusted_jobs/1:stop')
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+ # fail due to trusted job not found
+ resp = self.post_helper('/api/v2/projects/10/trusted_jobs/1:stop')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+
+
+class TrustedNotificationsApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ resource = MessageToString(Resource(cpu=1, memory=2, replicas=1))
+ with db.session_scope() as session:
+ group1 = TrustedJobGroup(id=1, project_id=1, name='group1', resource=resource)
+ group2 = TrustedJobGroup(id=2, project_id=1, name='group2', created_at=datetime(2022, 10, 1, 0, 0, 0))
+ group3 = TrustedJobGroup(id=3, project_id=1, name='group3', resource=resource)
+ group4 = TrustedJobGroup(id=4, project_id=2, name='group4', created_at=datetime(2022, 10, 1, 0, 0, 1))
+ group5 = TrustedJobGroup(id=5, project_id=1, name='group5', created_at=datetime(2022, 10, 1, 0, 0, 4))
+ trusted_job1 = TrustedJob(id=1,
+ name='V10-1',
+ auth_status=AuthStatus.PENDING,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ trusted_job_group_id=1,
+ created_at=datetime(2022, 10, 1, 0, 0, 2))
+ trusted_job2 = TrustedJob(id=2,
+ name='V10-2',
+ auth_status=AuthStatus.WITHDRAW,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ trusted_job_group_id=1)
+ trusted_job3 = TrustedJob(id=3,
+ name='V10',
+ auth_status=AuthStatus.PENDING,
+ type=TrustedJobType.ANALYZE,
+ project_id=1,
+ trusted_job_group_id=1)
+ trusted_job4 = TrustedJob(id=4,
+ name='V9-1',
+ auth_status=AuthStatus.PENDING,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ trusted_job_group_id=1,
+ created_at=datetime(2022, 10, 1, 0, 0, 3))
+ trusted_job5 = TrustedJob(id=5,
+ name='V9-2',
+ auth_status=AuthStatus.PENDING,
+ type=TrustedJobType.EXPORT,
+ project_id=2,
+ trusted_job_group_id=4,
+ created_at=datetime(2022, 10, 1, 0, 0, 5))
+
+ session.add_all([
+ group1, group2, group3, group4, group5, trusted_job1, trusted_job2, trusted_job3, trusted_job4,
+ trusted_job5
+ ])
+ session.commit()
+
+ def test_get_trusted_notifications(self):
+ resp = self.get_helper('/api/v2/projects/1/trusted_notifications')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['group5', 'group1-V9-1', 'group1-V10-1', 'group2'])
+ resp = self.get_helper('/api/v2/projects/2/trusted_notifications')
+ data = self.get_response_data(resp)
+ self.assertEqual([d['name'] for d in data], ['group4-V9-2', 'group4'])
+
+
+class ExportTrustedJobApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ Flag.TRUSTED_COMPUTING_ENABLED.value = True
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-name')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ trusted_job1 = TrustedJob(id=1,
+ uuid='uuid1',
+ project_id=1,
+ status=TrustedJobStatus.SUCCEEDED,
+ version=1,
+ trusted_job_group_id=1,
+ resource=MessageToString(Resource(cpu=1, memory=1, replicas=1)))
+ trusted_job2 = TrustedJob(id=2,
+ uuid='uuid2',
+ project_id=1,
+ status=TrustedJobStatus.RUNNING,
+ version=2,
+ trusted_job_group_id=1)
+ session.add_all([project, participant1, proj_part1, trusted_job1, trusted_job2])
+ session.commit()
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_export_trusted_job(self, mock_get_system_info: MagicMock):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ # successful
+ resp = self.post_helper('/api/v2/projects/1/trusted_jobs/1:export')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ resp = self.post_helper('/api/v2/projects/1/trusted_jobs/1:export')
+ self.assertEqual(resp.status_code, HTTPStatus.NO_CONTENT)
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).filter_by(type=TrustedJobType.EXPORT,
+ version=1,
+ project_id=1,
+ trusted_job_group_id=1,
+ export_count=1).first()
+ self.assertEqual(tee_export_job.name, 'V1-domain1-1')
+ self.assertEqual(tee_export_job.coordinator_id, 0)
+ self.assertEqual(tee_export_job.status, TrustedJobStatus.NEW)
+ self.assertIsNotNone(tee_export_job.ticket_uuid)
+ tee_export_job = session.query(TrustedJob).filter_by(type=TrustedJobType.EXPORT,
+ version=1,
+ project_id=1,
+ trusted_job_group_id=1,
+ export_count=2).first()
+ self.assertEqual(tee_export_job.name, 'V1-domain1-2')
+ tee_analyze_job = session.query(TrustedJob).get(1)
+ self.assertEqual(tee_analyze_job.export_count, 2)
+ # not found
+ resp = self.post_helper('/api/v2/projects/1/trusted_jobs/10:export')
+ self.assertEqual(resp.status_code, HTTPStatus.NOT_FOUND)
+ # not succeeded
+ resp = self.post_helper('/api/v2/projects/1/trusted_jobs/2:export')
+ self.assertEqual(resp.status_code, HTTPStatus.CONFLICT)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/controller.py b/web_console_v2/api/fedlearner_webconsole/tee/controller.py
new file mode 100644
index 000000000..da5f3bb96
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/controller.py
@@ -0,0 +1,232 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Tuple, List
+import logging
+import grpc
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.two_pc.transaction_manager import TransactionManager
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TransactionData, CreateTrustedJobGroupData, \
+ LaunchTrustedJobData, StopTrustedJobData, LaunchTrustedExportJobData
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob
+from fedlearner_webconsole.tee.services import TrustedJobGroupService, check_tee_enabled
+from fedlearner_webconsole.tee.utils import get_participant
+from fedlearner_webconsole.proto.tee_pb2 import DomainNameDataset
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.rpc.v2.job_service_client import JobServiceClient
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+from fedlearner_webconsole.rpc.v2.system_service_client import SystemServiceClient
+
+
+def _get_transaction_manager(project_id: int, two_pc_type: TwoPcType) -> TransactionManager:
+ with db.session_scope() as session:
+ project = session.query(Project).get(project_id)
+ participants = ParticipantService(session).get_platform_participants_by_project(project_id)
+ tm = TransactionManager(project_name=project.name,
+ project_token=project.token,
+ two_pc_type=two_pc_type,
+ participants=[participant.domain_name for participant in participants])
+ return tm
+
+
+def create_trusted_job_group(group: TrustedJobGroup) -> Tuple[bool, str]:
+ coordinator_pure_domain_name = SettingService.get_system_info().pure_domain_name
+ with db.session_scope() as session:
+ project = group.project
+ if project is None:
+ raise InternalException(f'project {group.project_id} of group {group.id} not found')
+ domain_name_datasets = []
+ dataset_self = group.dataset
+ if dataset_self is not None:
+ domain_name_datasets.append(
+ DomainNameDataset(pure_domain_name=coordinator_pure_domain_name,
+ dataset_uuid=dataset_self.uuid,
+ dataset_name=dataset_self.name))
+ participant_datasets = group.get_participant_datasets()
+ if participant_datasets is not None:
+ for pd in participant_datasets.items:
+ participant = get_participant(session, pd.participant_id)
+ domain_name_datasets.append(
+ DomainNameDataset(pure_domain_name=participant.pure_domain_name(),
+ dataset_uuid=pd.uuid,
+ dataset_name=pd.name))
+ analyzer_id = group.analyzer_id
+ if analyzer_id:
+ analyzer_pure_domain_name = get_participant(session, analyzer_id).pure_domain_name()
+ else:
+ analyzer_pure_domain_name = coordinator_pure_domain_name
+ tm = _get_transaction_manager(project_id=project.id, two_pc_type=TwoPcType.CREATE_TRUSTED_JOB_GROUP)
+ create_trusted_job_group_data = CreateTrustedJobGroupData(
+ name=group.name,
+ uuid=group.uuid,
+ ticket_uuid=group.ticket_uuid,
+ project_name=project.name,
+ creator_username=group.creator_username,
+ algorithm_uuid=group.algorithm_uuid,
+ domain_name_datasets=domain_name_datasets,
+ coordinator_pure_domain_name=coordinator_pure_domain_name,
+ analyzer_pure_domain_name=analyzer_pure_domain_name,
+ )
+ return tm.run(data=TransactionData(create_trusted_job_group_data=create_trusted_job_group_data))
+
+
+def launch_trusted_job(project_id: int, group_uuid: str, version: int):
+ initiator_pure_domain_name = SettingService.get_system_info().pure_domain_name
+ tm = _get_transaction_manager(project_id=project_id, two_pc_type=TwoPcType.LAUNCH_TRUSTED_JOB)
+ data = TransactionData(
+ launch_trusted_job_data=LaunchTrustedJobData(uuid=resource_uuid(),
+ group_uuid=group_uuid,
+ version=version,
+ initiator_pure_domain_name=initiator_pure_domain_name))
+ return tm.run(data)
+
+
+def stop_trusted_job(project_id: int, uuid: str):
+ tm = _get_transaction_manager(project_id=project_id, two_pc_type=TwoPcType.STOP_TRUSTED_JOB)
+ data = TransactionData(stop_trusted_job_data=StopTrustedJobData(uuid=uuid))
+ return tm.run(data)
+
+
+def launch_trusted_export_job(project_id: int, uuid: str):
+ tm = _get_transaction_manager(project_id=project_id, two_pc_type=TwoPcType.LAUNCH_TRUSTED_EXPORT_JOB)
+ data = TransactionData(launch_trusted_export_job_data=LaunchTrustedExportJobData(uuid=uuid))
+ return tm.run(data)
+
+
+def get_tee_enabled_participants(session: Session, project_id: int) -> List[int]:
+ enabled_pids = []
+ if check_tee_enabled():
+ enabled_pids.append(0)
+ participants = ParticipantService(session).get_platform_participants_by_project(project_id)
+ for p in participants:
+ client = SystemServiceClient.from_participant(p.domain_name)
+ try:
+ resp = client.check_tee_enabled()
+ if resp.tee_enabled:
+ enabled_pids.append(p.id)
+ except grpc.RpcError as e:
+ raise InternalException(f'failed to get participant {p.id}\'s tee enabled status '
+ f'with grpc code {e.code()} and details {e.details()}') from e
+ return enabled_pids
+
+
+class TrustedJobGroupController:
+
+ def __init__(self, session: Session, project_id: int):
+ self._session = session
+ self._clients = []
+ self._participant_ids = []
+ project = session.query(Project).get(project_id)
+ participants = ParticipantService(session).get_platform_participants_by_project(project_id)
+ for p in participants:
+ self._clients.append(JobServiceClient.from_project_and_participant(p.domain_name, project.name))
+ self._participant_ids.append(p.id)
+
+ def inform_trusted_job_group(self, group: TrustedJobGroup, auth_status: AuthStatus):
+ group.auth_status = auth_status
+ for client, pid in zip(self._clients, self._participant_ids):
+ try:
+ client.inform_trusted_job_group(group.uuid, auth_status)
+ except grpc.RpcError as e:
+ logging.warning(f'[trusted-job-group] failed to inform participant {pid}\'s '
+ f'trusted job group {group.uuid} with grpc code {e.code()} and details {e.details()}')
+
+ def update_trusted_job_group(self, group: TrustedJobGroup, algorithm_uuid: str):
+ for client, pid in zip(self._clients, self._participant_ids):
+ try:
+ client.update_trusted_job_group(group.uuid, algorithm_uuid)
+ except grpc.RpcError as e:
+ raise InternalException(f'failed to update participant {pid}\'s trusted job group {group.uuid} '
+ f'with grpc code {e.code()} and details {e.details()}') from e
+ group.algorithm_uuid = algorithm_uuid
+
+ def delete_trusted_job_group(self, group: TrustedJobGroup):
+
+ for client, pid in zip(self._clients, self._participant_ids):
+ try:
+ client.delete_trusted_job_group(group.uuid)
+ except grpc.RpcError as e:
+ raise InternalException(f'failed to delete participant {pid}\'s trusted job group {group.uuid} '
+ f'with grpc code {e.code()} and details {e.details()}') from e
+ TrustedJobGroupService(self._session).delete(group)
+
+ def update_unauth_participant_ids(self, group: TrustedJobGroup):
+ unauth_set = set(group.get_unauth_participant_ids())
+ for client, pid in zip(self._clients, self._participant_ids):
+ try:
+ resp = client.get_trusted_job_group(group.uuid)
+ status = AuthStatus[resp.auth_status]
+ if status == AuthStatus.AUTHORIZED:
+ unauth_set.discard(pid)
+ else:
+ unauth_set.add(pid)
+ except grpc.RpcError as e:
+ logging.warning(f'[trusted-job-group] failed to get participant {pid}\'s '
+ f'trusted job group {group.uuid} with grpc code {e.code()} and details {e.details()}')
+ group.set_unauth_participant_ids(list(unauth_set))
+
+
+class TrustedJobController:
+
+ def __init__(self, session: Session, project_id: int):
+ self._session = session
+ self._clients = []
+ self._participants = ParticipantService(session).get_platform_participants_by_project(project_id)
+ project = session.query(Project).get(project_id)
+ for p in self._participants:
+ self._clients.append(JobServiceClient.from_project_and_participant(p.domain_name, project.name))
+
+ def inform_auth_status(self, trusted_job: TrustedJob, auth_status: AuthStatus):
+ trusted_job.auth_status = auth_status
+ participants_info: ParticipantsInfo = trusted_job.get_participants_info()
+ self_pure_dn = SettingService.get_system_info().pure_domain_name
+ participants_info.participants_map[self_pure_dn].auth_status = auth_status.name
+ trusted_job.set_participants_info(participants_info)
+ for client, p in zip(self._clients, self._participants):
+ try:
+ client.inform_trusted_job(trusted_job.uuid, auth_status)
+ except grpc.RpcError as e:
+ logging.warning(f'[trusted-job] failed to inform participant {p.id}\'s '
+ f'trusted job {trusted_job.uuid} with grpc code {e.code()} and details {e.details()}')
+
+ def update_participants_info(self, trusted_job: TrustedJob):
+ participants_info = trusted_job.get_participants_info()
+ for client, p in zip(self._clients, self._participants):
+ try:
+ resp = client.get_trusted_job(trusted_job.uuid)
+ auth_status = AuthStatus[resp.auth_status]
+ participants_info.participants_map[p.pure_domain_name()].auth_status = auth_status.name
+ except grpc.RpcError as e:
+ logging.warning(f'[trusted-job] failed to get participant {p.id}\'s '
+ f'trusted job {trusted_job.uuid} with grpc code {e.code()} and details {e.details()}')
+ trusted_job.set_participants_info(participants_info)
+
+ def create_trusted_export_job(self, tee_export_job: TrustedJob, tee_analyze_job: TrustedJob):
+ # local trusted export job is already created by apis and this func is only used by runner
+ for client, p in zip(self._clients, self._participants):
+ try:
+ client.create_trusted_export_job(tee_export_job.uuid, tee_export_job.name, tee_export_job.export_count,
+ tee_analyze_job.uuid, tee_export_job.ticket_uuid)
+ except grpc.RpcError as e:
+ raise InternalException(
+ f'failed to create participant {p.id}\'s trusted export job {tee_export_job.uuid} '
+ f'with grpc code {e.code()} and details {e.details()}') from e
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/controller_test.py b/web_console_v2/api/fedlearner_webconsole/tee/controller_test.py
new file mode 100644
index 000000000..e1ba14752
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/controller_test.py
@@ -0,0 +1,213 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+import grpc
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.rpc.client import FakeRpcError
+from google.protobuf.empty_pb2 import Empty
+from google.protobuf.text_format import MessageToString
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob, TrustedJobType, TrustedJobStatus
+from fedlearner_webconsole.tee.controller import TrustedJobGroupController, TrustedJobController
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import GetTrustedJobGroupResponse
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+
+
+class TrustedJobGroupControllerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ participant2 = Participant(id=2, name='part3', domain_name='fl-domain3.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ proj_part2 = ProjectParticipant(project_id=1, participant_id=2)
+ group = TrustedJobGroup(id=1, uuid='uuid', auth_status=AuthStatus.AUTHORIZED, unauth_participant_ids='1,2')
+ session.add_all([project, participant1, participant2, proj_part1, proj_part2, group])
+ session.commit()
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_trusted_job_group')
+ def test_inform_trusted_job_group(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).inform_trusted_job_group(group, AuthStatus.AUTHORIZED)
+ self.assertEqual(group.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(mock_client.call_args_list, [(('uuid', AuthStatus.AUTHORIZED),),
+ (('uuid', AuthStatus.AUTHORIZED),)])
+ # grpc abort with error
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'trusted job group request.uuid not found')
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).inform_trusted_job_group(group, AuthStatus.AUTHORIZED)
+ self.assertEqual(group.auth_status, AuthStatus.AUTHORIZED)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.update_trusted_job_group')
+ def test_update_trusted_job_group(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).update_trusted_job_group(group, 'algorithm-uuid')
+ self.assertEqual(group.algorithm_uuid, 'algorithm-uuid')
+ self.assertEqual(mock_client.call_args_list, [(('uuid', 'algorithm-uuid'),), (('uuid', 'algorithm-uuid'),)])
+ # grpc abort with error
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT, 'mismatched algorithm project')
+ with self.assertRaises(InternalException):
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).update_trusted_job_group(group, 'algorithm-uuid')
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.delete_trusted_job_group')
+ def test_delete_trusted_job_group(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).delete_trusted_job_group(group)
+ self.assertEqual(mock_client.call_args_list, [(('uuid',),), (('uuid',),)])
+ # grpc abort with err
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT, 'trusted job is not deletable')
+ with self.assertRaises(InternalException):
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).delete_trusted_job_group(group)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_trusted_job_group')
+ def test_update_unauth_participant_ids(self, mock_client: MagicMock):
+ mock_client.side_effect = [
+ GetTrustedJobGroupResponse(auth_status='PENDING'),
+ GetTrustedJobGroupResponse(auth_status='AUTHORIZED')
+ ]
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).update_unauth_participant_ids(group)
+ self.assertCountEqual(group.get_unauth_participant_ids(), [1])
+ # grpc abort with err
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'trusted job group uuid not found')
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).get(1)
+ TrustedJobGroupController(session, 1).update_unauth_participant_ids(group)
+ self.assertCountEqual(group.get_unauth_participant_ids(), [1, 2])
+
+
+class TrustedJobControllerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ participant2 = Participant(id=2, name='part3', domain_name='fl-domain3.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ proj_part2 = ProjectParticipant(project_id=1, participant_id=2)
+ participants_info = ParticipantsInfo()
+ participants_info.participants_map['domain1'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['domain2'].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map['domain3'].auth_status = AuthStatus.WITHDRAW.name
+ tee_export_job = TrustedJob(id=1,
+ uuid='uuid1',
+ name='V1-domain1-1',
+ type=TrustedJobType.EXPORT,
+ version=1,
+ trusted_job_group_id=1,
+ ticket_status=TicketStatus.APPROVED,
+ ticket_uuid='ticket-uuid',
+ auth_status=AuthStatus.PENDING,
+ status=TrustedJobStatus.NEW,
+ export_count=1,
+ participants_info=MessageToString(participants_info))
+ tee_analyze_job = TrustedJob(id=2,
+ uuid='uuid2',
+ type=TrustedJobType.ANALYZE,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=1,
+ status=TrustedJobStatus.SUCCEEDED)
+ session.add_all(
+ [project, participant1, participant2, proj_part1, proj_part2, tee_export_job, tee_analyze_job])
+ session.commit()
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.inform_trusted_job')
+ def test_inform_auth_status(self, mock_client: MagicMock, mock_get_system_info: MagicMock):
+ mock_client.return_value = Empty()
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ with db.session_scope() as session:
+ trusted_job = session.query(TrustedJob).get(1)
+ TrustedJobController(session, 1).inform_auth_status(trusted_job, AuthStatus.AUTHORIZED)
+ self.assertEqual(trusted_job.auth_status, AuthStatus.AUTHORIZED)
+ self.assertEqual(trusted_job.get_participants_info().participants_map['domain1'].auth_status, 'AUTHORIZED')
+ self.assertEqual(mock_client.call_args_list, [(('uuid1', AuthStatus.AUTHORIZED),),
+ (('uuid1', AuthStatus.AUTHORIZED),)])
+ # grpc abort with error
+ mock_client.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'not found')
+ with db.session_scope() as session:
+ group = session.query(TrustedJob).get(1)
+ TrustedJobController(session, 1).inform_auth_status(group, AuthStatus.AUTHORIZED)
+ self.assertEqual(group.auth_status, AuthStatus.AUTHORIZED)
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_trusted_job')
+ def test_update_participants_info(self, mock_client: MagicMock):
+ mock_client.side_effect = [
+ GetTrustedJobGroupResponse(auth_status='WITHDRAW'),
+ GetTrustedJobGroupResponse(auth_status='AUTHORIZED')
+ ]
+ with db.session_scope() as session:
+ trusted_job = session.query(TrustedJob).get(1)
+ TrustedJobController(session, 1).update_participants_info(trusted_job)
+ self.assertEqual(trusted_job.get_participants_info().participants_map['domain2'].auth_status, 'WITHDRAW')
+ self.assertEqual(trusted_job.get_participants_info().participants_map['domain3'].auth_status, 'AUTHORIZED')
+
+ # grpc abort with err
+ mock_client.side_effect = [
+ FakeRpcError(grpc.StatusCode.NOT_FOUND, 'trusted job uuid not found'),
+ GetTrustedJobGroupResponse(auth_status='AUTHORIZED')
+ ]
+ with db.session_scope() as session:
+ trusted_job = session.query(TrustedJob).get(1)
+ TrustedJobController(session, 1).update_participants_info(trusted_job)
+ self.assertEqual(trusted_job.get_participants_info().participants_map['domain2'].auth_status, 'PENDING')
+ self.assertEqual(trusted_job.get_participants_info().participants_map['domain3'].auth_status, 'AUTHORIZED')
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_trusted_export_job')
+ def test_create_trusted_export_job(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).get(1)
+ tee_analyze_job = session.query(TrustedJob).get(2)
+ TrustedJobController(session, 1).create_trusted_export_job(tee_export_job, tee_analyze_job)
+ self.assertEqual(mock_client.call_args_list, [(('uuid1', 'V1-domain1-1', 1, 'uuid2', 'ticket-uuid'),)] * 2)
+ # grpc abort with err
+ mock_client.side_effect = [
+ FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT, 'tee_analyze_job uuid2 invalid'),
+ Empty()
+ ]
+ with self.assertRaises(InternalException):
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).get(1)
+ tee_analyze_job = session.query(TrustedJob).get(2)
+ TrustedJobController(session, 1).create_trusted_export_job(tee_export_job, tee_analyze_job)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/models.py b/web_console_v2/api/fedlearner_webconsole/tee/models.py
new file mode 100644
index 000000000..6c08708d7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/models.py
@@ -0,0 +1,368 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+from google.protobuf import text_format
+from typing import List, Optional
+from sqlalchemy.sql import func
+from sqlalchemy.sql.schema import Index
+
+from fedlearner_webconsole.algorithm.models import Algorithm
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.tee_pb2 import TrustedJobGroupPb, TrustedJobGroupRef, TrustedJobPb, TrustedJobRef, \
+ Resource, ParticipantDatasetList, TrustedNotification
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_and_auth_model import ReviewTicketAndAuthModel
+from fedlearner_webconsole.job.models import JobState, Job
+
+
+class GroupCreateStatus(enum.Enum):
+ PENDING = 'PENDING'
+ FAILED = 'FAILED'
+ SUCCEEDED = 'SUCCEEDED'
+
+
+class TicketAuthStatus(enum.Enum):
+ TICKET_PENDING = 'TICKET_PENDING'
+ TICKET_DECLINED = 'TICKET_DECLINED'
+ CREATE_PENDING = 'CREATE_PENDING'
+ CREATE_FAILED = 'CREATE_FAILED'
+ AUTH_PENDING = 'AUTH_PENDING'
+ AUTHORIZED = 'AUTHORIZED'
+
+
+class TrustedJobStatus(enum.Enum):
+ NEW = 'NEW'
+ CREATED = 'CREATED'
+ CREATE_FAILED = 'CREATE_FAILED'
+ PENDING = 'PENDING'
+ RUNNING = 'RUNNING'
+ SUCCEEDED = 'SUCCEEDED'
+ FAILED = 'FAILED'
+ STOPPED = 'STOPPED'
+
+
+class TrustedJobType(enum.Enum):
+ ANALYZE = 'ANALYZE'
+ EXPORT = 'EXPORT'
+
+
+class TrustedJobGroup(db.Model, ReviewTicketAndAuthModel):
+ __tablename__ = 'trusted_job_groups_v2'
+ __table_args__ = (
+ Index('idx_trusted_group_name', 'name'),
+ Index('idx_trusted_group_project_id', 'project_id'),
+ default_table_args('trusted_job_groups_v2'),
+ )
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ name = db.Column(db.String(255), comment='name')
+ uuid = db.Column(db.String(64), comment='uuid')
+ latest_version = db.Column(db.Integer, default=0, comment='latest version')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment of trusted job group')
+ project_id = db.Column(db.Integer, comment='project id')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created at', server_default=func.now())
+ updated_at = db.Column(db.DateTime(timezone=True),
+ comment='updated at',
+ server_default=func.now(),
+ onupdate=func.now())
+ creator_username = db.Column(db.String(255), comment='creator username')
+ coordinator_id = db.Column(db.Integer, comment='coordinator participant id')
+ analyzer_id = db.Column(db.Integer, comment='analyzer participant id')
+ status = db.Column(db.Enum(GroupCreateStatus, native_enum=False, length=32, create_constraint=False),
+ default=GroupCreateStatus.PENDING,
+ comment='create state')
+ unauth_participant_ids = db.Column(db.Text(), comment='unauth participant ids')
+ algorithm_uuid = db.Column(db.String(64), comment='algorithm uuid')
+ resource = db.Column('rsc', db.String(255), comment='resource')
+ dataset_id = db.Column(db.Integer, comment='dataset id')
+ participant_datasets = db.Column(db.Text(), comment='list of participant-to-dataset mapping')
+ # relationship to other tables
+ project = db.relationship(Project.__name__, primaryjoin='Project.id == foreign(TrustedJobGroup.project_id)')
+ algorithm = db.relationship(Algorithm.__name__,
+ primaryjoin='Algorithm.uuid == foreign(TrustedJobGroup.algorithm_uuid)')
+ trusted_jobs = db.relationship(
+ 'TrustedJob',
+ order_by='desc(TrustedJob.version)',
+ primaryjoin='TrustedJobGroup.id == foreign(TrustedJob.trusted_job_group_id)',
+ # To disable the warning of back_populates
+ overlaps='group')
+ dataset = db.relationship(Dataset.__name__, primaryjoin='Dataset.id == foreign(TrustedJobGroup.dataset_id)')
+
+ def to_proto(self) -> TrustedJobGroupPb:
+ group = TrustedJobGroupPb(
+ id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ latest_version=self.latest_version,
+ comment=self.comment,
+ project_id=self.project_id,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ creator_username=self.creator_username,
+ coordinator_id=self.coordinator_id,
+ analyzer_id=self.analyzer_id,
+ ticket_uuid=self.ticket_uuid,
+ ticket_status=self.ticket_status.name,
+ status=self.status.name,
+ auth_status=self.auth_status.name,
+ ticket_auth_status=self.get_ticket_auth_status().name,
+ latest_job_status=self.get_latest_job_status().name,
+ algorithm_id=self.algorithm.id if self.algorithm else 0,
+ algorithm_uuid=self.algorithm_uuid,
+ dataset_id=self.dataset_id,
+ )
+ if self.unauth_participant_ids is not None:
+ group.unauth_participant_ids.extend(self.get_unauth_participant_ids())
+ if self.resource is not None:
+ group.resource.MergeFrom(self.get_resource())
+ if self.participant_datasets is not None:
+ group.participant_datasets.MergeFrom(self.get_participant_datasets())
+ return group
+
+ def to_ref(self) -> TrustedJobGroupRef:
+ group = TrustedJobGroupRef(
+ id=self.id,
+ name=self.name,
+ created_at=to_timestamp(self.created_at),
+ ticket_status=self.ticket_status.name,
+ status=self.status.name,
+ auth_status=self.auth_status.name,
+ ticket_auth_status=self.get_ticket_auth_status().name,
+ latest_job_status=self.get_latest_job_status().name,
+ is_configured=self.resource is not None,
+ )
+ if self.coordinator_id == 0:
+ group.is_creator = True
+ else:
+ group.is_creator = False
+ group.creator_id = self.coordinator_id
+ group.unauth_participant_ids.extend(self.get_unauth_participant_ids())
+ return group
+
+ def get_latest_job_status(self) -> TrustedJobStatus:
+ for trusted_job in self.trusted_jobs:
+ if trusted_job.type == TrustedJobType.ANALYZE:
+ return trusted_job.get_status()
+ return TrustedJobStatus.NEW
+
+ def get_ticket_auth_status(self) -> TicketAuthStatus:
+ if self.ticket_status == TicketStatus.PENDING:
+ return TicketAuthStatus.TICKET_PENDING
+ if self.ticket_status == TicketStatus.DECLINED:
+ return TicketAuthStatus.TICKET_DECLINED
+ if self.status == GroupCreateStatus.PENDING:
+ return TicketAuthStatus.CREATE_PENDING
+ if self.status == GroupCreateStatus.FAILED:
+ return TicketAuthStatus.CREATE_FAILED
+ if self.auth_status != AuthStatus.AUTHORIZED or len(self.get_unauth_participant_ids()) > 0:
+ return TicketAuthStatus.AUTH_PENDING
+ return TicketAuthStatus.AUTHORIZED
+
+ def get_resource(self) -> Optional[Resource]:
+ if self.resource is not None:
+ return text_format.Parse(self.resource, Resource())
+ return None
+
+ def set_resource(self, resource: Optional[Resource] = None):
+ if resource is None:
+ resource = Resource()
+ self.resource = text_format.MessageToString(resource)
+
+ def get_participant_datasets(self) -> Optional[ParticipantDatasetList]:
+ if self.participant_datasets is not None:
+ return text_format.Parse(self.participant_datasets, ParticipantDatasetList())
+ return None
+
+ def set_participant_datasets(self, participant_datasets: Optional[ParticipantDatasetList] = None):
+ if participant_datasets is None:
+ participant_datasets = ParticipantDatasetList()
+ self.participant_datasets = text_format.MessageToString(participant_datasets)
+
+ def get_unauth_participant_ids(self) -> List[int]:
+ if self.unauth_participant_ids is not None and self.unauth_participant_ids:
+ sids = self.unauth_participant_ids.split(',')
+ return [int(s) for s in sids]
+ return []
+
+ def set_unauth_participant_ids(self, ids: List[int]):
+ if len(ids) > 0:
+ self.unauth_participant_ids = ','.join([str(i) for i in ids])
+ else:
+ self.unauth_participant_ids = None
+
+ def is_deletable(self) -> bool:
+ for trusted_job in self.trusted_jobs:
+ if trusted_job.get_status() in [TrustedJobStatus.PENDING, TrustedJobStatus.RUNNING]:
+ return False
+ return True
+
+ def to_notification(self) -> TrustedNotification:
+ return TrustedNotification(
+ type=TrustedNotification.TRUSTED_JOB_GROUP_CREATE,
+ id=self.id,
+ name=self.name,
+ created_at=to_timestamp(self.created_at),
+ coordinator_id=self.coordinator_id,
+ )
+
+
+class TrustedJob(db.Model, ReviewTicketAndAuthModel):
+ __tablename__ = 'trusted_jobs_v2'
+ __table_args__ = (
+ Index('idx_trusted_name', 'name'),
+ Index('idx_trusted_project_id', 'project_id'),
+ default_table_args('trusted_jobs_v2'),
+ )
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+ name = db.Column(db.String(255), comment='name')
+ type = db.Column('trusted_job_type',
+ db.Enum(TrustedJobType, native_enum=False, length=32, create_constraint=False),
+ default=TrustedJobType.ANALYZE,
+ key='type',
+ comment='trusted job type')
+ job_id = db.Column(db.Integer, comment='job id')
+ uuid = db.Column(db.String(64), comment='uuid')
+ version = db.Column(db.Integer, comment='version')
+ export_count = db.Column(db.Integer, default=0, comment='export count')
+ comment = db.Column('cmt', db.Text(), key='comment', comment='comment of trusted job')
+ project_id = db.Column(db.Integer, comment='project id')
+ trusted_job_group_id = db.Column(db.Integer, comment='trusted job group id')
+ coordinator_id = db.Column(db.Integer, comment='coordinator participant id')
+ created_at = db.Column(db.DateTime(timezone=True), comment='created at', server_default=func.now())
+ updated_at = db.Column(db.DateTime(timezone=True),
+ comment='updated at',
+ server_default=func.now(),
+ onupdate=func.now())
+ started_at = db.Column(db.DateTime(timezone=True), comment='started_at')
+ finished_at = db.Column(db.DateTime(timezone=True), comment='finished_at')
+ status = db.Column(db.Enum(TrustedJobStatus, native_enum=False, length=32, create_constraint=False),
+ default=TrustedJobStatus.NEW,
+ comment='trusted job status')
+ algorithm_uuid = db.Column(db.String(64), comment='algorithm uuid')
+ resource = db.Column('rsc', db.String(255), comment='resource')
+ export_dataset_id = db.Column(db.Integer, comment='export dataset id')
+ result_key = db.Column(db.Text(), comment='result key')
+ # relationship to other tables
+ job = db.relationship(Job.__name__, primaryjoin='Job.id == foreign(TrustedJob.job_id)')
+ project = db.relationship(Project.__name__, primaryjoin='Project.id == foreign(TrustedJob.project_id)')
+ group = db.relationship('TrustedJobGroup',
+ primaryjoin='TrustedJobGroup.id == foreign(TrustedJob.trusted_job_group_id)')
+ algorithm = db.relationship(Algorithm.__name__, primaryjoin='Algorithm.uuid == foreign(TrustedJob.algorithm_uuid)')
+ export_dataset = db.relationship(Dataset.__name__,
+ primaryjoin='Dataset.id == foreign(TrustedJob.export_dataset_id)')
+
+ def to_proto(self) -> TrustedJobPb:
+ trusted_job = TrustedJobPb(
+ id=self.id,
+ type=self.type.name,
+ name=self.name,
+ job_id=self.job_id,
+ uuid=self.uuid,
+ version=self.version,
+ comment=self.comment,
+ project_id=self.project_id,
+ trusted_job_group_id=self.trusted_job_group_id,
+ coordinator_id=self.coordinator_id,
+ status=self.get_status().name,
+ created_at=to_timestamp(self.created_at),
+ updated_at=to_timestamp(self.updated_at),
+ started_at=to_timestamp(self.started_at) if self.started_at is not None else None,
+ finished_at=to_timestamp(self.finished_at) if self.finished_at is not None else None,
+ algorithm_id=self.algorithm.id if self.algorithm else 0,
+ algorithm_uuid=self.algorithm_uuid,
+ ticket_uuid=self.ticket_uuid,
+ ticket_status=self.ticket_status.name,
+ auth_status=self.auth_status.name,
+ participants_info=self.get_participants_info(),
+ ticket_auth_status=self.get_ticket_auth_status().name,
+ export_dataset_id=self.export_dataset_id,
+ )
+ if self.resource is not None:
+ trusted_job.resource.MergeFrom(self.get_resource())
+ return trusted_job
+
+ def to_ref(self) -> TrustedJobRef:
+ return TrustedJobRef(
+ id=self.id,
+ type=self.type.name,
+ name=self.name,
+ coordinator_id=self.coordinator_id,
+ job_id=self.job_id,
+ comment=self.comment,
+ status=self.get_status().name,
+ participants_info=self.get_participants_info(),
+ ticket_auth_status=self.get_ticket_auth_status().name,
+ started_at=to_timestamp(self.started_at) if self.started_at is not None else None,
+ finished_at=to_timestamp(self.finished_at) if self.finished_at is not None else None,
+ )
+
+ def get_resource(self) -> Optional[Resource]:
+ if self.resource is not None:
+ return text_format.Parse(self.resource, Resource())
+ return None
+
+ def set_resource(self, resource: Optional[Resource] = None):
+ if resource is None:
+ resource = Resource()
+ self.resource = text_format.MessageToString(resource)
+
+ def update_status(self):
+ if self.status in [TrustedJobStatus.FAILED, TrustedJobStatus.STOPPED, TrustedJobStatus.SUCCEEDED]:
+ return
+ if self.job is None:
+ return
+ job_state = self.job.state
+ if job_state == JobState.FAILED:
+ self.status = TrustedJobStatus.FAILED
+ if job_state == JobState.COMPLETED:
+ self.status = TrustedJobStatus.SUCCEEDED
+ if job_state == JobState.STARTED:
+ self.status = TrustedJobStatus.RUNNING
+ if job_state == job_state.STOPPED:
+ self.status = TrustedJobStatus.STOPPED
+ if job_state in [JobState.NEW, JobState.WAITING]:
+ self.status = TrustedJobStatus.PENDING
+ if self.status in [TrustedJobStatus.FAILED, TrustedJobStatus.STOPPED, TrustedJobStatus.SUCCEEDED]:
+ self.finished_at = self.job.updated_at
+
+ def get_status(self) -> TrustedJobStatus:
+ self.update_status()
+ return self.status
+
+ def to_notification(self) -> TrustedNotification:
+ return TrustedNotification(
+ type=TrustedNotification.TRUSTED_JOB_EXPORT,
+ id=self.id,
+ name=f'{self.group.name}-{self.name}',
+ created_at=to_timestamp(self.created_at),
+ coordinator_id=self.coordinator_id,
+ )
+
+ def get_ticket_auth_status(self) -> TicketAuthStatus:
+ if self.ticket_status == TicketStatus.PENDING:
+ return TicketAuthStatus.TICKET_PENDING
+ if self.ticket_status == TicketStatus.DECLINED:
+ return TicketAuthStatus.TICKET_DECLINED
+ if self.status == TrustedJobStatus.NEW:
+ return TicketAuthStatus.CREATE_PENDING
+ if self.status == TrustedJobStatus.CREATE_FAILED:
+ return TicketAuthStatus.CREATE_FAILED
+ if not self.is_all_participants_authorized():
+ return TicketAuthStatus.AUTH_PENDING
+ return TicketAuthStatus.AUTHORIZED
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/models_test.py b/web_console_v2/api/fedlearner_webconsole/tee/models_test.py
new file mode 100644
index 000000000..28c7c7646
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/models_test.py
@@ -0,0 +1,404 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime
+from google.protobuf import text_format
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob, TrustedJobStatus, \
+ GroupCreateStatus, TrustedJobType, TicketAuthStatus
+from fedlearner_webconsole.proto.tee_pb2 import TrustedJobGroupPb, TrustedJobGroupRef, TrustedJobPb, TrustedJobRef, \
+ Resource, ParticipantDataset, ParticipantDatasetList, TrustedNotification
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.job.models import JobState, Job, JobType
+
+
+class TrustedJobGroupTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ participant = Participant(id=1, name='p1', domain_name='test.domain.name')
+ group = TrustedJobGroup(
+ id=1,
+ name='trusted job group test',
+ latest_version=1,
+ comment='This is comment for test.',
+ project_id=2,
+ created_at=datetime(2022, 6, 15, 0, 0, 0),
+ updated_at=datetime(2022, 6, 15, 0, 0, 0),
+ creator_username='admin',
+ coordinator_id=1,
+ analyzer_id=1,
+ ticket_uuid='ticket-uuid',
+ ticket_status=TicketStatus.APPROVED,
+ status=GroupCreateStatus.SUCCEEDED,
+ auth_status=AuthStatus.AUTHORIZED,
+ unauth_participant_ids='1,2',
+ algorithm_uuid='algorithm-uuid3',
+ resource='cpu: 2\nmemory: 2\nreplicas: 1\n',
+ dataset_id=4,
+ participant_datasets="""
+ items {
+ participant_id: 1
+ uuid: "uuid1"
+ name: "name1"
+ }
+ items {
+ participant_id: 2
+ uuid: "uuid2"
+ name: "name2"
+ }
+ """,
+ )
+ session.add_all([group, participant])
+ session.commit()
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ pb = TrustedJobGroupPb(id=1,
+ name='trusted job group test',
+ latest_version=1,
+ comment='This is comment for test.',
+ project_id=2,
+ created_at=1655251200,
+ updated_at=1655251200,
+ creator_username='admin',
+ coordinator_id=1,
+ analyzer_id=1,
+ ticket_uuid='ticket-uuid',
+ ticket_status='APPROVED',
+ status='SUCCEEDED',
+ auth_status='AUTHORIZED',
+ latest_job_status='NEW',
+ ticket_auth_status='AUTH_PENDING',
+ unauth_participant_ids=[1, 2],
+ algorithm_id=0,
+ algorithm_uuid='algorithm-uuid3',
+ resource=Resource(cpu=2, memory=2, replicas=1),
+ dataset_id=4,
+ participant_datasets=ParticipantDatasetList(items=[
+ ParticipantDataset(participant_id=1, uuid='uuid1', name='name1'),
+ ParticipantDataset(participant_id=2, uuid='uuid2', name='name2'),
+ ]))
+ self.assertEqual(pb, group.to_proto())
+
+ def test_to_ref(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ ref = TrustedJobGroupRef(
+ id=1,
+ name='trusted job group test',
+ created_at=1655251200,
+ is_creator=False,
+ creator_id=1,
+ ticket_status='APPROVED',
+ status='SUCCEEDED',
+ auth_status='AUTHORIZED',
+ latest_job_status='NEW',
+ ticket_auth_status='AUTH_PENDING',
+ unauth_participant_ids=[1, 2],
+ is_configured=True,
+ )
+ self.assertEqual(ref, group.to_ref())
+
+ def test_get_latest_job_status(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ job_status = group.get_latest_job_status()
+ self.assertEqual(TrustedJobStatus.NEW, job_status)
+ with db.session_scope() as session:
+ new_job = TrustedJob(id=1, version=1, trusted_job_group_id=1, status=TrustedJobStatus.RUNNING)
+ session.add(new_job)
+ session.commit()
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ job_status = group.get_latest_job_status()
+ self.assertEqual(TrustedJobStatus.RUNNING, job_status)
+
+ def test_get_ticket_auth_status(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ # AUTH_PENDING
+ self.assertEqual(group.get_ticket_auth_status(), TicketAuthStatus.AUTH_PENDING)
+ # AUTHORIZED
+ group.unauth_participant_ids = None
+ self.assertEqual(group.get_ticket_auth_status(), TicketAuthStatus.AUTHORIZED)
+ # CREATED_FAILED
+ group.status = GroupCreateStatus.FAILED
+ self.assertEqual(group.get_ticket_auth_status(), TicketAuthStatus.CREATE_FAILED)
+ # CREATED_PENDING
+ group.status = GroupCreateStatus.PENDING
+ self.assertEqual(group.get_ticket_auth_status(), TicketAuthStatus.CREATE_PENDING)
+ # TICKET_DECLINED
+ group.ticket_status = TicketStatus.DECLINED
+ self.assertEqual(group.get_ticket_auth_status(), TicketAuthStatus.TICKET_DECLINED)
+ # TICKET_PENDING
+ group.ticket_status = TicketStatus.PENDING
+ self.assertEqual(group.get_ticket_auth_status(), TicketAuthStatus.TICKET_PENDING)
+
+ def test_get_resource(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ self.assertEqual(Resource(cpu=2, memory=2, replicas=1), group.get_resource())
+ group.resource = None
+ self.assertIsNone(group.get_resource())
+
+ def test_set_resource(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ group.set_resource(Resource(cpu=4, memory=4, replicas=1))
+ self.assertEqual('cpu: 4\nmemory: 4\nreplicas: 1\n', group.resource)
+ group.set_resource()
+ self.assertEqual('', group.resource)
+
+ def test_get_participant_datasets(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ expected = ParticipantDatasetList(items=[
+ ParticipantDataset(participant_id=1, uuid='uuid1', name='name1'),
+ ParticipantDataset(participant_id=2, uuid='uuid2', name='name2'),
+ ])
+ self.assertEqual(expected, group.get_participant_datasets())
+ group.participant_datasets = None
+ self.assertIsNone(group.get_participant_datasets())
+
+ def test_set_participant_datasets(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ pds = ParticipantDatasetList(items=[ParticipantDataset(participant_id=1, uuid='uuid1', name='name1')])
+ group.set_participant_datasets(pds)
+ self.assertEqual('items {\n participant_id: 1\n uuid: "uuid1"\n name: "name1"\n}\n',
+ group.participant_datasets)
+
+ def test_get_unauth_participant_ids(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ group.unauth_participant_ids = '2'
+ self.assertEqual([2], group.get_unauth_participant_ids())
+ group.unauth_participant_ids = '2,3,4'
+ self.assertEqual([2, 3, 4], group.get_unauth_participant_ids())
+ group.unauth_participant_ids = None
+ self.assertEqual([], group.get_unauth_participant_ids())
+
+ def test_set_unauth_participant_ids(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ group.set_unauth_participant_ids([])
+ self.assertIsNone(group.unauth_participant_ids)
+ group.set_unauth_participant_ids([1, 2, 3])
+ self.assertEqual('1,2,3', group.unauth_participant_ids)
+
+ def test_is_deletable(self):
+ with db.session_scope() as session:
+ trusted_job1 = TrustedJob(id=1, trusted_job_group_id=1, status=TrustedJobStatus.STOPPED)
+ trusted_job2 = TrustedJob(id=2, trusted_job_group_id=1, status=TrustedJobStatus.FAILED)
+ session.add_all([trusted_job1, trusted_job2])
+ session.commit()
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ self.assertTrue(group.is_deletable())
+ # not deletable
+ with db.session_scope() as session:
+ session.query(TrustedJob).filter_by(id=1).update({'status': TrustedJobStatus.RUNNING})
+ session.commit()
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ self.assertFalse(group.is_deletable())
+
+ def test_to_notification(self):
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ notif = group.to_notification()
+ self.assertEqual(notif.type, TrustedNotification.TRUSTED_JOB_GROUP_CREATE)
+ self.assertEqual(notif.id, 1)
+ self.assertEqual(notif.name, 'trusted job group test')
+ self.assertEqual(notif.created_at, 1655251200)
+ self.assertEqual(notif.coordinator_id, 1)
+
+
+class TrustedJobTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ participant = Participant(id=1, name='p1', domain_name='test.domain.name')
+ group = TrustedJobGroup(id=1, name='trusted job group name', project_id=2, coordinator_id=1)
+ self.participants_info = ParticipantsInfo(participants_map={
+ 'self': ParticipantInfo(auth_status='AUTHORIZED'),
+ 'part1': ParticipantInfo(auth_status='WITHDRAW'),
+ })
+ trusted_job = TrustedJob(
+ id=1,
+ type=TrustedJobType.ANALYZE,
+ name='V1',
+ job_id=1,
+ uuid='uuid test',
+ version=1,
+ comment='This is comment for test.',
+ project_id=2,
+ trusted_job_group_id=1,
+ coordinator_id=1,
+ auth_status=AuthStatus.AUTHORIZED,
+ ticket_status=TicketStatus.APPROVED,
+ participants_info=text_format.MessageToString(self.participants_info),
+ created_at=datetime(2022, 6, 14, 0, 0, 0),
+ updated_at=datetime(2022, 6, 14, 0, 0, 1),
+ started_at=datetime(2022, 6, 15, 0, 0, 0),
+ finished_at=datetime(2022, 6, 15, 0, 0, 1),
+ status=TrustedJobStatus.PENDING,
+ algorithm_uuid='algorithm-uuid3',
+ resource='cpu: 2\nmemory: 2\nreplicas: 1\n',
+ export_dataset_id=2,
+ )
+ job = Job(id=1,
+ name='trusted-job-1-1-1-uuid',
+ state=JobState.NEW,
+ job_type=JobType.CUSTOMIZED,
+ workflow_id=0,
+ project_id=1)
+ session.add_all([group, trusted_job, job, participant])
+ session.commit()
+
+ def test_to_proto(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ pb = TrustedJobPb(
+ id=1,
+ type='ANALYZE',
+ name='V1',
+ job_id=1,
+ uuid='uuid test',
+ version=1,
+ comment='This is comment for test.',
+ project_id=2,
+ trusted_job_group_id=1,
+ coordinator_id=1,
+ ticket_status='APPROVED',
+ auth_status='AUTHORIZED',
+ participants_info=self.participants_info,
+ ticket_auth_status='AUTH_PENDING',
+ created_at=1655164800,
+ updated_at=1655164801,
+ started_at=1655251200,
+ finished_at=1655251201,
+ status='PENDING',
+ algorithm_id=0,
+ algorithm_uuid='algorithm-uuid3',
+ resource=Resource(cpu=2, memory=2, replicas=1),
+ export_dataset_id=2,
+ )
+ self.assertEqual(pb, trusted_job.to_proto())
+
+ def test_to_ref(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ ref = TrustedJobRef(
+ id=1,
+ type='ANALYZE',
+ name='V1',
+ coordinator_id=1,
+ job_id=1,
+ comment='This is comment for test.',
+ started_at=1655251200,
+ finished_at=1655251201,
+ status='PENDING',
+ participants_info=self.participants_info,
+ ticket_auth_status='AUTH_PENDING',
+ )
+ self.assertEqual(ref, trusted_job.to_ref())
+
+ def test_get_resource(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ self.assertEqual(Resource(cpu=2, memory=2, replicas=1), trusted_job.get_resource())
+ trusted_job.resource = None
+ self.assertIsNone(trusted_job.get_resource())
+
+ def test_set_resource(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ trusted_job.set_resource(Resource(cpu=4, memory=4, replicas=1))
+ self.assertEqual('cpu: 4\nmemory: 4\nreplicas: 1\n', trusted_job.resource)
+ trusted_job.set_resource()
+ self.assertEqual('', trusted_job.resource)
+
+ def test_update_status(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ job: Job = session.query(Job).get(1)
+ # case 1
+ trusted_job.update_status()
+ self.assertEqual(TrustedJobStatus.PENDING, trusted_job.status)
+ # case 2
+ job.state = JobState.FAILED
+ trusted_job.update_status()
+ self.assertEqual(TrustedJobStatus.FAILED, trusted_job.status)
+ # case 3
+ trusted_job.status = TrustedJobStatus.RUNNING
+ job.state = JobState.COMPLETED
+ trusted_job.update_status()
+ self.assertEqual(TrustedJobStatus.SUCCEEDED, trusted_job.status)
+ # case 4
+ job.state = JobState.WAITING
+ trusted_job.update_status()
+ self.assertEqual(TrustedJobStatus.SUCCEEDED, trusted_job.status)
+
+ def test_get_status(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ job: Job = session.query(Job).get(1)
+ self.assertEqual(trusted_job.get_status(), TrustedJobStatus.PENDING)
+ job.state = JobState.STARTED
+ self.assertEqual(trusted_job.get_status(), TrustedJobStatus.RUNNING)
+ job.state = JobState.FAILED
+ self.assertEqual(trusted_job.get_status(), TrustedJobStatus.FAILED)
+ job.state = JobState.WAITING
+ self.assertEqual(trusted_job.get_status(), TrustedJobStatus.FAILED)
+
+ def test_to_notification(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ notif = trusted_job.to_notification()
+ self.assertEqual(notif.type, TrustedNotification.TRUSTED_JOB_EXPORT)
+ self.assertEqual(notif.id, 1)
+ self.assertEqual(notif.name, 'trusted job group name-V1')
+ self.assertEqual(notif.created_at, 1655164800)
+ self.assertEqual(notif.coordinator_id, 1)
+
+ def test_get_ticket_auth_status(self):
+ with db.session_scope() as session:
+ trusted_job: TrustedJob = session.query(TrustedJob).get(1)
+ self.assertEqual(trusted_job.get_ticket_auth_status(), TicketAuthStatus.AUTH_PENDING)
+ self.participants_info.participants_map['part1'].auth_status = 'AUTHORIZED'
+ trusted_job.set_participants_info(self.participants_info)
+ self.assertEqual(trusted_job.get_ticket_auth_status(), TicketAuthStatus.AUTHORIZED)
+ trusted_job.status = TrustedJobStatus.CREATE_FAILED
+ self.assertEqual(trusted_job.get_ticket_auth_status(), TicketAuthStatus.CREATE_FAILED)
+ trusted_job.status = TrustedJobStatus.NEW
+ self.assertEqual(trusted_job.get_ticket_auth_status(), TicketAuthStatus.CREATE_PENDING)
+ trusted_job.ticket_status = TicketStatus.PENDING
+ self.assertEqual(trusted_job.get_ticket_auth_status(), TicketAuthStatus.TICKET_PENDING)
+ trusted_job.ticket_status = TicketStatus.DECLINED
+ self.assertEqual(trusted_job.get_ticket_auth_status(), TicketAuthStatus.TICKET_DECLINED)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/runners.py b/web_console_v2/api/fedlearner_webconsole/tee/runners.py
new file mode 100644
index 000000000..db9811ae2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/runners.py
@@ -0,0 +1,222 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from envs import Envs
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, TeeRunnerOutput
+from fedlearner_webconsole.proto import dataset_pb2
+from fedlearner_webconsole.tee.models import TrustedJobGroup, GroupCreateStatus, TrustedJob, TrustedJobType, \
+ TrustedJobStatus
+from fedlearner_webconsole.tee.utils import get_pure_path
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.tee.controller import create_trusted_job_group, TrustedJobGroupController, \
+ launch_trusted_job, TrustedJobController, launch_trusted_export_job
+from fedlearner_webconsole.exceptions import WebConsoleApiException
+from fedlearner_webconsole.dataset.services import DatasetService, BatchService
+from fedlearner_webconsole.dataset.models import DatasetType, DatasetKindV2, DatasetFormat, ImportType, StoreFormat
+
+
+class TeeCreateRunner(IRunnerV2):
+
+ @staticmethod
+ def _create_trusted_job_group():
+ # schedule all groups with ticket APPROVED, status PENDING and coordinator_id 0
+ processed_groups = set()
+ with db.session_scope() as session:
+ groups_ids = session.query(TrustedJobGroup.id).filter_by(ticket_status=TicketStatus.APPROVED,
+ status=GroupCreateStatus.PENDING,
+ coordinator_id=0).all()
+ for group_id, *_ in groups_ids:
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).populate_existing().with_for_update().get(
+ group_id)
+ if group.status != GroupCreateStatus.PENDING:
+ continue
+ processed_groups.add(group.id)
+ try:
+ succeeded, msg = create_trusted_job_group(group)
+ except WebConsoleApiException as e:
+ succeeded = False
+ msg = e.details
+ if not succeeded:
+ group.status = GroupCreateStatus.FAILED
+ logging.warning(f'[create trusted job group scheduler]: group {group.id} failed, exception {msg}')
+ else:
+ group.status = GroupCreateStatus.SUCCEEDED
+ session.commit()
+ return processed_groups
+
+ @staticmethod
+ def _launch_trusted_job():
+ # schedule all newly created trusted job group satisfy
+ # state == SUCCEEDED and coordinator_id == 0 and version == 0 and
+ # auth_status == AUTHORIZED and unauth_participant_ids == None
+ processed_groups = set()
+ with db.session_scope() as session:
+ groups_ids = session.query(TrustedJobGroup.id).filter_by(status=GroupCreateStatus.SUCCEEDED,
+ coordinator_id=0,
+ latest_version=0,
+ unauth_participant_ids=None,
+ auth_status=AuthStatus.AUTHORIZED)
+ for group_id, *_ in groups_ids:
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).populate_existing().with_for_update().get(
+ group_id)
+ if group.latest_version or group.auth_status != AuthStatus.AUTHORIZED or group.unauth_participant_ids:
+ continue
+ group.latest_version = 1
+ session.commit()
+ processed_groups.add(group.id)
+ succeeded, msg = launch_trusted_job(group.project_id, group.uuid, group.latest_version)
+ if not succeeded:
+ logging.warning(f'[launch trusted job scheduler]: group {group.id} failed, exception {msg}')
+ return processed_groups
+
+ @staticmethod
+ def _create_trusted_export_job():
+ processed_ids = set()
+ with db.session_scope() as session:
+ trusted_jobs_ids = session.query(TrustedJob.id).filter_by(
+ type=TrustedJobType.EXPORT,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.NEW,
+ ).all()
+ for trusted_job_id, *_ in trusted_jobs_ids:
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).populate_existing().with_for_update().get(trusted_job_id)
+ tee_analyze_job = session.query(TrustedJob).filter_by(
+ type=TrustedJobType.ANALYZE,
+ trusted_job_group_id=tee_export_job.trusted_job_group_id,
+ version=tee_export_job.version).first()
+ processed_ids.add(tee_export_job.id)
+ try:
+ TrustedJobController(session, tee_export_job.project_id).create_trusted_export_job(
+ tee_export_job, tee_analyze_job)
+ tee_export_job.status = TrustedJobStatus.CREATED
+ except WebConsoleApiException as e:
+ tee_export_job.status = TrustedJobStatus.CREATE_FAILED
+ logging.warning(
+ f'[create trusted export job scheduler]: {tee_export_job.id} failed, exception {e.details}')
+ session.commit()
+ return processed_ids
+
+ @staticmethod
+ def _launch_trusted_export_job():
+ processed_ids = set()
+ with db.session_scope() as session:
+ trusted_jobs_ids = session.query(TrustedJob.id).filter_by(
+ type=TrustedJobType.EXPORT,
+ status=TrustedJobStatus.CREATED,
+ coordinator_id=0,
+ ).all()
+ for trusted_job_id, *_ in trusted_jobs_ids:
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).get(trusted_job_id)
+ if not tee_export_job.is_all_participants_authorized():
+ continue
+ processed_ids.add(tee_export_job.id)
+ succeeded, msg = launch_trusted_export_job(tee_export_job.project_id, tee_export_job.uuid)
+ if not succeeded:
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).get(trusted_job_id)
+ tee_export_job.status = TrustedJobStatus.FAILED
+ session.commit()
+ logging.warning(f'[launch trusted export job scheduler]: {tee_export_job.id} failed, exception {msg}')
+ return processed_ids
+
+ @staticmethod
+ def _create_export_dataset():
+ processed_ids = set()
+ with db.session_scope() as session:
+ trusted_jobs_ids = session.query(TrustedJob.id).filter_by(
+ type=TrustedJobType.EXPORT,
+ status=TrustedJobStatus.SUCCEEDED,
+ coordinator_id=0,
+ export_dataset_id=None,
+ ).all()
+ for trusted_job_id, *_ in trusted_jobs_ids:
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).populate_existing().with_for_update().get(trusted_job_id)
+ if tee_export_job.export_dataset_id:
+ continue
+ processed_ids.add(tee_export_job.id)
+ dataset = DatasetService(session).create_dataset(
+ dataset_pb2.DatasetParameter(
+ name=f'{tee_export_job.group.name}-{tee_export_job.name}',
+ is_published=False,
+ type=DatasetType.PSI.value,
+ project_id=tee_export_job.project_id,
+ kind=DatasetKindV2.INTERNAL_PROCESSED.value,
+ format=DatasetFormat.NONE_STRUCTURED.name,
+ path=f'{get_pure_path(Envs.STORAGE_ROOT)}/job_output/{tee_export_job.job.name}/export',
+ import_type=ImportType.COPY.value,
+ store_format=StoreFormat.UNKNOWN.value,
+ auth_status=AuthStatus.AUTHORIZED.name))
+ session.flush()
+ BatchService(session).create_batch(dataset_pb2.BatchParameter(dataset_id=dataset.id))
+ tee_export_job.export_dataset_id = dataset.id
+ session.commit()
+ return processed_ids
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ # high-frequency runner compared to TEEAuthRunner do the following 4 tasks
+ # 1. create trusted job groups with state is PENDING and ticket_status is APPROVED
+ # 2. launch newly-self-created trusted jobs with version 0 and status SUCCEEDED when fully authed
+ # 3. create remotely for local trusted export jobs with ticket_status APPROVED and status NEW
+ # 4. launch trusted export jobs with coordinator_id 0 when fully created and authed
+ # 5. create export dataset for successful trusted export job
+ created_group_ids = self._create_trusted_job_group()
+ launched_group_ids = self._launch_trusted_job()
+ created_trusted_export_job_ids = self._create_trusted_export_job()
+ launched_trusted_export_job_ids = self._launch_trusted_export_job()
+ created_dataset_trusted_export_job_ids = self._create_export_dataset()
+ return RunnerStatus.DONE, RunnerOutput(tee_runner_output=TeeRunnerOutput(
+ created_group_ids=list(created_group_ids),
+ launched_group_ids=list(launched_group_ids),
+ created_trusted_export_job_ids=list(created_trusted_export_job_ids),
+ launched_trusted_export_job_ids=list(launched_trusted_export_job_ids),
+ created_dataset_trusted_export_job_ids=list(created_dataset_trusted_export_job_ids),
+ ))
+
+
+class TeeResourceCheckRunner(IRunnerV2):
+
+ @staticmethod
+ def _update_unauth_participant_ids():
+ processed_groups = set()
+ with db.session_scope() as session:
+ group_ids = session.query(TrustedJobGroup.id).filter_by(status=GroupCreateStatus.SUCCEEDED).all()
+ for group_id, *_ in group_ids:
+ with db.session_scope() as session:
+ group = session.query(TrustedJobGroup).populate_existing().with_for_update().get(group_id)
+ TrustedJobGroupController(session, group.project_id).update_unauth_participant_ids(group)
+ processed_groups.add(group_id)
+ session.commit()
+ return processed_groups
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ # low-frequency runner compared to TEECreateRunner do the following 2 tasks
+ # 1. get auth_status of participants actively in case grpc InformTrustedJobGroup failed
+ # 2. TODO(liuledian): get export_auth_status actively
+ checked_group_ids = self._update_unauth_participant_ids()
+ return RunnerStatus.DONE, RunnerOutput(tee_runner_output=TeeRunnerOutput(
+ checked_group_ids=list(checked_group_ids)))
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/runners_test.py b/web_console_v2/api/fedlearner_webconsole/tee/runners_test.py
new file mode 100644
index 000000000..bd97c39e5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/runners_test.py
@@ -0,0 +1,403 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+from google.protobuf.empty_pb2 import Empty
+from google.protobuf.text_format import MessageToString
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.dataset.models import Dataset, DataBatch
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmProject, AlgorithmType
+from fedlearner_webconsole.tee.models import TrustedJobGroup, GroupCreateStatus, TrustedJob, TrustedJobType, \
+ TrustedJobStatus
+from fedlearner_webconsole.tee.runners import TeeCreateRunner, TeeResourceCheckRunner
+from fedlearner_webconsole.proto.tee_pb2 import ParticipantDatasetList, ParticipantDataset, Resource
+from fedlearner_webconsole.proto.rpc.v2.job_service_pb2 import GetTrustedJobGroupResponse
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.review.common import NO_CENTRAL_SERVER_UUID
+
+
+class TeeRunnerTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ participant2 = Participant(id=2, name='part3', domain_name='fl-domain3.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ proj_part2 = ProjectParticipant(project_id=1, participant_id=2)
+ dataset1 = Dataset(id=1, name='dataset-name1', uuid='dataset-uuid1', is_published=True)
+ data_batch1 = DataBatch(id=1, dataset_id=1)
+ dataset2 = Dataset(id=2, name='dataset-name2', uuid='dataset-uuid2', is_published=False)
+ algorithm = Algorithm(id=1,
+ uuid='algorithm-uuid1',
+ type=AlgorithmType.TRUSTED_COMPUTING,
+ algorithm_project_id=1)
+ algorithm_proj = AlgorithmProject(id=1, uuid='algorithm-proj-uuid')
+ session.add_all([
+ project, participant1, proj_part1, participant2, proj_part2, dataset1, data_batch1, dataset2, algorithm,
+ algorithm_proj
+ ])
+ session.commit()
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_create_trusted_job_group(self, mock_remote_do_two_pc, mock_get_system_info):
+ mock_remote_do_two_pc.return_value = True, ''
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ with db.session_scope() as session:
+ participant_datasets = ParticipantDatasetList(
+ items=[ParticipantDataset(participant_id=1, uuid='dataset-uuid3', name='dataset-name3')])
+ # group in ticket_status APPROVED / status PENDING / valid params
+ group1 = TrustedJobGroup(id=1,
+ project_id=1,
+ ticket_status=TicketStatus.APPROVED,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID,
+ algorithm_uuid='algorithm-uuid1',
+ dataset_id=1,
+ coordinator_id=0,
+ analyzer_id=1,
+ uuid='uuid1')
+ group1.set_participant_datasets(participant_datasets)
+ # group in ticket_status APPROVED / status PENDING / invalid params
+ # error at controller run
+ group2 = TrustedJobGroup(id=2,
+ project_id=10,
+ ticket_status=TicketStatus.APPROVED,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID,
+ algorithm_uuid='algorithm-uuid2',
+ dataset_id=1,
+ coordinator_id=0,
+ analyzer_id=0,
+ uuid='uuid2')
+ # error at prepare
+ group3 = TrustedJobGroup(id=3,
+ project_id=1,
+ ticket_status=TicketStatus.APPROVED,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID,
+ algorithm_uuid='algorithm-uuid1',
+ dataset_id=2,
+ coordinator_id=0,
+ analyzer_id=0,
+ uuid='uuid3')
+ # status FAILED
+ group4 = TrustedJobGroup(id=4,
+ project_id=1,
+ ticket_status=TicketStatus.APPROVED,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID,
+ algorithm_uuid='algorithm-uuid1',
+ dataset_id=1,
+ coordinator_id=0,
+ analyzer_id=0,
+ status=GroupCreateStatus.FAILED,
+ uuid='uuid4')
+ # status SUCCEEDED
+ group5 = TrustedJobGroup(id=5,
+ project_id=1,
+ ticket_status=TicketStatus.APPROVED,
+ ticket_uuid=NO_CENTRAL_SERVER_UUID,
+ algorithm_uuid='algorithm-uuid1',
+ dataset_id=1,
+ coordinator_id=0,
+ analyzer_id=0,
+ status=GroupCreateStatus.SUCCEEDED,
+ uuid='uuid5')
+ session.add_all([group1, group2, group3, group4, group5])
+ session.commit()
+ runner = TeeCreateRunner()
+ # first run
+ # pylint: disable=protected-access
+ processed_groups = runner._create_trusted_job_group()
+ self.assertEqual(processed_groups, set([1, 2, 3]))
+ with db.session_scope() as session:
+ group1: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ self.assertEqual(group1.status, GroupCreateStatus.SUCCEEDED)
+ group2: TrustedJobGroup = session.query(TrustedJobGroup).get(2)
+ self.assertEqual(group2.status, GroupCreateStatus.FAILED)
+ group3: TrustedJobGroup = session.query(TrustedJobGroup).get(3)
+ self.assertEqual(group3.status, GroupCreateStatus.FAILED)
+ # second run should do nothing
+ processed_groups = runner._create_trusted_job_group()
+ self.assertEqual(processed_groups, set())
+
+ @patch('fedlearner_webconsole.tee.services.get_batch_data_path')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager._remote_do_two_pc')
+ def test_launch_trusted_job(self, mock_remote_do_two_pc, mock_get_system_info, mock_get_batch_data_path):
+ mock_remote_do_two_pc.return_value = True, ''
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ mock_get_batch_data_path.return_value = 'file:///data/test'
+ with db.session_scope() as session:
+ # valid
+ group1 = TrustedJobGroup(id=1,
+ project_id=1,
+ uuid='uuid1',
+ status=GroupCreateStatus.SUCCEEDED,
+ coordinator_id=0,
+ latest_version=0,
+ algorithm_uuid='algorithm-uuid1',
+ dataset_id=1,
+ auth_status=AuthStatus.AUTHORIZED,
+ resource=MessageToString(Resource(cpu=1000, memory=1, replicas=1)))
+ # not fully authorized
+ group2 = TrustedJobGroup(id=2,
+ project_id=1,
+ uuid='uuid2',
+ status=GroupCreateStatus.SUCCEEDED,
+ coordinator_id=0,
+ latest_version=0,
+ algorithm_uuid='algorithm-uuid1',
+ auth_status=AuthStatus.AUTHORIZED,
+ unauth_participant_ids='1,2')
+ # not creator
+ group3 = TrustedJobGroup(id=3,
+ project_id=1,
+ uuid='uuid3',
+ status=GroupCreateStatus.SUCCEEDED,
+ coordinator_id=1,
+ latest_version=0,
+ algorithm_uuid='algorithm-uuid1',
+ auth_status=AuthStatus.AUTHORIZED)
+ session.add_all([group1, group2, group3])
+ session.commit()
+ runner = TeeCreateRunner()
+ # first run
+ # pylint: disable=protected-access
+ processed_groups = runner._launch_trusted_job()
+ self.assertCountEqual(processed_groups, [1])
+ with db.session_scope() as session:
+ group1: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ self.assertEqual(group1.latest_version, 1)
+ trusted_job: TrustedJob = session.query(TrustedJob).filter_by(trusted_job_group_id=1, version=1).first()
+ self.assertIsNotNone(trusted_job)
+ self.assertEqual(group2.latest_version, 0)
+ self.assertEqual(group3.latest_version, 0)
+ # second run
+ processed_groups = runner._launch_trusted_job()
+ self.assertEqual(processed_groups, set())
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.get_trusted_job_group')
+ def test_update_unauth_participant_ids(self, mock_client: MagicMock):
+ with db.session_scope() as session:
+ group1 = TrustedJobGroup(id=1,
+ uuid='uuid1',
+ project_id=1,
+ unauth_participant_ids='1,2',
+ status=GroupCreateStatus.SUCCEEDED)
+ group2 = TrustedJobGroup(id=2,
+ uuid='uuid2',
+ project_id=1,
+ unauth_participant_ids='2',
+ status=GroupCreateStatus.SUCCEEDED)
+ group3 = TrustedJobGroup(id=3,
+ uuid='uuid3',
+ project_id=1,
+ unauth_participant_ids='1,2',
+ status=GroupCreateStatus.FAILED)
+ session.add_all([group1, group2, group3])
+ session.commit()
+ # for 2 participants and 2 valid groups, client should be called 4 times
+ # pylint: disable=protected-access
+ mock_client.side_effect = [
+ GetTrustedJobGroupResponse(auth_status='PENDING'),
+ GetTrustedJobGroupResponse(auth_status='AUTHORIZED'),
+ GetTrustedJobGroupResponse(auth_status='AUTHORIZED'),
+ GetTrustedJobGroupResponse(auth_status='AUTHORIZED')
+ ]
+ runner = TeeResourceCheckRunner()
+ processed_groups = runner._update_unauth_participant_ids()
+ self.assertCountEqual(processed_groups, [1, 2])
+ with db.session_scope() as session:
+ group1: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ self.assertCountEqual(group1.get_unauth_participant_ids(), [1])
+ group2: TrustedJobGroup = session.query(TrustedJobGroup).get(2)
+ self.assertCountEqual(group2.get_unauth_participant_ids(), [])
+
+ @patch('fedlearner_webconsole.rpc.v2.job_service_client.JobServiceClient.create_trusted_export_job')
+ def test_create_trusted_export_job(self, mock_client: MagicMock):
+ mock_client.return_value = Empty()
+ with db.session_scope() as session:
+ tee_analyze_job = TrustedJob(id=1,
+ uuid='uuid1',
+ type=TrustedJobType.ANALYZE,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=2,
+ status=TrustedJobStatus.SUCCEEDED)
+ tee_export_job1 = TrustedJob(id=2,
+ uuid='uuid2',
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=1,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.NEW)
+ tee_export_job2 = TrustedJob(id=3,
+ uuid='uuid3',
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=2,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.NEW)
+ tee_export_job3 = TrustedJob(id=4,
+ uuid='uuid4',
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=1,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.CREATED)
+ tee_export_job4 = TrustedJob(id=5,
+ uuid='uuid5',
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=3,
+ ticket_status=TicketStatus.PENDING,
+ status=TrustedJobStatus.NEW)
+ session.add_all([tee_analyze_job, tee_export_job1, tee_export_job2, tee_export_job3, tee_export_job4])
+ session.commit()
+ runner = TeeCreateRunner()
+ # pylint: disable=protected-access
+ processed_ids = runner._create_trusted_export_job()
+ self.assertCountEqual(processed_ids, [2, 3])
+ with db.session_scope() as session:
+ tee_export_job1 = session.query(TrustedJob).get(2)
+ self.assertEqual(tee_export_job1.status, TrustedJobStatus.CREATED)
+ tee_export_job2 = session.query(TrustedJob).get(3)
+ self.assertEqual(tee_export_job2.status, TrustedJobStatus.CREATED)
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager.run')
+ def test_launch_trusted_export_job(self, mock_run):
+ mock_run.return_value = True, ''
+ with db.session_scope() as session:
+ tee_export_job1 = TrustedJob(id=1,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=1,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.CREATED,
+ coordinator_id=0)
+ tee_export_job2 = TrustedJob(id=2,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=2,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.CREATED,
+ coordinator_id=0)
+ participants_info = tee_export_job2.get_participants_info()
+ participants_info.participants_map['domain1'].auth_status = 'WITHDRAW'
+ tee_export_job2.set_participants_info(participants_info)
+ tee_export_job3 = TrustedJob(id=3,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=3,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.CREATED,
+ coordinator_id=1)
+ tee_export_job4 = TrustedJob(id=4,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=2,
+ trusted_job_group_id=1,
+ export_count=1,
+ ticket_status=TicketStatus.APPROVED,
+ status=TrustedJobStatus.NEW,
+ coordinator_id=0)
+ session.add_all([tee_export_job1, tee_export_job2, tee_export_job3, tee_export_job4])
+ session.commit()
+ runner = TeeCreateRunner()
+ # pylint: disable=protected-access
+ processed_ids = runner._launch_trusted_export_job()
+ self.assertCountEqual(processed_ids, [1])
+
+ def test_create_export_dataset(self):
+ with db.session_scope() as session:
+ group1 = TrustedJobGroup(id=1, name='group1', project_id=1, uuid='group-uuid1')
+ tee_export_job1 = TrustedJob(id=1,
+ name='V1-me-1',
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ job_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=1,
+ status=TrustedJobStatus.SUCCEEDED,
+ coordinator_id=0)
+ tee_export_job2 = TrustedJob(id=2,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=2,
+ status=TrustedJobStatus.RUNNING,
+ coordinator_id=0)
+ tee_export_job3 = TrustedJob(id=3,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=1,
+ trusted_job_group_id=1,
+ export_count=3,
+ status=TrustedJobStatus.SUCCEEDED,
+ coordinator_id=1)
+ tee_export_job4 = TrustedJob(id=4,
+ type=TrustedJobType.EXPORT,
+ project_id=1,
+ version=2,
+ trusted_job_group_id=1,
+ export_count=1,
+ status=TrustedJobStatus.SUCCEEDED,
+ coordinator_id=0,
+ export_dataset_id=1)
+ job1 = Job(id=1,
+ name='job-name1',
+ job_type=JobType.CUSTOMIZED,
+ workflow_id=0,
+ project_id=1,
+ state=JobState.COMPLETED)
+ session.add_all([group1, tee_export_job1, tee_export_job2, tee_export_job3, tee_export_job4, job1])
+ session.commit()
+ runner = TeeCreateRunner()
+ # pylint: disable=protected-access
+ processed_ids = runner._create_export_dataset()
+ self.assertCountEqual(processed_ids, [1])
+ with db.session_scope() as session:
+ tee_export_job1 = session.query(TrustedJob).get(1)
+ dataset = session.query(Dataset).get(tee_export_job1.export_dataset_id)
+ self.assertEqual(dataset.name, 'group1-V1-me-1')
+ self.assertEqual(len(dataset.data_batches), 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/services.py b/web_console_v2/api/fedlearner_webconsole/tee/services.py
new file mode 100644
index 000000000..28849626f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/services.py
@@ -0,0 +1,262 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+from google.protobuf.text_format import MessageToString
+from envs import Envs
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob, TrustedJobStatus, TrustedJobType
+from fedlearner_webconsole.tee.tee_job_template import TEE_YAML_TEMPLATE
+from fedlearner_webconsole.tee.utils import get_pure_path
+from fedlearner_webconsole.utils.pp_datetime import now
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.flag.models import Flag
+from fedlearner_webconsole.workflow_template.utils import make_variable
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.job.controller import create_job_without_workflow, stop_job, \
+ schedule_job, start_job_if_ready
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.dataset.data_path import get_batch_data_path
+
+
+def check_tee_enabled() -> bool:
+ # TODO(liuledian): call k8s api to check whether it has sgx machines instead of using system variables
+ return Flag.TEE_MACHINE_DEPLOYED.value
+
+
+def creat_tee_analyze_job_definition(session: Session, job_name: str, trusted_job: TrustedJob,
+ group: TrustedJobGroup) -> JobDefinition:
+ domain_name = SettingService.get_system_info().domain_name
+ analyzer_domain = domain_name
+ if group.analyzer_id:
+ analyzer_domain = session.query(Participant).get(group.analyzer_id).domain_name
+ provider_domain_ls = []
+ pds = group.get_participant_datasets()
+ if pds:
+ for pd in pds.items:
+ provider_domain_ls.append(session.query(Participant).get(pd.participant_id).domain_name)
+ input_data_path = ''
+ if group.dataset_id:
+ input_data_path = get_pure_path(get_batch_data_path(group.dataset.get_single_batch()))
+ provider_domain_ls.append(domain_name)
+ algorithm = AlgorithmFetcher(trusted_job.project_id).get_algorithm(trusted_job.algorithm_uuid)
+ variables_dict = {
+ 'project_name': group.project.name,
+ 'data_role': 'PROVIDER' if group.analyzer_id else 'ANALYZER',
+ 'task_type': 'ANALYZE',
+ 'input_data_path': input_data_path,
+ 'output_data_path': f'{get_pure_path(Envs.STORAGE_ROOT)}/job_output/{job_name}/output',
+ 'algorithm': {
+ 'path': get_pure_path(algorithm.path),
+ 'config': MessageToString(algorithm.parameter, as_one_line=True),
+ },
+ 'domain_name': domain_name,
+ 'analyzer_domain': analyzer_domain,
+ 'providers_domain': ','.join(provider_domain_ls),
+ 'pccs_url': 'https://sgx-dcap-server.bytedance.com/sgx/certification/v3/',
+ 'sgx_mem': '' if group.analyzer_id else '100',
+ 'worker_cpu': f'{trusted_job.get_resource().cpu}m',
+ 'worker_mem': f'{trusted_job.get_resource().memory}Gi',
+ 'worker_replicas': 1,
+ }
+ variables = [make_variable(name=k, typed_value=v) for k, v in variables_dict.items()]
+ return JobDefinition(name=job_name,
+ job_type=JobDefinition.CUSTOMIZED,
+ is_federated=False,
+ variables=variables,
+ yaml_template=TEE_YAML_TEMPLATE)
+
+
+def creat_tee_export_job_definition(session: Session, job_name: str, tee_export_job: TrustedJob,
+ tee_analyze_job: TrustedJob, group: TrustedJobGroup) -> JobDefinition:
+ domain_name = SettingService.get_system_info().domain_name
+ analyzer_domain = domain_name
+ if group.analyzer_id:
+ analyzer_domain = session.query(Participant).get(group.analyzer_id).domain_name
+ receiver_domain = domain_name
+ if tee_export_job.coordinator_id:
+ receiver_domain = session.query(Participant).get(tee_export_job.coordinator_id).domain_name
+ variables_dict = {
+ 'data_role': 'PROVIDER' if group.analyzer_id else 'ANALYZER',
+ 'task_type': 'EXPORT',
+ 'output_data_path': f'{get_pure_path(Envs.STORAGE_ROOT)}/job_output/{tee_analyze_job.job.name}/output',
+ 'export_data_path': f'{get_pure_path(Envs.STORAGE_ROOT)}/job_output/{job_name}/export/batch/0',
+ 'algorithm': {
+ 'path': '',
+ 'config': '',
+ },
+ 'domain_name': domain_name,
+ 'receiver_domain': receiver_domain,
+ 'analyzer_domain': analyzer_domain,
+ 'worker_cpu': f'{tee_export_job.get_resource().cpu}m',
+ 'worker_mem': f'{tee_export_job.get_resource().memory}Gi',
+ 'worker_replicas': 1,
+ }
+ variables = [make_variable(name=k, typed_value=v) for k, v in variables_dict.items()]
+ return JobDefinition(name=job_name,
+ job_type=JobDefinition.CUSTOMIZED,
+ is_federated=False,
+ variables=variables,
+ yaml_template=TEE_YAML_TEMPLATE)
+
+
+class TrustedJobGroupService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def delete(self, group: TrustedJobGroup):
+ for trusted_job in group.trusted_jobs:
+ self._session.delete(trusted_job.job)
+ self._session.query(TrustedJob).filter_by(trusted_job_group_id=group.id).delete()
+ self._session.delete(group)
+
+ def lock_and_update_version(self, group_id: int, version: Optional[int] = None) -> TrustedJobGroup:
+ """
+ If param version is None, increment the latest version by 1.
+ Otherwise, set the latest version as param version only if param version is larger
+ """
+ group: TrustedJobGroup = self._session.query(TrustedJobGroup).populate_existing().with_for_update().get(
+ group_id)
+ if version is None:
+ group.latest_version = group.latest_version + 1
+ elif version > group.latest_version:
+ group.latest_version = version
+ return group
+
+ def launch_trusted_job(self, group: TrustedJobGroup, uuid: str, version: int, coordinator_id: int):
+ self.lock_and_update_version(group.id, version)
+ name = f'V{version}'
+ trusted_job = TrustedJob(
+ name=name,
+ uuid=uuid,
+ version=version,
+ coordinator_id=coordinator_id,
+ project_id=group.project_id,
+ trusted_job_group_id=group.id,
+ started_at=now(),
+ status=TrustedJobStatus.CREATED,
+ algorithm_uuid=group.algorithm_uuid,
+ resource=group.resource,
+ )
+ job_name = f'trusted-job-{version}-{uuid}'
+ job_definition = creat_tee_analyze_job_definition(self._session, job_name, trusted_job, group)
+ job = create_job_without_workflow(self._session, job_definition, group.project_id, job_name)
+ schedule_job(self._session, job)
+ start_job_if_ready(self._session, job)
+ trusted_job.job_id = job.id
+ trusted_job.update_status()
+ self._session.add(trusted_job)
+ self._session.flush()
+
+
+class TrustedJobService:
+
+ def __init__(self, session: Session):
+ self._session = session
+
+ def lock_and_update_export_count(self, trusted_job_id: int) -> TrustedJob:
+ trusted_job = self._session.query(TrustedJob).populate_existing().with_for_update().get(trusted_job_id)
+ if not trusted_job.export_count:
+ trusted_job.export_count = 1
+ else:
+ trusted_job.export_count += 1
+ return trusted_job
+
+ def stop_trusted_job(self, trusted_job: TrustedJob):
+ if trusted_job.get_status() != TrustedJobStatus.RUNNING:
+ return
+ if trusted_job.job is not None:
+ stop_job(self._session, trusted_job.job)
+ self._session.flush()
+
+ def create_external_export(self, uuid: str, name: str, coordinator_id: int, export_count: int, ticket_uuid: str,
+ tee_analyze_job: TrustedJob):
+ """Create trusted export job for non-coordinator, called by rpc server layer"""
+ tee_export_job = TrustedJob(
+ name=name,
+ type=TrustedJobType.EXPORT,
+ uuid=uuid,
+ version=tee_analyze_job.version,
+ export_count=export_count,
+ project_id=tee_analyze_job.project_id,
+ trusted_job_group_id=tee_analyze_job.trusted_job_group_id,
+ coordinator_id=coordinator_id,
+ ticket_uuid=ticket_uuid,
+ ticket_status=TicketStatus.APPROVED,
+ auth_status=AuthStatus.PENDING,
+ status=TrustedJobStatus.CREATED,
+ resource=tee_analyze_job.resource,
+ result_key=tee_analyze_job.result_key,
+ )
+ participants = ParticipantService(self._session).get_participants_by_project(tee_analyze_job.project_id)
+ participants_info = ParticipantsInfo()
+ for p in participants:
+ participants_info.participants_map[p.pure_domain_name()].CopyFrom(
+ ParticipantInfo(auth_status=AuthStatus.PENDING.name))
+ self_pure_dn = SettingService.get_system_info().pure_domain_name
+ participants_info.participants_map[self_pure_dn].auth_status = AuthStatus.PENDING.name
+ coordinator_pure_dn = self._session.query(Participant).get(coordinator_id).pure_domain_name()
+ participants_info.participants_map[coordinator_pure_dn].auth_status = AuthStatus.AUTHORIZED.name
+ tee_export_job.set_participants_info(participants_info)
+ self._session.add(tee_export_job)
+ self._session.flush()
+
+ def create_internal_export(self, uuid: str, tee_analyze_job: TrustedJob):
+ """Create trusted export job for coordinator, called by api layer"""
+ self_pure_dn = SettingService.get_system_info().pure_domain_name
+ tee_export_job = TrustedJob(
+ name=f'V{tee_analyze_job.version}-{self_pure_dn}-{tee_analyze_job.export_count}',
+ type=TrustedJobType.EXPORT,
+ uuid=uuid,
+ version=tee_analyze_job.version,
+ export_count=tee_analyze_job.export_count,
+ project_id=tee_analyze_job.project_id,
+ trusted_job_group_id=tee_analyze_job.trusted_job_group_id,
+ coordinator_id=0,
+ auth_status=AuthStatus.AUTHORIZED,
+ status=TrustedJobStatus.NEW,
+ resource=tee_analyze_job.resource,
+ result_key=tee_analyze_job.result_key,
+ )
+ participants = ParticipantService(self._session).get_participants_by_project(tee_analyze_job.project_id)
+ participants_info = ParticipantsInfo()
+ for p in participants:
+ participants_info.participants_map[p.pure_domain_name()].auth_status = AuthStatus.PENDING.name
+ participants_info.participants_map[self_pure_dn].auth_status = AuthStatus.AUTHORIZED.name
+ tee_export_job.set_participants_info(participants_info)
+ self._session.add(tee_export_job)
+ self._session.flush()
+
+ def launch_trusted_export_job(self, tee_export_job: TrustedJob):
+ job_name = f'trusted-job-{tee_export_job.version}-{tee_export_job.uuid}'
+ tee_analyze_job = self._session.query(TrustedJob).filter_by(
+ type=TrustedJobType.ANALYZE,
+ trusted_job_group_id=tee_export_job.trusted_job_group_id,
+ version=tee_export_job.version).first()
+ job_definition = creat_tee_export_job_definition(self._session, job_name, tee_export_job, tee_analyze_job,
+ tee_export_job.group)
+ job = create_job_without_workflow(self._session, job_definition, tee_export_job.project_id, job_name)
+ tee_export_job.started_at = now()
+ schedule_job(self._session, job)
+ start_job_if_ready(self._session, job)
+ tee_export_job.job_id = job.id
+ tee_export_job.update_status()
+ self._session.flush()
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/tee_job_template.py b/web_console_v2/api/fedlearner_webconsole/tee/tee_job_template.py
new file mode 100644
index 000000000..08d063a0a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/tee_job_template.py
@@ -0,0 +1,154 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+TEE_YAML_TEMPLATE = """{
+ "apiVersion": "fedlearner.k8s.io/v1alpha1",
+ "kind": "FedApp",
+ "metadata": {
+ "name": self.name,
+ "namespace": system.variables.namespace,
+ "labels": dict(system.variables.labels),
+ "annotations": {
+ "queue": "fedlearner",
+ "schedulerName": "batch",
+ },
+ },
+ "spec": {
+ "activeDeadlineSeconds": 86400,
+ "fedReplicaSpecs": {
+ "Worker": {
+ "mustSuccess": True,
+ "port": {
+ "containerPort": 50051,
+ "name": "grpc-port",
+ "protocol": "TCP"
+ },
+ "template": {
+ "spec": {
+ "restartPolicy": "Never",
+ "containers": [
+ {
+ "name": "gene-analysis",
+ "image": str(system.variables.sgx_image),
+ "volumeMounts": list(system.variables.volume_mounts_list),
+ "command": [
+ '/bin/bash'
+ ],
+ "args": [
+ '/app/entrypoint.sh'
+ ],
+ "env": system.basic_envs_list + [
+ {
+ "name": "PROJECT_NAME",
+ "value": str(self.variables.get("project_name", ""))
+ },
+ {
+ "name": "DATA_ROLE",
+ "value": str(self.variables.get("data_role", ""))
+ },
+ {
+ "name": "TASK_TYPE",
+ "value": str(self.variables.get("task_type", ""))
+ },
+ {
+ "name": "INPUT_DATA_PATH",
+ "value": str(self.variables.get("input_data_path", ""))
+ },
+ {
+ "name": "OUTPUT_DATA_PATH",
+ "value": str(self.variables.get("output_data_path", ""))
+ },
+ {
+ "name": "EXPORT_DATA_PATH",
+ "value": str(self.variables.get("export_data_path", ""))
+ },
+ {
+ "name": "ALGORITHM_PATH",
+ "value": str(self.variables.algorithm.path)
+ },
+ {
+ "name": "ALGORITHM_PARAMETERS",
+ "value": str(self.variables.algorithm.config)
+ },
+ {
+ "name": "DOMAIN_NAME",
+ "value": str(self.variables.get("domain_name", ""))
+ },
+ {
+ "name": "ANALYZER_DOMAIN",
+ "value": str(self.variables.get("analyzer_domain", ""))
+ },
+ {
+ "name": "PROVIDERS_DOMAIN",
+ "value": str(self.variables.get("providers_domain", ""))
+ },
+ {
+ "name": "RECEIVER_DOMAIN",
+ "value": str(self.variables.get("receiver_domain", ""))
+ },
+ {
+ "name": "PCCS_URL",
+ "value": str(self.variables.get("pccs_url", ""))
+ },
+ {
+ "name": "RESULT_KEY",
+ "value": str(self.variables.get("result_key", ""))
+ }
+ ] + [],
+ "imagePullPolicy": "IfNotPresent",
+ "ports": [
+ {
+ "containerPort": 50051,
+ "name": "grpc-port",
+ "protocol": "TCP"
+ }
+ ],
+ "resources": {
+ "limits": {
+ "cpu": self.variables.worker_cpu,
+ "memory": self.variables.worker_mem
+ } ,
+ "requests": {
+ "cpu": self.variables.worker_cpu,
+ "memory": self.variables.worker_mem
+ }
+ } if not self.variables.get("sgx_mem", "") else {
+ "limits": {
+ sgx_epc_mem": str(self.variables.sgx_mem),
+ "cpu": self.variables.worker_cpu,
+ "memory": self.variables.worker_mem
+ } ,
+ "requests": {
+ "sgx_epc_mem": str(self.variables.sgx_mem),
+ "cpu": self.variables.worker_cpu,
+ "memory": self.variables.worker_mem
+ }
+ }
+
+ }
+ ],
+ "imagePullSecrets": [
+ {
+ "name": "regcred"
+ }
+ ],
+ "volumes": list(system.variables.volumes_list)
+ }
+ },
+ "replicas": self.variables.worker_replicas
+ }
+ }
+ }
+}"""
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/utils.py b/web_console_v2/api/fedlearner_webconsole/tee/utils.py
new file mode 100644
index 000000000..9421cca0c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/utils.py
@@ -0,0 +1,78 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from urllib.parse import urlparse
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.algorithm.models import Algorithm
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.exceptions import InvalidArgumentException, NotFoundException
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob
+from fedlearner_webconsole.proto.algorithm_pb2 import AlgorithmPb
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+
+
+def get_project(session: Session, project_id: int) -> Project:
+ project = session.query(Project).get(project_id)
+ if project is None:
+ raise InvalidArgumentException(f'project {project_id} is not found')
+ return project
+
+
+def get_dataset(session: Session, dataset_id: int) -> Dataset:
+ dataset = session.query(Dataset).get(dataset_id)
+ if dataset is None:
+ raise InvalidArgumentException(f'dataset {dataset_id} is not found')
+ return dataset
+
+
+def get_algorithm(session: Session, algorithm_id: int) -> Algorithm:
+ algorithm = session.query(Algorithm).get(algorithm_id)
+ if algorithm is None:
+ raise InvalidArgumentException(f'algorithm {algorithm_id} is not found')
+ return algorithm
+
+
+def get_participant(session: Session, participant_id: int) -> Participant:
+ participant = session.query(Participant).get(participant_id)
+ if participant is None:
+ raise InvalidArgumentException(f'participant {participant_id} is not found')
+ return participant
+
+
+def get_trusted_job_group(session: Session, project_id: int, group_id: int) -> TrustedJobGroup:
+ group = session.query(TrustedJobGroup).filter_by(id=group_id, project_id=project_id).first()
+ if group is None:
+ raise NotFoundException(f'trusted job group {group_id} is not found')
+ return group
+
+
+def get_trusted_job(session: Session, project_id: int, trusted_job_id: int) -> TrustedJob:
+ trusted_job = session.query(TrustedJob).filter_by(id=trusted_job_id, project_id=project_id).first()
+ if trusted_job is None:
+ raise NotFoundException(f'trusted job {trusted_job_id} is not found')
+ return trusted_job
+
+
+def get_algorithm_with_uuid(project_id: int, algorithm_uuid: str) -> AlgorithmPb:
+ try:
+ return AlgorithmFetcher(project_id).get_algorithm(algorithm_uuid)
+ except NotFoundException as e:
+ raise InvalidArgumentException(f'algorithm {algorithm_uuid} is not found') from e
+
+
+def get_pure_path(path: str) -> str:
+ return urlparse(path).path
diff --git a/web_console_v2/api/fedlearner_webconsole/tee/utils_test.py b/web_console_v2/api/fedlearner_webconsole/tee/utils_test.py
new file mode 100644
index 000000000..8e2246bd4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/tee/utils_test.py
@@ -0,0 +1,104 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+import grpc
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.rpc.client import FakeRpcError
+from fedlearner_webconsole.tee.utils import get_project, get_dataset, get_algorithm, get_participant, \
+ get_trusted_job_group, get_trusted_job, get_algorithm_with_uuid, get_pure_path
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.algorithm.models import Algorithm
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.exceptions import InvalidArgumentException, NotFoundException
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob
+
+
+class UtilsTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project')
+ dataset = Dataset(id=1, name='dataset')
+ algorithm = Algorithm(id=1, name='algorithm', project_id=1, uuid='uuid1')
+ participant = Participant(id=1, name='part', domain_name='domain')
+ group = TrustedJobGroup(id=1, name='trusted-group', project_id=1)
+ trusted_job = TrustedJob(id=1, name='V1', version=1, project_id=1, trusted_job_group_id=1)
+ session.add_all([project, dataset, algorithm, participant, group, trusted_job])
+ session.commit()
+
+ def test_get_project(self):
+ with db.session_scope() as session:
+ project = get_project(session, 1)
+ self.assertEqual(project.name, 'project')
+ with self.assertRaises(InvalidArgumentException):
+ get_project(session, 2)
+
+ def test_get_dataset(self):
+ with db.session_scope() as session:
+ dataset = get_dataset(session, 1)
+ self.assertEqual(dataset.name, 'dataset')
+ with self.assertRaises(InvalidArgumentException):
+ get_dataset(session, 2)
+
+ def test_get_algorithm(self):
+ with db.session_scope() as session:
+ algorithm = get_algorithm(session, 1)
+ self.assertEqual(algorithm.name, 'algorithm')
+ with self.assertRaises(InvalidArgumentException):
+ get_algorithm(session, 2)
+
+ def test_get_participant(self):
+ with db.session_scope() as session:
+ participant = get_participant(session, 1)
+ self.assertEqual(participant.name, 'part')
+ with self.assertRaises(InvalidArgumentException):
+ get_participant(session, 2)
+
+ def test_get_trusted_job_group(self):
+ with db.session_scope() as session:
+ group = get_trusted_job_group(session, 1, 1)
+ self.assertEqual(group.name, 'trusted-group')
+ with self.assertRaises(NotFoundException):
+ get_trusted_job_group(session, 1, 2)
+
+ def test_get_trusted_job(self):
+ with db.session_scope() as session:
+ trusted_job = get_trusted_job(session, 1, 1)
+ self.assertEqual(trusted_job.name, 'V1')
+ with self.assertRaises(NotFoundException):
+ get_trusted_job(session, 1, 2)
+
+ @patch('fedlearner_webconsole.algorithm.fetcher.AlgorithmFetcher.get_algorithm_from_participant')
+ def test_get_algorithm_with_uuid(self, mock_get_algorithm: MagicMock):
+ mock_get_algorithm.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'not found')
+ algorithm = get_algorithm_with_uuid(1, 'uuid1')
+ self.assertEqual(algorithm.name, 'algorithm')
+ with self.assertRaises(InvalidArgumentException):
+ get_algorithm_with_uuid(1, 'not-exist')
+
+ def test_get_pure_path(self):
+ self.assertEqual(get_pure_path('file:///data/test'), '/data/test')
+ self.assertEqual(get_pure_path('/data/test'), '/data/test')
+ self.assertEqual(get_pure_path('hdfs:///data/test'), '/data/test')
+ self.assertEqual(get_pure_path('hdfs://fl.net/data/test'), '/data/test')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/two_pc/BUILD.bazel
new file mode 100644
index 000000000..f8d1585b0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/BUILD.bazel
@@ -0,0 +1,519 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "dataset_job_launcher_lib",
+ srcs = ["dataset_job_launcher.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "dataset_job_launcher_lib_test",
+ size = "small",
+ srcs = [
+ "dataset_job_launcher_test.py",
+ ],
+ imports = ["../.."],
+ main = "dataset_job_launcher_test.py",
+ deps = [
+ ":dataset_job_launcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "dataset_job_stopper_lib",
+ srcs = ["dataset_job_stopper.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "dataset_job_stopper_lib_test",
+ size = "small",
+ srcs = [
+ "dataset_job_stopper_test.py",
+ ],
+ imports = ["../.."],
+ main = "dataset_job_stopper_test.py",
+ deps = [
+ ":dataset_job_stopper_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "handlers_lib",
+ srcs = ["handlers.py"],
+ imports = ["../.."],
+ deps = [
+ ":dataset_job_launcher_lib",
+ ":dataset_job_stage_launcher_lib",
+ ":dataset_job_stage_stopper_lib",
+ ":dataset_job_stopper_lib",
+ ":model_job_creator_lib",
+ ":model_job_group_creator_lib",
+ ":model_job_launcher_lib",
+ ":models_lib",
+ ":trusted_export_job_launcher_lib",
+ ":trusted_job_group_creator_lib",
+ ":trusted_job_launcher_lib",
+ ":trusted_job_stopper_lib",
+ ":workflow_state_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "handlers_lib_test",
+ size = "small",
+ srcs = [
+ "handlers_test.py",
+ ],
+ imports = ["../.."],
+ main = "handlers_test.py",
+ deps = [
+ ":handlers_lib",
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "model_job_creator_lib",
+ srcs = ["model_job_creator.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ ":trusted_job_group_creator_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "model_job_creator_lib_test",
+ size = "small",
+ srcs = [
+ "model_job_creator_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_job_creator_test.py",
+ deps = [
+ ":model_job_creator_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "model_job_group_creator_lib",
+ srcs = ["model_job_group_creator.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "model_job_group_creator_lib_test",
+ size = "small",
+ srcs = [
+ "model_job_group_creator_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_job_group_creator_test.py",
+ deps = [
+ ":model_job_group_creator_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "model_job_launcher_lib",
+ srcs = ["model_job_launcher.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:service_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "model_job_launcher_lib_test",
+ size = "small",
+ srcs = [
+ "model_job_launcher_test.py",
+ ],
+ imports = ["../.."],
+ main = "model_job_launcher_test.py",
+ deps = [
+ ":model_job_launcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/mmgr:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:mixins_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_library(
+ name = "resource_manager_lib",
+ srcs = ["resource_manager.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto"],
+)
+
+py_library(
+ name = "trusted_job_group_creator_lib",
+ srcs = ["trusted_job_group_creator.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:ticket_helper_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "trusted_job_group_creator_lib_test",
+ size = "small",
+ srcs = [
+ "trusted_job_group_creator_test.py",
+ ],
+ imports = ["../.."],
+ main = "trusted_job_group_creator_test.py",
+ deps = [
+ ":trusted_job_group_creator_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/review:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/rpc:client_lib",
+ "@com_github_grpc_grpc//src/python/grpcio/grpc:grpcio",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "workflow_state_controller_lib",
+ srcs = ["workflow_state_controller.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:workflow_controller_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "transaction_manager_lib",
+ srcs = [
+ "transaction_manager.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":handlers_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "transaction_manager_lib_test",
+ size = "small",
+ srcs = [
+ "transaction_manager_test.py",
+ ],
+ imports = ["../.."],
+ main = "transaction_manager_test.py",
+ deps = [
+ ":transaction_manager_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "trusted_job_launcher_lib",
+ srcs = [
+ "trusted_job_launcher.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:fetcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:domain_name_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "trusted_job_launcher_lib_test",
+ size = "small",
+ srcs = [
+ "trusted_job_launcher_test.py",
+ ],
+ imports = ["../.."],
+ main = "trusted_job_launcher_test.py",
+ deps = [
+ ":trusted_job_launcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/algorithm:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:data_path_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "trusted_job_stopper_lib",
+ srcs = [
+ "trusted_job_stopper.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:resource_manager_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "trusted_job_stopper_lib_test",
+ size = "small",
+ srcs = [
+ "trusted_job_stopper_test.py",
+ ],
+ imports = ["../.."],
+ main = "trusted_job_stopper_test.py",
+ deps = [
+ ":trusted_job_stopper_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "trusted_export_job_launcher_lib",
+ srcs = [
+ "trusted_export_job_launcher.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "trusted_export_job_launcher_lib_test",
+ size = "small",
+ srcs = [
+ "trusted_export_job_launcher_test.py",
+ ],
+ imports = ["../.."],
+ main = "trusted_export_job_launcher_test.py",
+ deps = [
+ ":trusted_export_job_launcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/tee:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/base_model:base_model_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "dataset_job_stage_launcher_lib",
+ srcs = ["dataset_job_stage_launcher.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:local_controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "dataset_job_stage_launcher_lib_test",
+ size = "small",
+ srcs = [
+ "dataset_job_stage_launcher_test.py",
+ ],
+ imports = ["../.."],
+ main = "dataset_job_stage_launcher_test.py",
+ deps = [
+ ":dataset_job_stage_launcher_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "dataset_job_stage_stopper_lib",
+ srcs = ["dataset_job_stage_stopper.py"],
+ imports = ["../.."],
+ deps = [
+ ":resource_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:local_controllers_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "dataset_job_stage_stopper_lib_test",
+ size = "small",
+ srcs = [
+ "dataset_job_stage_stopper_test.py",
+ ],
+ imports = ["../.."],
+ main = "dataset_job_stage_stopper_test.py",
+ deps = [
+ ":dataset_job_stage_stopper_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/__init__.py b/web_console_v2/api/fedlearner_webconsole/two_pc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_launcher.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_launcher.py
new file mode 100644
index 000000000..55db153e5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_launcher.py
@@ -0,0 +1,76 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.workflow.workflow_controller import start_workflow_locally
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobState
+from fedlearner_webconsole.dataset.services import DatasetJobService
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+
+
+class DatasetJobLauncher(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.launch_dataset_job_data is not None
+ self._data = data.launch_dataset_job_data
+ self._session = session
+
+ def prepare(self) -> Tuple[bool, str]:
+ dataset_job = self._session.query(DatasetJob).populate_existing().with_for_update(read=True).filter_by(
+ uuid=self._data.dataset_job_uuid).first()
+ if dataset_job is None:
+ message = f'failed to find dataset_job, uuid is {self._data.dataset_job_uuid}'
+ logging.warning(f'[dataset_job launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_uuid}')
+ return False, message
+ if not dataset_job.state in [DatasetJobState.PENDING, DatasetJobState.RUNNING]:
+ message = f'dataset_job state check failed! current: {dataset_job.state.value}, ' \
+ f'expected: {DatasetJobState.PENDING.value} or {DatasetJobState.RUNNING.value},' \
+ f'uuid is {self._data.dataset_job_uuid}'
+ logging.warning(f'[dataset_job launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_uuid}')
+ return False, message
+ if dataset_job.workflow is None:
+ message = f'failed to find workflow, uuid is {self._data.dataset_job_uuid}'
+ logging.warning(f'[dataset_job launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_uuid}')
+ return False, message
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ dataset_job = self._session.query(DatasetJob).populate_existing().with_for_update().filter_by(
+ uuid=self._data.dataset_job_uuid).first()
+ if dataset_job.state == DatasetJobState.RUNNING:
+ return True, ''
+ if dataset_job.state != DatasetJobState.PENDING:
+ message = f'dataset_job state check failed! current: {dataset_job.state.value}, ' \
+ f'expected: {DatasetJobState.PENDING.value} or {DatasetJobState.RUNNING.value},' \
+ f'uuid is {self._data.dataset_job_uuid}'
+ logging.warning(f'[dataset_job launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_uuid}')
+ return False, message
+ try:
+ start_workflow_locally(self._session, dataset_job.workflow)
+ except RuntimeError as e:
+ logging.error(f'[dataset_job launch 2pc] commit: {e}, uuid: {self._data.dataset_job_uuid}')
+ raise
+ DatasetJobService(self._session).start_dataset_job(dataset_job)
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[dataset_job launch 2pc] abort')
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_launcher_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_launcher_test.py
new file mode 100644
index 000000000..b28b9367a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_launcher_test.py
@@ -0,0 +1,163 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock, ANY
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobKind, DatasetJobState
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.two_pc.dataset_job_launcher import DatasetJobLauncher
+from fedlearner_webconsole.proto.two_pc_pb2 import LaunchDatasetJobData, \
+ TransactionData
+
+
+class DatasetJobLauncherTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _DATASET_JOB_ID = 1
+ _WORKFLOW_ID = 1
+ _DATASET_JOB_UUID = 'test_uuid'
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=self._PROJECT_ID, name='test')
+ session.add(project)
+ workflow = Workflow(id=self._WORKFLOW_ID, uuid=self._DATASET_JOB_UUID)
+ session.add(workflow)
+ session.commit()
+ launch_dataset_job_data = LaunchDatasetJobData(dataset_job_uuid=self._DATASET_JOB_UUID)
+ self.data = TransactionData(launch_dataset_job_data=launch_dataset_job_data)
+
+ def test_prepare_no_dataset_job(self):
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_illegal_state(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.FAILED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_no_related_workflow(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=0,
+ state=DatasetJobState.PENDING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_successfully(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.PENDING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.RUNNING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ @patch('fedlearner_webconsole.two_pc.dataset_job_launcher.start_workflow_locally')
+ def test_commit(self, mock_start_workflow_locally: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.RUNNING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ session.commit()
+ mock_start_workflow_locally.assert_not_called()
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.state, DatasetJobState.RUNNING)
+ self.assertIsNone(dataset_job.started_at)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.PENDING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ workflow = session.query(Workflow).get(self._WORKFLOW_ID)
+ mock_start_workflow_locally.assert_called_once_with(ANY, workflow)
+ session.commit()
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.state, DatasetJobState.RUNNING)
+ self.assertIsNotNone(dataset_job.started_at)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.SUCCEEDED
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertFalse(flag)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_launcher.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_launcher.py
new file mode 100644
index 000000000..feb4e092d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_launcher.py
@@ -0,0 +1,76 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.dataset.models import DatasetJobStage, DatasetJobState
+from fedlearner_webconsole.dataset.local_controllers import DatasetJobStageLocalController
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+
+
+class DatasetJobStageLauncher(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.launch_dataset_job_stage_data is not None
+ self._data = data.launch_dataset_job_stage_data
+ self._session = session
+
+ def prepare(self) -> Tuple[bool, str]:
+ dataset_job_stage: DatasetJobStage = self._session.query(DatasetJobStage).filter_by(
+ uuid=self._data.dataset_job_stage_uuid).first()
+ if dataset_job_stage is None:
+ message = 'failed to find dataset_job_stage'
+ logging.warning(
+ f'[dataset_job_stage launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_stage_uuid}')
+ return False, message
+ if not dataset_job_stage.state in [DatasetJobState.PENDING, DatasetJobState.RUNNING]:
+ message = 'dataset_job_stage state check failed! invalid state!'
+ logging.warning(
+ f'[dataset_job_stage launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_stage_uuid}')
+ return False, message
+ if dataset_job_stage.workflow is None:
+ message = 'failed to find workflow'
+ logging.warning(
+ f'[dataset_job_stage launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_stage_uuid}')
+ return False, message
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ # use x lock here, it will keep waiting if it find other lock until lock release or timeout.
+ # we dont't use s lock as it may raise deadlock exception.
+ dataset_job_stage: DatasetJobStage = self._session.query(DatasetJobStage).populate_existing().with_for_update(
+ ).filter_by(uuid=self._data.dataset_job_stage_uuid).first()
+ if dataset_job_stage.state == DatasetJobState.RUNNING:
+ return True, ''
+ if dataset_job_stage.state != DatasetJobState.PENDING:
+ message = 'dataset_job_stage state check failed! invalid state!'
+ logging.warning(
+ f'[dataset_job_stage launch 2pc] prepare: {message}, uuid: {self._data.dataset_job_stage_uuid}')
+ return False, message
+ try:
+ DatasetJobStageLocalController(session=self._session).start(dataset_job_stage)
+ except RuntimeError as e:
+ logging.error(f'[dataset_job_stage launch 2pc] commit: {e}, uuid: {self._data.dataset_job_stage_uuid}')
+ raise
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[dataset_job_stage launch 2pc] abort')
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_launcher_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_launcher_test.py
new file mode 100644
index 000000000..a23e3e97b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_launcher_test.py
@@ -0,0 +1,202 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobKind, DatasetJobStage, DatasetJobState
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.two_pc.dataset_job_stage_launcher import DatasetJobStageLauncher
+from fedlearner_webconsole.proto.two_pc_pb2 import LaunchDatasetJobStageData, \
+ TransactionData
+
+
+class DatasetJobStageLauncherTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _DATASET_JOB_ID = 1
+ _DATA_BATCH_ID = 1
+ _DATASET_JOB_STAGE_ID = 1
+ _WORKFLOW_ID = 1
+ _DATASET_JOB_UUID = 'dataset_job uuid'
+ _DATASET_JOB_STAGE_UUID = 'dataset_job_stage uuid'
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=self._PROJECT_ID, name='test')
+ session.add(project)
+ workflow = Workflow(id=self._WORKFLOW_ID, uuid=self._DATASET_JOB_UUID)
+ session.add(workflow)
+ session.commit()
+ launch_dataset_job_stage_data = LaunchDatasetJobStageData(dataset_job_stage_uuid=self._DATASET_JOB_STAGE_UUID)
+ self.data = TransactionData(launch_dataset_job_stage_data=launch_dataset_job_stage_data)
+
+ def test_prepare_no_dataset_job_stage(self):
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_illegal_state(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=0,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.FAILED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ dataset_job_stage = DatasetJobStage(id=self._DATASET_JOB_STAGE_ID,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=self._DATASET_JOB_ID,
+ uuid=self._DATASET_JOB_STAGE_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ data_batch_id=self._DATA_BATCH_ID,
+ state=DatasetJobState.FAILED)
+ session.add(dataset_job_stage)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_no_related_workflow(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=0,
+ state=DatasetJobState.PENDING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ dataset_job_stage = DatasetJobStage(id=self._DATASET_JOB_STAGE_ID,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=self._DATASET_JOB_ID,
+ uuid=self._DATASET_JOB_STAGE_UUID,
+ workflow_id=100,
+ data_batch_id=self._DATA_BATCH_ID,
+ state=DatasetJobState.PENDING)
+ session.add(dataset_job_stage)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_successfully(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.PENDING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ dataset_job_stage = DatasetJobStage(id=self._DATASET_JOB_STAGE_ID,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=self._DATASET_JOB_ID,
+ uuid=self._DATASET_JOB_STAGE_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ data_batch_id=self._DATA_BATCH_ID,
+ state=DatasetJobState.PENDING)
+ session.add(dataset_job_stage)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.RUNNING
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ dataset_job_stage.state = DatasetJobState.RUNNING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ @patch('fedlearner_webconsole.two_pc.dataset_job_stage_launcher.DatasetJobStageLocalController.start')
+ def test_commit(self, mock_start: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.RUNNING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ dataset_job_stage = DatasetJobStage(id=self._DATASET_JOB_STAGE_ID,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=self._DATASET_JOB_ID,
+ uuid=self._DATASET_JOB_STAGE_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ data_batch_id=self._DATA_BATCH_ID,
+ state=DatasetJobState.RUNNING)
+ session.add(dataset_job_stage)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ session.commit()
+ mock_start.assert_not_called()
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.state, DatasetJobState.RUNNING)
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.RUNNING)
+ self.assertIsNone(dataset_job_stage.started_at)
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.PENDING
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ dataset_job_stage.state = DatasetJobState.PENDING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ mock_start.assert_called_once_with(dataset_job_stage)
+ session.commit()
+
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.SUCCEEDED
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ dataset_job_stage.state = DatasetJobState.SUCCEEDED
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageLauncher(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertFalse(flag)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_stopper.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_stopper.py
new file mode 100644
index 000000000..45aa7f516
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_stopper.py
@@ -0,0 +1,68 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.dataset.models import DatasetJobStage, DatasetJobState
+from fedlearner_webconsole.dataset.local_controllers import DatasetJobStageLocalController
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+
+
+class DatasetJobStageStopper(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.stop_dataset_job_stage_data is not None
+ self._data = data.stop_dataset_job_stage_data
+ self._session = session
+
+ def prepare(self) -> Tuple[bool, str]:
+ dataset_job_stage: DatasetJobStage = self._session.query(DatasetJobStage).filter_by(
+ uuid=self._data.dataset_job_stage_uuid).first()
+ if dataset_job_stage is None:
+ message = 'dataset_job_stage not found'
+ logging.warning(
+ f'[dataset_job_stage stop 2pc] prepare: {message}, uuid: {self._data.dataset_job_stage_uuid}')
+ return False, message
+ if dataset_job_stage.state in [DatasetJobState.SUCCEEDED, DatasetJobState.FAILED]:
+ message = 'dataset_job_stage state check failed! ' \
+ f'current state {dataset_job_stage.state.value} cannot stop'
+ logging.warning(
+ f'[dataset_job_stage stop 2pc] prepare: {message}, uuid: {self._data.dataset_job_stage_uuid}')
+ return False, message
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ # use x lock here, it will keep waiting if it find other lock until lock release or timeout.
+ # we dont't use s lock as it may raise deadlock exception.
+ dataset_job_stage: DatasetJobStage = self._session.query(DatasetJobStage).populate_existing().with_for_update(
+ ).filter_by(uuid=self._data.dataset_job_stage_uuid).first()
+ # allow stop to stop state transfer
+ if dataset_job_stage.state == DatasetJobState.STOPPED:
+ return True, ''
+ try:
+ DatasetJobStageLocalController(session=self._session).stop(dataset_job_stage)
+ except RuntimeError as e:
+ logging.error(f'[dataset_job_stage stop 2pc] commit: {e}, uuid: {self._data.dataset_job_stage_uuid}')
+ raise
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[dataset_job_stage stop 2pc] abort')
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_stopper_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_stopper_test.py
new file mode 100644
index 000000000..9b8cb8f92
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stage_stopper_test.py
@@ -0,0 +1,156 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobKind, DatasetJobState, DatasetJobStage
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.two_pc.dataset_job_stage_stopper import DatasetJobStageStopper
+from fedlearner_webconsole.proto.two_pc_pb2 import StopDatasetJobStageData, \
+ TransactionData
+
+
+class DatasetJobStageStopperTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _DATASET_JOB_ID = 1
+ _DATA_BATCH_ID = 1
+ _DATASET_JOB_STAGE_ID = 1
+ _WORKFLOW_ID = 1
+ _DATASET_JOB_UUID = 'dataset_job uuid'
+ _DATASET_JOB_STAGE_UUID = 'dataset_job_stage uuid'
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=self._PROJECT_ID, name='test')
+ session.add(project)
+ workflow = Workflow(id=self._WORKFLOW_ID, uuid=self._DATASET_JOB_UUID)
+ session.add(workflow)
+ session.commit()
+ stop_dataset_job_stage_data = StopDatasetJobStageData(dataset_job_stage_uuid=self._DATASET_JOB_STAGE_UUID)
+ self.data = TransactionData(stop_dataset_job_stage_data=stop_dataset_job_stage_data)
+
+ def test_prepare_no_dataset_job_stage(self):
+ with db.session_scope() as session:
+ creator = DatasetJobStageStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_state(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.FAILED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ dataset_job_stage = DatasetJobStage(id=self._DATASET_JOB_STAGE_ID,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=self._DATASET_JOB_ID,
+ uuid=self._DATASET_JOB_STAGE_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ data_batch_id=self._DATA_BATCH_ID,
+ state=DatasetJobState.FAILED)
+ session.add(dataset_job_stage)
+ session.commit()
+ # test prepare state failed
+ with db.session_scope() as session:
+ creator = DatasetJobStageStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ # test prepare state succeeded
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.PENDING
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ dataset_job_stage.state = DatasetJobState.PENDING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ # test prepare state stop to stop
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.STOPPED
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ dataset_job_stage.state = DatasetJobState.STOPPED
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ @patch('fedlearner_webconsole.two_pc.dataset_job_stage_stopper.DatasetJobStageLocalController.stop')
+ def test_commit_state(self, mock_stop: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.STOPPED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ dataset_job_stage = DatasetJobStage(id=self._DATASET_JOB_STAGE_ID,
+ project_id=self._PROJECT_ID,
+ dataset_job_id=self._DATASET_JOB_ID,
+ uuid=self._DATASET_JOB_STAGE_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ data_batch_id=self._DATA_BATCH_ID,
+ state=DatasetJobState.STOPPED)
+ session.add(dataset_job_stage)
+ session.commit()
+
+ # test commit state stop to stop
+ with db.session_scope() as session:
+ creator = DatasetJobStageStopper(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ mock_stop.assert_not_called()
+ session.flush()
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.state, DatasetJobState.STOPPED)
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ self.assertEqual(dataset_job_stage.state, DatasetJobState.STOPPED)
+ self.assertIsNone(dataset_job_stage.finished_at)
+
+ # test commit state succeeded
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.PENDING
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ dataset_job_stage.state = DatasetJobState.PENDING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStageStopper(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ dataset_job_stage = session.query(DatasetJobStage).get(self._DATASET_JOB_STAGE_ID)
+ mock_stop.assert_called_once_with(dataset_job_stage)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stopper.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stopper.py
new file mode 100644
index 000000000..3c014e14f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stopper.py
@@ -0,0 +1,72 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.workflow.workflow_controller import stop_workflow_locally
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobState
+from fedlearner_webconsole.dataset.services import DatasetJobService
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+
+
+class DatasetJobStopper(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.stop_dataset_job_data is not None
+ self._data = data.stop_dataset_job_data
+ self._session = session
+
+ def prepare(self) -> Tuple[bool, str]:
+ dataset_job = self._session.query(DatasetJob).populate_existing().with_for_update(read=True).filter_by(
+ uuid=self._data.dataset_job_uuid).first()
+ if dataset_job is None:
+ message = 'dataset_job not found'
+ logging.warning(f'[dataset_job stop 2pc] prepare: {message}, uuid: {self._data.dataset_job_uuid}')
+ return False, message
+ if dataset_job.state in [DatasetJobState.SUCCEEDED, DatasetJobState.FAILED]:
+ message = f'dataset_job state check failed! current state {dataset_job.state.value} cannot stop, ' \
+ f'expected: {DatasetJobState.PENDING.value}, {DatasetJobState.RUNNING.value} or ' \
+ f'{DatasetJobState.STOPPED.value}, uuid is {self._data.dataset_job_uuid}'
+ logging.warning(f'[dataset_job stop 2pc] prepare: {message}, uuid: {self._data.dataset_job_uuid}')
+ return False, message
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ dataset_job = self._session.query(DatasetJob).populate_existing().with_for_update().filter_by(
+ uuid=self._data.dataset_job_uuid).first()
+ # allow stop to stop state transfer
+ if dataset_job.state == DatasetJobState.STOPPED:
+ return True, ''
+ try:
+ if dataset_job.workflow is not None:
+ stop_workflow_locally(self._session, dataset_job.workflow)
+ else:
+ logging.info(f'[dataset_job stop 2pc] commit: workflow not found, just skip, ' \
+ f'uuid: {self._data.dataset_job_uuid}')
+ except RuntimeError as e:
+ logging.error(f'[dataset_job stop 2pc] commit: {e}, uuid: {self._data.dataset_job_uuid}')
+ raise
+ DatasetJobService(self._session).finish_dataset_job(dataset_job=dataset_job,
+ finish_state=DatasetJobState.STOPPED)
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[dataset_job stop 2pc] abort')
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stopper_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stopper_test.py
new file mode 100644
index 000000000..5c385552f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/dataset_job_stopper_test.py
@@ -0,0 +1,152 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock, ANY
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.dataset.models import DatasetJob, DatasetJobKind, DatasetJobState
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.two_pc.dataset_job_stopper import DatasetJobStopper
+from fedlearner_webconsole.proto.two_pc_pb2 import StopDatasetJobData, \
+ TransactionData
+
+
+class DatasetJobStopperTest(NoWebServerTestCase):
+ _PROJECT_ID = 1
+ _DATASET_JOB_ID = 1
+ _WORKFLOW_ID = 1
+ _DATASET_JOB_UUID = 'test_uuid'
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=self._PROJECT_ID, name='test')
+ session.add(project)
+ workflow = Workflow(id=self._WORKFLOW_ID, uuid=self._DATASET_JOB_UUID)
+ session.add(workflow)
+ session.commit()
+ stop_dataset_job_data = StopDatasetJobData(dataset_job_uuid=self._DATASET_JOB_UUID)
+ self.data = TransactionData(stop_dataset_job_data=stop_dataset_job_data)
+
+ def test_prepare_no_dataset_job(self):
+ with db.session_scope() as session:
+ creator = DatasetJobStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ def test_prepare_state(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.FAILED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ session.commit()
+ # test prepare state failed
+ with db.session_scope() as session:
+ creator = DatasetJobStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertFalse(flag)
+
+ # test prepare state succeeded
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.PENDING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ # test prepare state stop to stop
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.STOPPED
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStopper(session, tid='1', data=self.data)
+ flag, _ = creator.prepare()
+ self.assertTrue(flag)
+
+ def test_commit_no_workflow(self):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ state=DatasetJobState.RUNNING,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStopper(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ session.flush()
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.state, DatasetJobState.STOPPED)
+
+ @patch('fedlearner_webconsole.two_pc.dataset_job_stopper.stop_workflow_locally')
+ def test_commit_state(self, mock_stop_workflow_locally: MagicMock):
+ with db.session_scope() as session:
+ dataset_job = DatasetJob(id=self._DATASET_JOB_ID,
+ project_id=self._PROJECT_ID,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ uuid=self._DATASET_JOB_UUID,
+ workflow_id=self._WORKFLOW_ID,
+ state=DatasetJobState.STOPPED,
+ kind=DatasetJobKind.DATA_ALIGNMENT)
+ session.add(dataset_job)
+ session.commit()
+
+ # test commit state stop to stop
+ with db.session_scope() as session:
+ creator = DatasetJobStopper(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ mock_stop_workflow_locally.assert_not_called()
+ session.flush()
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.state, DatasetJobState.STOPPED)
+ self.assertIsNone(dataset_job.finished_at)
+
+ # test commit state succeeded
+ with db.session_scope() as session:
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ dataset_job.state = DatasetJobState.PENDING
+ session.commit()
+ with db.session_scope() as session:
+ creator = DatasetJobStopper(session, tid='1', data=self.data)
+ flag, _ = creator.commit()
+ self.assertTrue(flag)
+ workflow = session.query(Workflow).get(self._WORKFLOW_ID)
+ mock_stop_workflow_locally.assert_called_once_with(ANY, workflow)
+ session.flush()
+ dataset_job = session.query(DatasetJob).get(self._DATASET_JOB_ID)
+ self.assertEqual(dataset_job.state, DatasetJobState.STOPPED)
+ self.assertIsNotNone(dataset_job.finished_at)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/handlers.py b/web_console_v2/api/fedlearner_webconsole/two_pc/handlers.py
new file mode 100644
index 000000000..b218341e7
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/handlers.py
@@ -0,0 +1,85 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TwoPcAction, TransactionData
+from fedlearner_webconsole.two_pc.dataset_job_stage_launcher import DatasetJobStageLauncher
+from fedlearner_webconsole.two_pc.dataset_job_stage_stopper import DatasetJobStageStopper
+from fedlearner_webconsole.two_pc.model_job_creator import ModelJobCreator
+from fedlearner_webconsole.two_pc.trusted_export_job_launcher import TrustedExportJobLauncher
+from fedlearner_webconsole.two_pc.workflow_state_controller import WorkflowStateController
+from fedlearner_webconsole.two_pc.model_job_group_creator import ModelJobGroupCreator
+from fedlearner_webconsole.two_pc.model_job_launcher import ModelJobLauncher
+from fedlearner_webconsole.two_pc.dataset_job_launcher import DatasetJobLauncher
+from fedlearner_webconsole.two_pc.dataset_job_stopper import DatasetJobStopper
+from fedlearner_webconsole.two_pc.trusted_job_group_creator import TrustedJobGroupCreator
+from fedlearner_webconsole.two_pc.trusted_job_launcher import TrustedJobLauncher
+from fedlearner_webconsole.two_pc.trusted_job_stopper import TrustedJobStopper
+from fedlearner_webconsole.two_pc.models import Transaction, TransactionState
+
+
+def run_two_pc_action(session: Session, tid: str, two_pc_type: TwoPcType, action: TwoPcAction,
+ data: TransactionData) -> Tuple[bool, str]:
+ # Checks idempotent
+ trans = session.query(Transaction).filter_by(uuid=tid).first()
+ if trans is None:
+ trans = Transaction(
+ uuid=tid,
+ state=TransactionState.NEW,
+ )
+ trans.set_type(two_pc_type)
+ session.add(trans)
+ executed, result, message = trans.check_idempotent(action)
+ if executed:
+ return result, message
+
+ rm = None
+ if two_pc_type == TwoPcType.CREATE_MODEL_JOB:
+ rm = ModelJobCreator(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.CONTROL_WORKFLOW_STATE:
+ rm = WorkflowStateController(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.CREATE_MODEL_JOB_GROUP:
+ rm = ModelJobGroupCreator(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.LAUNCH_MODEL_JOB:
+ rm = ModelJobLauncher(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.LAUNCH_DATASET_JOB:
+ rm = DatasetJobLauncher(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.STOP_DATASET_JOB:
+ rm = DatasetJobStopper(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.CREATE_TRUSTED_JOB_GROUP:
+ rm = TrustedJobGroupCreator(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.LAUNCH_TRUSTED_JOB:
+ rm = TrustedJobLauncher(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.STOP_TRUSTED_JOB:
+ rm = TrustedJobStopper(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.LAUNCH_TRUSTED_EXPORT_JOB:
+ rm = TrustedExportJobLauncher(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.LAUNCH_DATASET_JOB_STAGE:
+ rm = DatasetJobStageLauncher(session=session, tid=tid, data=data)
+ elif two_pc_type == TwoPcType.STOP_DATASET_JOB_STAGE:
+ rm = DatasetJobStageStopper(session=session, tid=tid, data=data)
+ if rm is None:
+ raise NotImplementedError()
+
+ succeeded = False
+ try:
+ if trans.is_valid_action(action):
+ succeeded, message = rm.run_two_pc(action)
+ except Exception as e: # pylint: disable=broad-except
+ message = str(e)
+ return trans.update(action, succeeded, message)
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/handlers_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/handlers_test.py
new file mode 100644
index 000000000..3734dabd6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/handlers_test.py
@@ -0,0 +1,99 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest import mock
+from unittest.mock import patch, MagicMock
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TwoPcAction, TransactionData
+from fedlearner_webconsole.two_pc.handlers import run_two_pc_action
+from fedlearner_webconsole.two_pc.models import Transaction, TransactionState
+
+
+class HandlersTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.two_pc.handlers.ModelJobCreator')
+ def test_run_two_pc_action_new_transaction(self, mock_model_job_creator_class):
+ mock_model_job_creator = MagicMock()
+ mock_model_job_creator.run_two_pc = MagicMock(return_value=(True, 'aloha'))
+ mock_model_job_creator_class.return_value = mock_model_job_creator
+
+ tid = '123'
+ tdata = TransactionData()
+ with db.session_scope() as session:
+ succeeded, message = run_two_pc_action(session=session,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=tdata)
+ session.commit()
+ self.assertTrue(succeeded)
+ self.assertEqual(message, 'aloha')
+ mock_model_job_creator_class.assert_called_once_with(tid=tid, data=tdata, session=mock.ANY)
+ mock_model_job_creator.run_two_pc.assert_called_once_with(TwoPcAction.PREPARE)
+ with db.session_scope() as session:
+ trans: Transaction = session.query(Transaction).filter_by(uuid=tid).first()
+ self.assertEqual(trans.get_type(), TwoPcType.CREATE_MODEL_JOB)
+ self.assertEqual(trans.state, TransactionState.PREPARE_SUCCEEDED)
+ self.assertEqual(trans.message, 'aloha')
+
+ @patch('fedlearner_webconsole.two_pc.handlers.ModelJobCreator')
+ def test_run_two_pc_action_redundant_action_idempotent(self, mock_model_job_creator_class):
+ tid = '234234'
+ with db.session_scope() as session:
+ trans = Transaction(uuid=tid, state=TransactionState.PREPARE_SUCCEEDED, message='prepared')
+ session.add(trans)
+ session.commit()
+ with db.session_scope() as session:
+ succeeded, message = run_two_pc_action(session=session,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=TransactionData())
+ self.assertTrue(succeeded)
+ self.assertEqual(message, 'prepared')
+ mock_model_job_creator_class.assert_not_called()
+
+ @patch('fedlearner_webconsole.two_pc.handlers.ModelJobCreator')
+ def test_run_two_pc_action_exception(self, mock_model_job_creator_class):
+ mock_model_job_creator = MagicMock()
+ mock_model_job_creator.run_two_pc = MagicMock(side_effect=RuntimeError('Unknown error'))
+ mock_model_job_creator_class.return_value = mock_model_job_creator
+
+ tid = '123234234'
+ tdata = TransactionData()
+ with db.session_scope() as session:
+ trans = Transaction(uuid=tid, state=TransactionState.PREPARE_SUCCEEDED, message='prepared')
+ session.add(trans)
+ session.commit()
+ succeeded, message = run_two_pc_action(session=session,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.COMMIT,
+ data=tdata)
+ session.commit()
+ self.assertFalse(succeeded)
+ self.assertIn('Unknown error', message)
+ mock_model_job_creator.run_two_pc.assert_called_once_with(TwoPcAction.COMMIT)
+ with db.session_scope() as session:
+ trans: Transaction = session.query(Transaction).filter_by(uuid=tid).first()
+ self.assertEqual(trans.state, TransactionState.INVALID)
+ self.assertIn('Unknown error', trans.message)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_creator.py b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_creator.py
new file mode 100644
index 000000000..ce2618945
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_creator.py
@@ -0,0 +1,143 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+from typing import Tuple, Optional
+
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.dataset.models import Dataset, ResourceState
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobType, ModelJobGroup
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+
+
+class ModelJobCreator(ResourceManager):
+ """Create model job without configuration"""
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.create_model_job_data is not None
+ self._data = data.create_model_job_data
+ self._session = session
+
+ def _check_model_job(self) -> Tuple[bool, str]:
+ model_job_name = self._data.model_job_name
+ model_job = self._session.query(ModelJob).filter_by(name=model_job_name).first()
+ if model_job:
+ message = f'model job {model_job_name} already exist'
+ logging.info('[model-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_model_job_group(self) -> Tuple[bool, str]:
+ model_job_group_name = self._data.group_name
+ # there is no model group for eval/predict model job
+ if model_job_group_name:
+ model_job_group = self._session.query(ModelJobGroup).filter_by(name=model_job_group_name).first()
+ if model_job_group is None:
+ message = f'model group {model_job_group_name} not exists'
+ logging.info('[model-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_model(self) -> Tuple[bool, str]:
+ model_uuid = self._data.model_uuid
+ # there is no model for training model job
+ if model_uuid:
+ model = self._session.query(Model).filter_by(uuid=model_uuid).first()
+ if model is None:
+ message = f'model {self._data.model_uuid} not found'
+ logging.info('[model-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_dataset(self) -> Tuple[bool, str]:
+ if self._data.dataset_uuid:
+ dataset: Dataset = self._session.query(Dataset).filter_by(uuid=self._data.dataset_uuid).first()
+ if not dataset:
+ message = f'dataset {self._data.dataset_uuid} not exists'
+ logging.info('[model-job-2pc] prepare failed: %s', message)
+ return False, message
+ if dataset.get_frontend_state() != ResourceState.SUCCEEDED:
+ message = f'dataset {self._data.dataset_uuid} is not succeeded'
+ logging.info('[model-group-2pc] prepare failed: %s', message)
+ return False, message
+ if not dataset.is_published:
+ message = f'dataset {self._data.dataset_uuid} is not published'
+ logging.info('[model-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def prepare(self) -> Tuple[bool, str]:
+ check_fn_list = [self._check_model_job, self._check_model_job_group, self._check_model, self._check_dataset]
+ for check_fn in check_fn_list:
+ succeeded, message = check_fn()
+ if not succeeded:
+ return False, message
+ logging.info('[model-job-2pc] prepare succeeded')
+ return True, ''
+
+ def _get_model_job_group_id(self) -> Optional[int]:
+ if self._data.group_name:
+ model_job_group = self._session.query(ModelJobGroup).filter_by(name=self._data.group_name).first()
+ return model_job_group.id
+ return None
+
+ def _get_model_id(self) -> Optional[int]:
+ if self._data.model_uuid:
+ model = self._session.query(Model).filter_by(uuid=self._data.model_uuid).first()
+ return model.id
+ return None
+
+ def _get_project_id(self) -> int:
+ project = self._session.query(Project).filter_by(name=self._data.project_name).first()
+ return project.id
+
+ def _get_dataset_id(self) -> Optional[int]:
+ if self._data.dataset_uuid:
+ dataset = self._session.query(Dataset).filter_by(uuid=self._data.dataset_uuid).first()
+ return dataset.id
+ return None
+
+ def commit(self) -> Tuple[bool, str]:
+ model_job_group_id = self._get_model_job_group_id()
+ model_id = self._get_model_id()
+ project_id = self._get_project_id()
+ dataset_id = self._get_dataset_id()
+ coordinator = ParticipantService(
+ self._session).get_participant_by_pure_domain_name(pure_domain_name=self._data.coordinator_pure_domain_name)
+ coordinator_id = None
+ if coordinator is not None:
+ coordinator_id = coordinator.id
+ model_job = ModelJob(name=self._data.model_job_name,
+ model_job_type=ModelJobType[self._data.model_job_type],
+ project_id=project_id,
+ uuid=self._data.model_job_uuid,
+ group_id=model_job_group_id,
+ dataset_id=dataset_id,
+ model_id=model_id,
+ workflow_uuid=self._data.workflow_uuid,
+ algorithm_type=self._data.algorithm_type,
+ coordinator_id=coordinator_id)
+ self._session.add(model_job)
+ logging.info('[model-job-2pc] commit succeeded')
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[model-job-2pc] abort')
+ # As we did not preserve any resource, do nothing
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_creator_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_creator_test.py
new file mode 100644
index 000000000..5d4b6ae77
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_creator_test.py
@@ -0,0 +1,124 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import Mock, patch
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, ResourceState
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.two_pc.model_job_creator import ModelJobCreator
+from fedlearner_webconsole.mmgr.models import Model, ModelJob, ModelJobType, ModelJobGroup
+from fedlearner_webconsole.proto.two_pc_pb2 import CreateModelJobData, \
+ TransactionData
+
+
+class ModelJobCreatorTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(name='project')
+ model = Model(name='model', uuid='model-uuid')
+ model_job_group = ModelJobGroup(name='group')
+ dataset = Dataset(name='dataset', uuid='dataset-uuid', is_published=True)
+ session.add_all([project, model, model_job_group, dataset])
+ session.commit()
+ create_model_job_data = CreateModelJobData(model_job_name='model-job',
+ model_job_type=ModelJobType.EVALUATION.name,
+ model_job_uuid='model-job-uuid',
+ workflow_uuid='workflow-uuid',
+ group_name=model_job_group.name,
+ algorithm_type=AlgorithmType.NN_VERTICAL.name,
+ model_uuid=model.uuid,
+ project_name=project.name,
+ dataset_uuid='dataset-uuid')
+ self.data = TransactionData(create_model_job_data=create_model_job_data)
+
+ @patch('fedlearner_webconsole.dataset.models.Dataset.get_frontend_state')
+ def test_prepare(self, mock_get_frontend_state: Mock):
+ mock_get_frontend_state.return_value = ResourceState.SUCCEEDED
+ with db.session_scope() as session:
+ creator = ModelJobCreator(session, tid='12', data=self.data)
+ flag, message = creator.prepare()
+ self.assertTrue(flag)
+ # fail due to model not found
+ self.data.create_model_job_data.model_uuid = 'uuid'
+ flag, message = creator.prepare()
+ self.assertFalse(flag)
+ self.assertEqual(message, 'model uuid not found')
+ with db.session_scope() as session:
+ self.data.create_model_job_data.model_uuid = 'model-uuid'
+ model_job = ModelJob(name='model-job')
+ session.add(model_job)
+ session.commit()
+ with db.session_scope() as session:
+ # fail due to model job with the same name
+ flag, message = ModelJobCreator(session, tid='12', data=self.data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(message, 'model job model-job already exist')
+ with db.session_scope() as session:
+ self.data.create_model_job_data.group_name = 'group-1'
+ self.data.create_model_job_data.model_job_name = 'model-job-1'
+ flag, message = ModelJobCreator(session, tid='12', data=self.data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(message, 'model group group-1 not exists')
+ with db.session_scope() as session:
+ self.data.create_model_job_data.group_name = 'group'
+ # fail due to dataset is not found
+ self.data.create_model_job_data.dataset_uuid = 'dataset-uuid-1'
+ flag, message = ModelJobCreator(session, tid='12', data=self.data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(message, 'dataset dataset-uuid-1 not exists')
+ with db.session_scope() as session:
+ dataset = Dataset(name='dataset-test-failed', uuid='dataset-uuid-1', is_published=False)
+ session.add(dataset)
+ session.commit()
+ # fail due to dataset is not published
+ flag, message = ModelJobCreator(session, tid='12', data=self.data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(message, 'dataset dataset-uuid-1 is not published')
+ with db.session_scope() as session:
+ mock_get_frontend_state.return_value = ResourceState.FAILED
+ # fail due to dataset is not succeeded
+ flag, message = ModelJobCreator(session, tid='12', data=self.data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(message, 'dataset dataset-uuid-1 is not succeeded')
+
+ def test_commit(self):
+ with db.session_scope() as session:
+ creator = ModelJobCreator(session, tid='12', data=self.data)
+ creator.commit()
+ session.commit()
+ with db.session_scope() as session:
+ model = session.query(Model).filter_by(uuid='model-uuid').first()
+ project = session.query(Project).filter_by(name='project').first()
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='model-job').first()
+ model_job_group = session.query(ModelJobGroup).filter_by(name='group').first()
+ dataset = session.query(Dataset).filter_by(uuid='dataset-uuid').first()
+ self.assertEqual(model_job.uuid, 'model-job-uuid')
+ self.assertEqual(model_job.model_job_type, ModelJobType.EVALUATION)
+ self.assertEqual(model_job.workflow_uuid, 'workflow-uuid')
+ self.assertEqual(model_job.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(model_job.model_id, model.id)
+ self.assertEqual(model_job.project_id, project.id)
+ self.assertEqual(model_job.group_id, model_job_group.id)
+ self.assertEqual(model_job.dataset_id, dataset.id)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_group_creator.py b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_group_creator.py
new file mode 100644
index 000000000..6f08afb44
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_group_creator.py
@@ -0,0 +1,127 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.dataset.models import Dataset, ResourceState
+from fedlearner_webconsole.mmgr.models import ModelJobGroup, ModelJobRole, GroupCreateStatus
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.proto.project_pb2 import ParticipantsInfo, ParticipantInfo
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+from fedlearner_webconsole.setting.service import SettingService
+
+
+class ModelJobGroupCreator(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.create_model_job_group_data is not None
+ self._data = data.create_model_job_group_data
+ self._session = session
+
+ def _check_project(self) -> Tuple[bool, str]:
+ project_name = self._data.project_name
+ project = self._session.query(Project).filter_by(name=project_name).first()
+ if not project:
+ message = f'project {self._data.model_job_group_name} not exists'
+ logging.info('[model-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_group(self) -> Tuple[bool, str]:
+ model_job_group_name = self._data.model_job_group_name
+ model_job_group = self._session.query(ModelJobGroup).filter_by(name=model_job_group_name).first()
+ if model_job_group:
+ if model_job_group.uuid != self._data.model_job_group_uuid:
+ message = f'model group {model_job_group_name} with different uuid already exist'
+ logging.info('[model-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_dataset(self) -> Tuple[bool, str]:
+ if self._data.dataset_uuid:
+ dataset = self._session.query(Dataset).filter_by(uuid=self._data.dataset_uuid).first()
+ if not dataset:
+ message = f'dataset {self._data.dataset_uuid} not exists'
+ logging.info('[model-group-2pc] prepare failed: %s', message)
+ return False, message
+ if dataset.get_frontend_state() != ResourceState.SUCCEEDED:
+ message = f'dataset {self._data.dataset_uuid} is not succeeded'
+ logging.info('[model-group-2pc] prepare failed: %s', message)
+ return False, message
+ if not dataset.is_published:
+ message = f'dataset {self._data.dataset_uuid} is not published'
+ logging.info('[model-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def prepare(self) -> Tuple[bool, str]:
+ check_fn_list = [self._check_project, self._check_group, self._check_dataset]
+ for check_fn in check_fn_list:
+ succeeded, message = check_fn()
+ if not succeeded:
+ return False, message
+ logging.info('[model-group-2pc] prepare succeeded')
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ model_job_group_name = self._data.model_job_group_name
+ project = self._session.query(Project).filter_by(name=self._data.project_name).first()
+ group = self._session.query(ModelJobGroup).filter_by(name=model_job_group_name).first()
+ dataset_id = None
+ if self._data.dataset_uuid:
+ dataset_id = self._session.query(Dataset).filter_by(uuid=self._data.dataset_uuid).first().id
+ coordinator = ParticipantService(
+ self._session).get_participant_by_pure_domain_name(pure_domain_name=self._data.coordinator_pure_domain_name)
+ coordinator_id = None
+ if coordinator is not None:
+ coordinator_id = coordinator.id
+ if not group:
+ group = ModelJobGroup(name=model_job_group_name,
+ uuid=self._data.model_job_group_uuid,
+ project_id=project.id,
+ dataset_id=dataset_id,
+ authorized=False,
+ role=ModelJobRole.PARTICIPANT,
+ algorithm_type=AlgorithmType[self._data.algorithm_type],
+ coordinator_id=coordinator_id)
+ participants = ParticipantService(self._session).get_participants_by_project(project.id)
+ participants_info = ParticipantsInfo(participants_map={
+ p.pure_domain_name(): ParticipantInfo(auth_status=AuthStatus.PENDING.name) for p in participants
+ })
+ participants_info.participants_map[
+ self._data.coordinator_pure_domain_name].auth_status = AuthStatus.AUTHORIZED.name
+ pure_domain_name = SettingService.get_system_info().pure_domain_name
+ participants_info.participants_map[pure_domain_name].auth_status = AuthStatus.PENDING.name
+ group.set_participants_info(participants_info)
+ self._session.add(group)
+ group.status = GroupCreateStatus.SUCCEEDED
+ logging.info('[model-group-2pc] commit succeeded')
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[model-group-2pc] abort')
+ group = self._session.query(ModelJobGroup).filter_by(name=self._data.model_job_group_name).first()
+ if group is not None:
+ group.status = GroupCreateStatus.FAILED
+ # As we did not preserve any resource, do nothing
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_group_creator_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_group_creator_test.py
new file mode 100644
index 000000000..484305e1f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_group_creator_test.py
@@ -0,0 +1,157 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import Mock, patch
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset, ResourceState
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.two_pc.model_job_group_creator import ModelJobGroupCreator
+from fedlearner_webconsole.mmgr.models import ModelJobGroup, GroupCreateStatus
+from fedlearner_webconsole.proto.two_pc_pb2 import CreateModelJobGroupData, \
+ TransactionData
+from fedlearner_webconsole.proto.project_pb2 import ParticipantInfo, ParticipantsInfo
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+
+
+class ModelJobGroupCreatorTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=1, name='project')
+ participant = Participant(id=123, name='party', domain_name='fl-demo.com')
+ relationship = ProjectParticipant(project_id=1, participant_id=123)
+ dataset = Dataset(name='dataset', uuid='dataset_uuid', is_published=True)
+ with db.session_scope() as session:
+ session.add_all([project, participant, dataset, relationship])
+ session.commit()
+
+ @staticmethod
+ def get_transaction_data(group_name: str, group_uuid: str, project_name: str, dataset_uuid: str):
+ return TransactionData(
+ create_model_job_group_data=CreateModelJobGroupData(model_job_group_name=group_name,
+ model_job_group_uuid=group_uuid,
+ project_name=project_name,
+ algorithm_type=AlgorithmType.NN_VERTICAL.name,
+ coordinator_pure_domain_name='demo',
+ dataset_uuid=dataset_uuid))
+
+ @patch('fedlearner_webconsole.dataset.models.Dataset.get_frontend_state')
+ def test_prepare(self, mock_get_frontend_state: Mock):
+ mock_get_frontend_state.return_value = ResourceState.SUCCEEDED
+ data = self.get_transaction_data('group', 'uuid', 'project', 'dataset_uuid')
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertTrue(flag)
+ with db.session_scope() as session:
+ model_job_group = ModelJobGroup(name='group', uuid='uuid')
+ session.add(model_job_group)
+ session.commit()
+ # test for idempotence for creating group with same name and uuid
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertTrue(flag)
+ # fail due to uuid not consistent
+ data = self.get_transaction_data('group', 'uuid-1', 'project', 'dataset_uuid')
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'model group group with different uuid already exist')
+ # fail due to project not found
+ data = self.get_transaction_data('group', 'uuid', 'project-1', 'dataset_uuid')
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'project group not exists')
+ # fail due to dataset not found
+ data = self.get_transaction_data('group', 'uuid', 'project', 'dataset_uuid-1')
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'dataset dataset_uuid-1 not exists')
+ # fail due to dataset is not published
+ with db.session_scope() as session:
+ dataset = session.query(Dataset).filter_by(uuid='dataset_uuid').first()
+ dataset.is_published = False
+ session.add(dataset)
+ session.commit()
+ data = self.get_transaction_data('group', 'uuid', 'project', 'dataset_uuid')
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'dataset dataset_uuid is not published')
+ # fail due to dataset is not succeeded
+ mock_get_frontend_state.return_value = ResourceState.PROCESSING
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'dataset dataset_uuid is not succeeded')
+
+ @patch('fedlearner_webconsole.setting.service.get_pure_domain_name')
+ def test_commit(self, mock_pure_domain_name: Mock):
+ mock_pure_domain_name.return_value = 'test'
+ data = self.get_transaction_data('group', 'uuid', 'project', 'dataset_uuid')
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.commit()
+ session.commit()
+ self.assertTrue(flag)
+ with db.session_scope() as session:
+ model_job_group = session.query(ModelJobGroup).filter_by(name='group').first()
+ project = session.query(Project).filter_by(name='project').first()
+ dataset = session.query(Dataset).filter_by(uuid='dataset_uuid').first()
+ self.assertEqual(model_job_group.uuid, 'uuid')
+ self.assertEqual(model_job_group.project_id, project.id)
+ self.assertEqual(model_job_group.algorithm_type, AlgorithmType.NN_VERTICAL)
+ self.assertEqual(model_job_group.coordinator_id, 123)
+ self.assertEqual(model_job_group.dataset_id, dataset.id)
+ self.assertEqual(
+ model_job_group.get_participants_info(),
+ ParticipantsInfo(
+ participants_map={
+ 'test': ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'demo': ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ }))
+
+ def test_abort(self):
+ data = self.get_transaction_data('group', 'uuid', 'project', 'dataset_uuid')
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ flag, msg = creator.abort()
+ session.commit()
+ self.assertTrue(flag)
+ with db.session_scope() as session:
+ creator = ModelJobGroupCreator(session, '12', data)
+ group = ModelJobGroup(name='group', uuid='uuid', status=GroupCreateStatus.PENDING)
+ session.add(group)
+ session.flush()
+ flag, msg = creator.abort()
+ self.assertTrue(flag)
+ self.assertEqual(group.status, GroupCreateStatus.FAILED)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_launcher.py b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_launcher.py
new file mode 100644
index 000000000..8bfcdffa3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_launcher.py
@@ -0,0 +1,77 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.mmgr.service import ModelJobGroupService
+from fedlearner_webconsole.mmgr.models import ModelJobGroup, ModelJobRole
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+
+
+class ModelJobLauncher(ResourceManager):
+ """Launch a configured model job based on the config of model job group"""
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.create_model_job_data is not None
+ self._data = data.create_model_job_data
+ self._session = session
+
+ def prepare(self) -> Tuple[bool, str]:
+ if self._data.group_uuid is None:
+ message = 'group_uuid not found in create_model_job_data'
+ logging.info('[launch-model-job-2pc] prepare failed: %s', message)
+ return False, message
+ group: ModelJobGroup = self._session.query(ModelJobGroup).filter_by(uuid=self._data.group_uuid).first()
+ if group is None:
+ message = f'model group not found by uuid {self._data.group_uuid}'
+ logging.info('[launch-model-job-2pc] prepare failed: %s', message)
+ return False, message
+ if group.role == ModelJobRole.PARTICIPANT and not group.authorized:
+ message = f'model group {self._data.group_uuid} not authorized to coordinator'
+ logging.info('[launch-model-job-2pc] prepare failed: %s', message)
+ return False, message
+ if group.config is None:
+ message = f'the config of model group {group.name} not found'
+ logging.info('[launch-model-job-2pc] prepare failed: %s', message)
+ return False, message
+ # the latest version of group at coordinator is the same with the given version
+ if group.latest_version >= self._data.version and group.role == ModelJobRole.PARTICIPANT:
+ message = f'the latest version of model group {group.name} is larger than or equal to the given version'
+ logging.info('[launch-model-job-2pc] prepare failed: %s', message)
+ return False, message
+ if group.algorithm_id is not None and group.algorithm is None:
+ message = f'the algorithm {group.algorithm_id} of group {group.name} is not found'
+ logging.warning('[launch-model-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ group: ModelJobGroup = self._session.query(ModelJobGroup).filter_by(uuid=self._data.group_uuid).first()
+
+ ModelJobGroupService(self._session).launch_model_job(group=group,
+ name=self._data.model_job_name,
+ uuid=self._data.model_job_uuid,
+ version=self._data.version)
+ logging.info(f'[launch-model-job-2pc] commit succeeded for group {group.name}')
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[model-job-2pc] abort')
+ # As we did not preserve any resource, do nothing
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_launcher_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_launcher_test.py
new file mode 100644
index 000000000..7b8e5c3db
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/model_job_launcher_test.py
@@ -0,0 +1,159 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from google.protobuf.struct_pb2 import Value
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.initial_db import _insert_or_update_templates
+from fedlearner_webconsole.algorithm.models import AlgorithmType, Algorithm
+from fedlearner_webconsole.two_pc.model_job_launcher import ModelJobLauncher
+from fedlearner_webconsole.workflow.models import WorkflowState
+from fedlearner_webconsole.mmgr.models import ModelJob, ModelJobGroup, ModelJobType, ModelJobRole
+from fedlearner_webconsole.dataset.models import Dataset, DatasetJob, DatasetJobState, DatasetJobKind, DatasetType, \
+ DatasetJobStage
+from fedlearner_webconsole.proto.two_pc_pb2 import CreateModelJobData, \
+ TransactionData
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.proto.common_pb2 import Variable
+
+
+def _get_workflow_config():
+ return WorkflowDefinition(job_definitions=[
+ JobDefinition(name='train-job',
+ job_type=JobDefinition.JobType.NN_MODEL_TRANINING,
+ variables=[
+ Variable(name='mode', value='train'),
+ Variable(name='data_source',
+ value='dataset-job-stage-uuid-psi-data-join-job',
+ typed_value=Value(string_value='dataset-job-stage-uuid-psi-data-join-job')),
+ Variable(name='data_path', typed_value=Value(string_value='')),
+ ])
+ ])
+
+
+class ModelJobCreatorTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ _insert_or_update_templates(session)
+ dataset_job = DatasetJob(id=1,
+ name='datasetjob',
+ uuid='dataset-job-uuid',
+ state=DatasetJobState.SUCCEEDED,
+ project_id=1,
+ input_dataset_id=1,
+ output_dataset_id=2,
+ kind=DatasetJobKind.RSA_PSI_DATA_JOIN)
+ dataset = Dataset(id=2,
+ uuid='uuid',
+ name='datasetjob',
+ dataset_type=DatasetType.PSI,
+ path='/data/dataset/haha')
+ dataset_job_stage = DatasetJobStage(id=1,
+ name='data-join',
+ uuid='dataset-job-stage-uuid',
+ project_id=1,
+ state=DatasetJobState.SUCCEEDED,
+ dataset_job_id=1,
+ data_batch_id=1)
+ algorithm = Algorithm(id=2, name='algo')
+ model_job_group = ModelJobGroup(name='group',
+ uuid='uuid',
+ project_id=1,
+ algorithm_type=AlgorithmType.NN_VERTICAL,
+ algorithm_id=2,
+ role=ModelJobRole.PARTICIPANT,
+ authorized=True,
+ dataset_id=2,
+ latest_version=2)
+ model_job_group.set_config(_get_workflow_config())
+ session.add_all([dataset_job, dataset_job_stage, dataset, model_job_group, algorithm])
+ session.commit()
+
+ def test_prepare(self):
+ create_model_job_data = CreateModelJobData(model_job_name='model-job',
+ model_job_uuid='model-job-uuid',
+ group_uuid='uuid',
+ version=3)
+ data = TransactionData(create_model_job_data=create_model_job_data)
+ with db.session_scope() as session:
+ # succeeded
+ flag, _ = ModelJobLauncher(session, tid='12', data=data).prepare()
+ self.assertTrue(flag)
+ with db.session_scope() as session:
+ # fail due to group is not authorized
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(uuid='uuid').first()
+ group.authorized = False
+ flag, msg = ModelJobLauncher(session, tid='12', data=data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'model group uuid not authorized to coordinator')
+ with db.session_scope() as session:
+ # fail due to algorithm not found
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(uuid='uuid').first()
+ group.algorithm_id = 3
+ data = TransactionData(create_model_job_data=create_model_job_data)
+ flag, msg = ModelJobLauncher(session, tid='12', data=data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'the algorithm 3 of group group is not found')
+ with db.session_scope() as session:
+ # fail due to group is not configured
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(uuid='uuid').first()
+ group.config = None
+ flag, msg = ModelJobLauncher(session, tid='12', data=data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'the config of model group group not found')
+ with db.session_scope() as session:
+ # fail due to group is not found
+ data.create_model_job_data.group_uuid = '1'
+ flag, msg = ModelJobLauncher(session, tid='12', data=data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg, 'model group not found by uuid 1')
+ data.create_model_job_data.group_uuid = 'uuid'
+ # fail due to version mismatch
+ data.create_model_job_data.version = 2
+ flag, msg = ModelJobLauncher(session, tid='12', data=data).prepare()
+ self.assertFalse(flag)
+ self.assertEqual(msg,
+ 'the latest version of model group group is larger than or equal to the given version')
+
+ def test_commit(self):
+ create_model_job_data = CreateModelJobData(model_job_name='model-job',
+ model_job_uuid='model-job-uuid',
+ group_uuid='uuid',
+ version=2)
+ data = TransactionData(create_model_job_data=create_model_job_data)
+ with db.session_scope() as session:
+ flag, msg = ModelJobLauncher(session, tid='12', data=data).commit()
+ self.assertTrue(flag)
+ session.commit()
+ with db.session_scope() as session:
+ group: ModelJobGroup = session.query(ModelJobGroup).filter_by(uuid='uuid').first()
+ model_job: ModelJob = session.query(ModelJob).filter_by(name='model-job').first()
+ self.assertEqual(model_job.group_id, group.id)
+ self.assertTrue(model_job.project_id, group.project_id)
+ self.assertTrue(model_job.algorithm_type, group.algorithm_type)
+ self.assertTrue(model_job.model_job_type, ModelJobType.TRAINING)
+ self.assertTrue(model_job.dataset_id, group.dataset_id)
+ self.assertEqual(model_job.workflow.get_config(), group.get_config())
+ self.assertEqual(model_job.workflow.state, WorkflowState.READY)
+ self.assertTrue(model_job.version, 2)
+ self.assertTrue(group.latest_version, 2)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/models.py b/web_console_v2/api/fedlearner_webconsole/two_pc/models.py
new file mode 100644
index 000000000..213f7b0ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/models.py
@@ -0,0 +1,113 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import enum
+from typing import Tuple, Optional
+
+from sqlalchemy import func, Index
+
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TwoPcAction
+from fedlearner_webconsole.utils.mixins import to_dict_mixin
+
+
+class TransactionState(enum.Enum):
+ NEW = 'NEW'
+ PREPARE_SUCCEEDED = 'PREPARE_SUCCEEDED'
+ PREPARE_FAILED = 'PREPARE_FAILED'
+ COMMITTED = 'COMMITTED'
+ ABORTED = 'ABORTED'
+ INVALID = 'INVALID'
+
+
+# Valid transition mappings:
+# Current state - action - result - new state
+_VALID_TRANSITIONS = {
+ TransactionState.NEW: {
+ TwoPcAction.PREPARE: {
+ True: TransactionState.PREPARE_SUCCEEDED,
+ False: TransactionState.PREPARE_FAILED,
+ }
+ },
+ TransactionState.PREPARE_SUCCEEDED: {
+ TwoPcAction.COMMIT: {
+ True: TransactionState.COMMITTED,
+ },
+ TwoPcAction.ABORT: {
+ True: TransactionState.ABORTED,
+ }
+ },
+ TransactionState.PREPARE_FAILED: {
+ TwoPcAction.ABORT: {
+ True: TransactionState.ABORTED,
+ }
+ }
+}
+
+
+@to_dict_mixin(ignores=['_type'], extras={'type': lambda t: t.get_type()})
+class Transaction(db.Model):
+ __tablename__ = 'transactions_v2'
+ __table_args__ = (Index('uniq_uuid', 'uuid', unique=True), default_table_args('2pc transactions'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True, comment='id')
+ uuid = db.Column(db.String(64), comment='uuid')
+ # 2PC type, consistent with TwoPcType in proto
+ _two_pc_type = db.Column('type', db.String(32), comment='2pc type name')
+ state = db.Column(db.Enum(TransactionState, native_enum=False, create_constraint=False, length=32),
+ default=TransactionState.NEW,
+ comment='state')
+ message = db.Column(db.Text(), comment='message of the last action')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created_at')
+ updated_at = db.Column(db.DateTime(timezone=True),
+ onupdate=func.now(),
+ server_default=func.now(),
+ comment='update_at')
+
+ def get_type(self) -> TwoPcType:
+ return TwoPcType.Value(self._two_pc_type)
+
+ def set_type(self, t: TwoPcType):
+ self._two_pc_type = TwoPcType.Name(t)
+
+ def is_valid_action(self, action: TwoPcAction) -> bool:
+ """Checks if the action is valid for current state or not."""
+ possible_results = _VALID_TRANSITIONS.get(self.state, {}).get(action, None)
+ return possible_results is not None
+
+ def check_idempotent(self, current_action: TwoPcAction) -> Tuple[bool, Optional[bool], Optional[str]]:
+ """Checks if the action executed and the result.
+
+ Returns:
+ (executed or not, result, message)
+ """
+ if self.state == TransactionState.INVALID:
+ return True, False, self.message
+ for current_state, actions in _VALID_TRANSITIONS.items():
+ for action, results in actions.items():
+ for result, new_state in results.items():
+ if new_state == self.state and action == current_action:
+ # Hits the history
+ return True, result, self.message
+ return False, None, None
+
+ def update(self, action: TwoPcAction, succeeded: bool, message: str) -> Tuple[bool, str]:
+ new_state = _VALID_TRANSITIONS.get(self.state, {}).get(action, {}).get(succeeded, None)
+ if new_state is None:
+ self.message = f'[2pc] Invalid transition: [{self.state} - {action} - {succeeded}], extra: {message}'
+ self.state = TransactionState.INVALID
+ return False, self.message
+ self.state = new_state
+ self.message = message
+ return succeeded, message
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/models_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/models_test.py
new file mode 100644
index 000000000..13b3b47fe
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/models_test.py
@@ -0,0 +1,79 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TwoPcAction
+from fedlearner_webconsole.two_pc.models import Transaction, TransactionState
+
+
+class TransactionTest(unittest.TestCase):
+
+ def test_two_pc_type(self):
+ trans = Transaction()
+ trans.set_type(TwoPcType.CREATE_MODEL_JOB)
+ self.assertEqual(trans._two_pc_type, 'CREATE_MODEL_JOB') # pylint: disable=protected-access
+ self.assertEqual(trans.get_type(), TwoPcType.CREATE_MODEL_JOB)
+
+ def test_is_valid_action(self):
+ trans = Transaction(state=TransactionState.NEW)
+ self.assertTrue(trans.is_valid_action(TwoPcAction.PREPARE))
+ self.assertFalse(trans.is_valid_action(TwoPcAction.COMMIT))
+ trans.state = TransactionState.PREPARE_FAILED
+ self.assertTrue(trans.is_valid_action(TwoPcAction.ABORT))
+ self.assertFalse(trans.is_valid_action(TwoPcAction.COMMIT))
+ trans.state = TransactionState.INVALID
+ self.assertFalse(trans.is_valid_action(TwoPcAction.ABORT))
+
+ def test_check_idempotent_invalid(self):
+ trans = Transaction(state=TransactionState.INVALID, message='invalid')
+ executed, result, message = trans.check_idempotent(TwoPcAction.COMMIT)
+ self.assertTrue(executed)
+ self.assertFalse(result)
+ self.assertEqual(message, 'invalid')
+
+ def test_check_idempotent_executed(self):
+ trans = Transaction(state=TransactionState.PREPARE_SUCCEEDED, message='prepared')
+ executed, result, message = trans.check_idempotent(TwoPcAction.PREPARE)
+ self.assertTrue(executed)
+ self.assertTrue(result)
+ self.assertEqual(message, 'prepared')
+
+ def test_check_idempotent_has_not_executed(self):
+ trans = Transaction(state=TransactionState.PREPARE_SUCCEEDED, message='prepared')
+ self.assertEqual(trans.check_idempotent(TwoPcAction.COMMIT), (False, None, None))
+
+ def test_update_failed(self):
+ trans = Transaction(state=TransactionState.PREPARE_SUCCEEDED)
+ trans.update(TwoPcAction.COMMIT, False, 'failed to abort')
+ self.assertEqual(trans.state, TransactionState.INVALID)
+ self.assertEqual(
+ trans.message,
+ '[2pc] Invalid transition: [TransactionState.PREPARE_SUCCEEDED - 1 - False], extra: failed to abort')
+
+ def test_update_successfully(self):
+ trans = Transaction(state=TransactionState.PREPARE_SUCCEEDED)
+ trans.update(TwoPcAction.COMMIT, True, 'yeah')
+ self.assertEqual(trans.state, TransactionState.COMMITTED)
+ self.assertEqual(trans.message, 'yeah')
+
+ trans = Transaction(state=TransactionState.PREPARE_SUCCEEDED)
+ trans.update(action=TwoPcAction.ABORT, succeeded=True, message='yep')
+ self.assertEqual(trans.state, TransactionState.ABORTED)
+ self.assertEqual(trans.message, 'yep')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/resource_manager.py b/web_console_v2/api/fedlearner_webconsole/two_pc/resource_manager.py
new file mode 100644
index 000000000..bb11f4a9a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/resource_manager.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from abc import abstractmethod
+from typing import Tuple
+
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData, TwoPcAction
+
+
+class ResourceManager(object):
+ """An abstract class to manage resource in 2pc.
+
+ The recommendation practice is to keep those methods idempotent.
+ """
+
+ def __init__(self, tid: str, data: TransactionData):
+ self.tid = tid
+ self.data = data
+
+ @abstractmethod
+ def prepare(self) -> Tuple[bool, str]:
+ pass
+
+ @abstractmethod
+ def commit(self) -> Tuple[bool, str]:
+ pass
+
+ @abstractmethod
+ def abort(self) -> Tuple[bool, str]:
+ pass
+
+ def run_two_pc(self, action: TwoPcAction) -> Tuple[bool, str]:
+ if action == TwoPcAction.PREPARE:
+ return self.prepare()
+ if action == TwoPcAction.COMMIT:
+ return self.commit()
+ assert action == TwoPcAction.ABORT
+ return self.abort()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/transaction_manager.py b/web_console_v2/api/fedlearner_webconsole/two_pc/transaction_manager.py
new file mode 100644
index 000000000..6cb81be88
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/transaction_manager.py
@@ -0,0 +1,95 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+from typing import List, Tuple
+from uuid import uuid4
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcAction, TwoPcType, TransactionData
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.two_pc.handlers import run_two_pc_action
+from fedlearner_webconsole.utils.metrics import emit_store
+
+
+class TransactionManager(object):
+
+ def __init__(self, project_name: str, project_token: str, participants: List[str], two_pc_type: TwoPcType):
+ self.type = two_pc_type
+ self._project_name = project_name
+ self._project_token = project_token
+ self._clients = []
+ for domain_name in participants:
+ self._clients.append(
+ RpcClient.from_project_and_participant(project_name=self._project_name,
+ project_token=self._project_token,
+ domain_name=domain_name))
+
+ def run(self, data: TransactionData) -> Tuple[bool, str]:
+ tid = str(uuid4())
+ prepared, pre_message = self.do_two_pc_action(tid, TwoPcAction.PREPARE, data)
+ # TODO(hangweiqiang): catch exception and maybe retry sometime?
+ if prepared:
+ succeeded, act_message = self.do_two_pc_action(tid, TwoPcAction.COMMIT, data)
+ else:
+ succeeded, act_message = self.do_two_pc_action(tid, TwoPcAction.ABORT, data)
+ if not succeeded:
+ emit_store('2pc.transaction_failure', 1)
+ return (prepared, pre_message) if not prepared else (succeeded, act_message)
+
+ def do_two_pc_action(self, tid: str, action: TwoPcAction, data: TransactionData) -> Tuple[bool, str]:
+ # TODO(hangweiqiang): using multi-thread
+ succeeded = True
+ message = None
+ for client in self._clients:
+ result, res_message = self._remote_do_two_pc(client, tid, action, data)
+ if not result and succeeded:
+ succeeded = False
+ message = res_message
+ result, res_message = self._local_do_two_pc(tid, action, data)
+ if not result and succeeded:
+ succeeded = False
+ message = res_message
+ return succeeded, message
+
+ def _remote_do_two_pc(self, client: RpcClient, tid: str, action: TwoPcAction,
+ data: TransactionData) -> Tuple[bool, str]:
+ response = client.run_two_pc(transaction_uuid=tid, two_pc_type=self.type, action=action, data=data)
+ if response.status.code != common_pb2.STATUS_SUCCESS:
+ # Something wrong during rpc call
+ logging.info('[%s] 2pc [%s] error [%s]: %s', self.type, action, tid, response.status.msg)
+ return False, response.message
+ if not response.succeeded:
+ # Failed
+ logging.info('[%s] 2pc [%s] failed [%s]: %s', self.type, action, tid, response.message)
+ return False, response.message
+ return True, response.message
+
+ def _local_do_two_pc(self, tid: str, action: TwoPcAction, data: TransactionData) -> Tuple[bool, str]:
+ try:
+ with db.session_scope() as session:
+ succeeded, message = run_two_pc_action(session=session,
+ tid=tid,
+ two_pc_type=self.type,
+ action=action,
+ data=data)
+ session.commit()
+ except Exception as e: # pylint: disable=broad-except
+ succeeded = False
+ message = str(e)
+ if not succeeded:
+ logging.info('[%s] 2pc [%s] failed locally [%s]: %s', self.type, action, tid, message)
+ return succeeded, message
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/transaction_manager_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/transaction_manager_test.py
new file mode 100644
index 000000000..24df28298
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/transaction_manager_test.py
@@ -0,0 +1,172 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+# pylint: disable=protected-access
+import unittest
+from unittest import mock
+from unittest.mock import patch, MagicMock, call
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.service_pb2 import TwoPcResponse
+from fedlearner_webconsole.proto.two_pc_pb2 import TwoPcType, TwoPcAction, TransactionData, \
+ CreateModelJobData
+from fedlearner_webconsole.two_pc.transaction_manager import TransactionManager
+
+
+class TransactionManagerTest(NoWebServerTestCase):
+ _PROJECT_NAME = 'test-project'
+ _PROJECT_TOKEN = 'testtoken'
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.RpcClient.from_project_and_participant')
+ def test_init(self, mock_rpc_client_factory):
+ mock_rpc_client_factory.return_value = MagicMock()
+ tm = TransactionManager(project_name=self._PROJECT_NAME,
+ project_token=self._PROJECT_TOKEN,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ participants=['fl1.com', 'fl2.com'])
+ self.assertEqual(tm.type, TwoPcType.CREATE_MODEL_JOB)
+ self.assertEqual(len(tm._clients), 2)
+
+ calls = [
+ call(project_name=self._PROJECT_NAME, project_token=self._PROJECT_TOKEN, domain_name='fl1.com'),
+ call(project_name=self._PROJECT_NAME, project_token=self._PROJECT_TOKEN, domain_name='fl2.com')
+ ]
+ mock_rpc_client_factory.assert_has_calls(calls)
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.run_two_pc_action')
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.uuid4')
+ def test_run(self, mock_uuid4, mock_local_run_two_pc_action):
+ tid = 'testttttt'
+ transaction_data = TransactionData(create_model_job_data=CreateModelJobData(model_job_name='test model name'))
+ mock_uuid4.return_value = tid
+ # Two participants
+ p1 = MagicMock()
+ p1.run_two_pc = MagicMock()
+ p2 = MagicMock()
+ p2.run_two_pc = MagicMock()
+ # A hack to avoid mocking RpcClient.from_project_and_participant
+ tm = TransactionManager(project_name=self._PROJECT_NAME,
+ project_token=self._PROJECT_TOKEN,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ participants=[])
+ tm._clients = [p1, p2]
+
+ # Test successfully
+ p1.run_two_pc.return_value = TwoPcResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ succeeded=True)
+ p2.run_two_pc.return_value = TwoPcResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS),
+ succeeded=True)
+ mock_local_run_two_pc_action.return_value = (True, '')
+ succeeded, _ = tm.run(transaction_data)
+ self.assertTrue(succeeded)
+ mock_uuid4.assert_called_once()
+ calls = [
+ call(transaction_uuid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data),
+ call(transaction_uuid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.COMMIT,
+ data=transaction_data),
+ ]
+ p1.run_two_pc.assert_has_calls(calls)
+ p2.run_two_pc.assert_has_calls(calls)
+ mock_local_run_two_pc_action.assert_has_calls([
+ call(session=mock.ANY,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data),
+ call(session=mock.ANY,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.COMMIT,
+ data=transaction_data),
+ ])
+
+ # Test failed
+ def p2_run_two_pc(action: TwoPcAction, *args, **kwargs) -> TwoPcResponse:
+ if action == TwoPcAction.PREPARE:
+ return TwoPcResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), succeeded=False)
+ return TwoPcResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), succeeded=True)
+
+ p2.run_two_pc.side_effect = p2_run_two_pc
+ mock_uuid4.reset_mock()
+ succeeded, _ = tm.run(transaction_data)
+ self.assertFalse(succeeded)
+ mock_uuid4.assert_called_once()
+ calls = [
+ call(transaction_uuid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data),
+ call(transaction_uuid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.ABORT,
+ data=transaction_data),
+ ]
+ p1.run_two_pc.assert_has_calls(calls)
+ p2.run_two_pc.assert_has_calls(calls)
+ mock_local_run_two_pc_action.assert_has_calls([
+ call(session=mock.ANY,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data),
+ call(session=mock.ANY,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.ABORT,
+ data=transaction_data),
+ ])
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.run_two_pc_action')
+ def test_do_two_pc_action(self, mock_local_run_two_pc_action):
+ tid = 'test-id'
+ transaction_data = TransactionData(create_model_job_data=CreateModelJobData(model_job_name='test model name'))
+ # Two participants
+ p1 = MagicMock()
+ p1.run_two_pc = MagicMock(
+ return_value=TwoPcResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), succeeded=True))
+ p2 = MagicMock()
+ p2.run_two_pc = MagicMock(
+ return_value=TwoPcResponse(status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS), succeeded=True))
+ mock_local_run_two_pc_action.return_value = (True, '')
+ # A hack to avoid mocking RpcClient.from_project_and_participant
+ tm = TransactionManager(project_name=self._PROJECT_NAME,
+ project_token=self._PROJECT_TOKEN,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ participants=[])
+ tm._clients = [p1, p2]
+ self.assertTrue(tm.do_two_pc_action(tid=tid, action=TwoPcAction.PREPARE, data=transaction_data))
+ p1.run_two_pc.assert_called_once_with(transaction_uuid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data)
+ p2.run_two_pc.assert_called_once_with(transaction_uuid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data)
+ mock_local_run_two_pc_action.assert_called_once_with(session=mock.ANY,
+ tid=tid,
+ two_pc_type=TwoPcType.CREATE_MODEL_JOB,
+ action=TwoPcAction.PREPARE,
+ data=transaction_data)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_export_job_launcher.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_export_job_launcher.py
new file mode 100644
index 000000000..7dd0ac5ea
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_export_job_launcher.py
@@ -0,0 +1,76 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.tee.services import TrustedJobService
+from fedlearner_webconsole.tee.models import TrustedJob, TrustedJobStatus
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+
+
+class TrustedExportJobLauncher(ResourceManager):
+ """Launch a configured trusted export job"""
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.launch_trusted_export_job_data is not None
+ self._data = data.launch_trusted_export_job_data
+ self._session = session
+ self._tee_export_job = None
+
+ def _check_trusted_export_job(self) -> Tuple[bool, str]:
+ self._tee_export_job = self._session.query(TrustedJob).filter_by(uuid=self._data.uuid).first()
+ if self._tee_export_job is None:
+ message = f'trusted export job {self._data.uuid} not found'
+ logging.info('[launch-trusted-export-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_auth(self) -> Tuple[bool, str]:
+ if self._tee_export_job.auth_status != AuthStatus.AUTHORIZED:
+ message = f'trusted export job {self._data.uuid} not authorized'
+ logging.info('[launch-trusted-export-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def prepare(self) -> Tuple[bool, str]:
+ # _check_trusted_export_job should be the first
+ check_fn_list = [
+ self._check_trusted_export_job,
+ self._check_auth,
+ ]
+ for check_fn in check_fn_list:
+ succeeded, message = check_fn()
+ if not succeeded:
+ return False, message
+ logging.info('[launch-trusted-export-job-2pc] prepare succeeded')
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ tee_export_job = self._session.query(TrustedJob).filter_by(uuid=self._data.uuid).first()
+ if tee_export_job.coordinator_id == 0 or tee_export_job.group.analyzer_id == 0:
+ TrustedJobService(self._session).launch_trusted_export_job(tee_export_job)
+ else:
+ tee_export_job.status = TrustedJobStatus.SUCCEEDED
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[launch-trusted-export-job-2pc] abort')
+ # As we did not preserve any resource, do nothing
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_export_job_launcher_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_export_job_launcher_test.py
new file mode 100644
index 000000000..8298a394e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_export_job_launcher_test.py
@@ -0,0 +1,117 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, MagicMock
+from testing.no_web_server_test_case import NoWebServerTestCase
+from google.protobuf.text_format import MessageToString
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.two_pc.trusted_export_job_launcher import TrustedExportJobLauncher
+from fedlearner_webconsole.tee.models import TrustedJobGroup, TrustedJob, TrustedJobType, TrustedJobStatus
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.tee_pb2 import Resource
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData, LaunchTrustedExportJobData
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+
+
+class TrustedExportJobLauncherTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-name')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ group = TrustedJobGroup(id=1, analyzer_id=0)
+ tee_export_job = TrustedJob(id=1,
+ uuid='uuid1',
+ name='V1-domain1-1',
+ type=TrustedJobType.EXPORT,
+ version=1,
+ project_id=1,
+ trusted_job_group_id=1,
+ auth_status=AuthStatus.AUTHORIZED,
+ status=TrustedJobStatus.CREATED,
+ export_count=1,
+ coordinator_id=1,
+ resource=MessageToString(Resource(cpu=1000, memory=1, replicas=1)))
+ tee_analyze_job = TrustedJob(id=2,
+ uuid='uuid2',
+ type=TrustedJobType.ANALYZE,
+ version=1,
+ trusted_job_group_id=1,
+ job_id=1,
+ status=TrustedJobStatus.SUCCEEDED)
+ job = Job(id=1,
+ name='trusted-job-1-uuid2',
+ job_type=JobType.CUSTOMIZED,
+ state=JobState.COMPLETED,
+ workflow_id=0,
+ project_id=1)
+ sys_var = SettingService(session).get_system_variables_dict()
+ session.add_all([project, participant1, proj_part1, group, tee_export_job, tee_analyze_job, job])
+ session.commit()
+ sys_var['sgx_image'] = 'artifact.bytedance.com/fedlearner/pp_bioinformatics:e13eb8a1d96ad046ca7354b8197d41fd'
+ self.sys_var = sys_var
+
+ @staticmethod
+ def get_transaction_data(uuid: str):
+ return TransactionData(launch_trusted_export_job_data=LaunchTrustedExportJobData(uuid=uuid))
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_prepare(self, mock_get_system_info: MagicMock):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).get(1)
+ # successful
+ data = self.get_transaction_data('uuid1')
+ launcher = TrustedExportJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertTrue(flag)
+ # fail due to tee_export_job not exist
+ data = self.get_transaction_data('not-exist')
+ launcher = TrustedExportJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+ # fail due to auth
+ tee_export_job.auth_status = AuthStatus.WITHDRAW
+ data = self.get_transaction_data('uuid1')
+ launcher = TrustedExportJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_variables_dict')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_commit(self, mock_get_system_info: MagicMock, mock_sys_dict: MagicMock):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ mock_sys_dict.return_value = self.sys_var
+ with db.session_scope() as session:
+ data = self.get_transaction_data('uuid1')
+ launcher = TrustedExportJobLauncher(session, '13', data)
+ flag, msg = launcher.commit()
+ self.assertTrue(flag)
+ session.commit()
+ with db.session_scope() as session:
+ tee_export_job = session.query(TrustedJob).get(1)
+ self.assertIsNotNone(tee_export_job.job_id)
+ self.assertEqual(tee_export_job.get_status(), TrustedJobStatus.PENDING)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_group_creator.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_group_creator.py
new file mode 100644
index 000000000..ce1b32349
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_group_creator.py
@@ -0,0 +1,180 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.review.ticket_helper import get_ticket_helper
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.tee.models import TrustedJobGroup, GroupCreateStatus
+from fedlearner_webconsole.proto.tee_pb2 import ParticipantDatasetList, ParticipantDataset
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.exceptions import NotFoundException
+
+
+class TrustedJobGroupCreator(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.create_trusted_job_group_data is not None
+ self._data = data.create_trusted_job_group_data
+ self._session = session
+ self._project_id = None
+ self.pure_domain_name = SettingService.get_system_info().pure_domain_name
+
+ def _check_ticket(self) -> Tuple[bool, str]:
+ validate = get_ticket_helper(self._session).validate_ticket(
+ self._data.ticket_uuid, lambda ticket: ticket.details.uuid == self._data.uuid)
+ if not validate:
+ message = f'ticket {self._data.ticket_uuid} is not valid'
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_group(self) -> Tuple[bool, str]:
+ name = self._data.name
+ project_name = self._data.project_name
+ project = self._session.query(Project).filter_by(name=project_name).first()
+ self._project_id = project.id
+ group = self._session.query(TrustedJobGroup).filter_by(name=name, project_id=project.id).first()
+ if group is not None and group.uuid != self._data.uuid:
+ message = f'trusted job group {name} in project {project_name} with different uuid already exists'
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_algorithm(self) -> Tuple[bool, str]:
+ try:
+ algorithm = AlgorithmFetcher(self._project_id).get_algorithm(self._data.algorithm_uuid)
+ if algorithm.type != AlgorithmType.TRUSTED_COMPUTING.name:
+ message = f'algorithm {self._data.algorithm_uuid} is not TRUSTED_COMPUTING type'
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ except NotFoundException as e:
+ message = e.message
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_participant_dataset(self) -> Tuple[bool, str]:
+ for dnd in self._data.domain_name_datasets:
+ if dnd.pure_domain_name == self.pure_domain_name:
+ dataset = self._session.query(Dataset).filter_by(uuid=dnd.dataset_uuid, name=dnd.dataset_name).first()
+ if dataset is None:
+ message = f'dataset {dnd.dataset_uuid} not exists'
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ if not dataset.is_published:
+ message = f'dataset {dnd.dataset_uuid} is not published'
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ else:
+ participant = ParticipantService(
+ self._session).get_participant_by_pure_domain_name(pure_domain_name=dnd.pure_domain_name)
+ if participant is None:
+ message = f'participant with pure domain name {dnd.pure_domain_name} not exists'
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_participant(self) -> Tuple[bool, str]:
+ for pure_domain_name in [self._data.coordinator_pure_domain_name, self._data.analyzer_pure_domain_name]:
+ if pure_domain_name != self.pure_domain_name:
+ participant = ParticipantService(
+ self._session).get_participant_by_pure_domain_name(pure_domain_name=pure_domain_name)
+ if participant is None:
+ message = f'participant with pure domain name {pure_domain_name} not exists'
+ logging.info('[trusted-group-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def prepare(self) -> Tuple[bool, str]:
+ # _check_algorithm should be after _check_group
+ check_fn_list = [
+ self._check_ticket,
+ self._check_group,
+ self._check_algorithm,
+ self._check_participant_dataset,
+ self._check_participant,
+ ]
+ for check_fn in check_fn_list:
+ succeeded, message = check_fn()
+ if not succeeded:
+ return False, message
+ logging.info('[trusted-group-2pc] prepare succeeded')
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ coordinator_pure_domain_name = self._data.coordinator_pure_domain_name
+ # The coordinator has already created in POST api so do nothing
+ if self.pure_domain_name == coordinator_pure_domain_name:
+ logging.info('[trusted-group-2pc] commit succeeded')
+ return True, ''
+ name = self._data.name
+ uuid = self._data.uuid
+ project = self._session.query(Project).filter_by(name=self._data.project_name).first()
+ coordinator_id = ParticipantService(
+ self._session).get_participant_by_pure_domain_name(pure_domain_name=coordinator_pure_domain_name).id
+ if self.pure_domain_name == self._data.analyzer_pure_domain_name:
+ analyzer_id = 0
+ else:
+ analyzer_id = ParticipantService(self._session).get_participant_by_pure_domain_name(
+ pure_domain_name=self._data.analyzer_pure_domain_name).id
+ dataset_id = None
+ participant_datasets = ParticipantDatasetList()
+ for dnd in self._data.domain_name_datasets:
+ if dnd.pure_domain_name == self.pure_domain_name:
+ dataset = self._session.query(Dataset).filter_by(uuid=dnd.dataset_uuid).first()
+ dataset_id = dataset.id
+ else:
+ participant = ParticipantService(
+ self._session).get_participant_by_pure_domain_name(pure_domain_name=dnd.pure_domain_name)
+ participant_datasets.items.append(
+ ParticipantDataset(participant_id=participant.id, uuid=dnd.dataset_uuid, name=dnd.dataset_name))
+ participants = ParticipantService(self._session).get_participants_by_project(project.id)
+ unauth_participant_ids = [p.id for p in participants if p.id != coordinator_id]
+ group = self._session.query(TrustedJobGroup).filter_by(uuid=uuid).first()
+ if not group:
+ group = TrustedJobGroup(
+ name=name,
+ uuid=uuid,
+ latest_version=0,
+ creator_username=self._data.creator_username,
+ project_id=project.id,
+ coordinator_id=coordinator_id,
+ analyzer_id=analyzer_id,
+ ticket_uuid=self._data.ticket_uuid,
+ ticket_status=TicketStatus.APPROVED,
+ status=GroupCreateStatus.SUCCEEDED,
+ algorithm_uuid=self._data.algorithm_uuid,
+ dataset_id=dataset_id,
+ )
+ group.set_participant_datasets(participant_datasets)
+ group.set_unauth_participant_ids(unauth_participant_ids)
+ self._session.add(group)
+ logging.info('[trusted-group-2pc] commit succeeded')
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[trusted-group-2pc] abort')
+ # As we did not preserve any resource, do nothing
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_group_creator_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_group_creator_test.py
new file mode 100644
index 000000000..bfc587dd9
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_group_creator_test.py
@@ -0,0 +1,245 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch
+from typing import List
+import grpc
+from google.protobuf.text_format import MessageToString
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.rpc.client import FakeRpcError
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.algorithm.models import AlgorithmProject, Algorithm, AlgorithmType
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData, CreateTrustedJobGroupData
+from fedlearner_webconsole.proto.tee_pb2 import DomainNameDataset, ParticipantDataset, ParticipantDatasetList
+from fedlearner_webconsole.tee.models import TrustedJobGroup, GroupCreateStatus
+from fedlearner_webconsole.two_pc.trusted_job_group_creator import TrustedJobGroupCreator
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.review.common import NO_CENTRAL_SERVER_UUID
+
+
+class TrustedJobGroupCreatorTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=1, name='project')
+ participant1 = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ participant2 = Participant(id=2, name='part3', domain_name='fl-domain3.com')
+ proj_part1 = ProjectParticipant(project_id=1, participant_id=1)
+ proj_part2 = ProjectParticipant(project_id=1, participant_id=2)
+ dataset1 = Dataset(id=1, name='dataset-name1', uuid='dataset-uuid1', is_published=True)
+ dataset2 = Dataset(id=2, name='dataset-name3', uuid='dataset-uuid3', is_published=False)
+ algorithm_proj1 = AlgorithmProject(id=1, uuid='algorithm-proj-uuid1', type=AlgorithmType.TRUSTED_COMPUTING)
+ algorithm1 = Algorithm(id=1,
+ uuid='algorithm-uuid1',
+ type=AlgorithmType.TRUSTED_COMPUTING,
+ algorithm_project_id=1)
+ algorithm2 = Algorithm(id=2, uuid='algorithm-uuid2', algorithm_project_id=1)
+ with db.session_scope() as session:
+ session.add_all([
+ project,
+ participant1,
+ participant2,
+ proj_part1,
+ proj_part2,
+ dataset1,
+ dataset2,
+ algorithm1,
+ algorithm2,
+ algorithm_proj1,
+ ])
+ session.commit()
+
+ @staticmethod
+ def get_transaction_data(name: str, uuid: str, ticket_uuid: str, project_name: str, algorithm_project_uuid: str,
+ algorithm_uuid: str, domain_name_datasets: List[DomainNameDataset],
+ coordinator_pure_domain_name: str, analyzer_pure_domain_name: str):
+ return TransactionData(create_trusted_job_group_data=CreateTrustedJobGroupData(
+ name=name,
+ uuid=uuid,
+ ticket_uuid=ticket_uuid,
+ project_name=project_name,
+ algorithm_project_uuid=algorithm_project_uuid,
+ algorithm_uuid=algorithm_uuid,
+ domain_name_datasets=domain_name_datasets,
+ coordinator_pure_domain_name=coordinator_pure_domain_name,
+ analyzer_pure_domain_name=analyzer_pure_domain_name,
+ ))
+
+ @patch('fedlearner_webconsole.algorithm.fetcher.AlgorithmFetcher.get_algorithm_from_participant')
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_prepare(self, mock_get_system_info, mock_get_algorithm):
+ # successful case
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ mock_get_algorithm.side_effect = FakeRpcError(grpc.StatusCode.NOT_FOUND, 'not found')
+ data = self.get_transaction_data(
+ 'group-name',
+ 'group-uuid',
+ NO_CENTRAL_SERVER_UUID,
+ 'project',
+ 'algorithm-proj-uuid1',
+ 'algorithm-uuid1',
+ [
+ DomainNameDataset(
+ pure_domain_name='domain1', dataset_uuid='dataset-uuid1', dataset_name='dataset-name1'),
+ DomainNameDataset(
+ pure_domain_name='domain2', dataset_uuid='dataset-uuid2', dataset_name='dataset-name2'),
+ ],
+ 'domain2',
+ 'domain2',
+ )
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertTrue(flag)
+ # fail due to algorithm not found
+ data.create_trusted_job_group_data.algorithm_uuid = 'algorithm-not-exist'
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ # fail due to algorithm type invalid
+ data.create_trusted_job_group_data.algorithm_uuid = 'algorithm-uuid2'
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ # fail due to participant not found
+ data = self.get_transaction_data(
+ 'group-name',
+ 'group-uuid',
+ NO_CENTRAL_SERVER_UUID,
+ 'project',
+ 'algorithm-proj-uuid1',
+ 'algorithm-uuid1',
+ [
+ DomainNameDataset(
+ pure_domain_name='domain1', dataset_uuid='dataset-uuid1', dataset_name='dataset-name1'),
+ DomainNameDataset(
+ pure_domain_name='domain-not-exist', dataset_uuid='dataset-uuid2', dataset_name='dataset-name2'),
+ ],
+ 'domain2',
+ 'domain2',
+ )
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ # fail due to dataset not found
+ data = self.get_transaction_data(
+ 'group-name',
+ 'group-uuid',
+ NO_CENTRAL_SERVER_UUID,
+ 'project',
+ 'algorithm-proj-uuid1',
+ 'algorithm-uuid1',
+ [
+ DomainNameDataset(
+ pure_domain_name='domain1', dataset_uuid='dataset-uuid-not-exist', dataset_name='dataset-name1'),
+ DomainNameDataset(
+ pure_domain_name='domain2', dataset_uuid='dataset-uuid2', dataset_name='dataset-name2'),
+ ],
+ 'domain2',
+ 'domain2',
+ )
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ # fail due to dataset not published
+ data = self.get_transaction_data(
+ 'group-name',
+ 'group-uuid',
+ NO_CENTRAL_SERVER_UUID,
+ 'project',
+ 'algorithm-proj-uuid1',
+ 'algorithm-uuid1',
+ [
+ DomainNameDataset(
+ pure_domain_name='domain1', dataset_uuid='dataset-uuid3', dataset_name='dataset-name3'),
+ DomainNameDataset(
+ pure_domain_name='domain2', dataset_uuid='dataset-uuid2', dataset_name='dataset-name2'),
+ ],
+ 'domain2',
+ 'domain2',
+ )
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+ # fail due to same trusted job group name with different uuid in project
+ with db.session_scope() as session:
+ group = TrustedJobGroup(name='group-name', uuid='other-group-uuid', project_id=1)
+ session.add(group)
+ session.commit()
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.prepare()
+ self.assertFalse(flag)
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_commit(self, mock_get_system_info):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ data = self.get_transaction_data(
+ 'group-name',
+ 'group-uuid',
+ NO_CENTRAL_SERVER_UUID,
+ 'project',
+ 'algorithm-proj-uuid1',
+ 'algorithm-uuid1',
+ [
+ DomainNameDataset(
+ pure_domain_name='domain1', dataset_uuid='dataset-uuid1', dataset_name='dataset-name1'),
+ DomainNameDataset(
+ pure_domain_name='domain2', dataset_uuid='dataset-uuid2', dataset_name='dataset-name2'),
+ ],
+ 'domain2',
+ 'domain2',
+ )
+ with db.session_scope() as session:
+ creator = TrustedJobGroupCreator(session, '12', data)
+ flag, msg = creator.commit()
+ session.commit()
+ self.assertTrue(flag)
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).filter_by(name='group-name').first()
+ self.assertEqual(group.uuid, 'group-uuid')
+ self.assertEqual(group.latest_version, 0)
+ self.assertEqual(group.project_id, 1)
+ self.assertEqual(group.coordinator_id, 1)
+ self.assertEqual(group.analyzer_id, 1)
+ self.assertEqual(group.ticket_uuid, NO_CENTRAL_SERVER_UUID)
+ self.assertEqual(group.ticket_status, TicketStatus.APPROVED)
+ self.assertEqual(group.status, GroupCreateStatus.SUCCEEDED)
+ self.assertEqual(group.auth_status, AuthStatus.PENDING)
+ self.assertEqual(group.unauth_participant_ids, '2')
+ self.assertEqual(group.algorithm_uuid, 'algorithm-uuid1')
+ self.assertEqual(group.dataset_id, 1)
+ participant_datasets = ParticipantDatasetList(
+ items=[ParticipantDataset(
+ participant_id=1,
+ uuid='dataset-uuid2',
+ name='dataset-name2',
+ )])
+ self.assertEqual(group.participant_datasets, MessageToString(participant_datasets))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_launcher.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_launcher.py
new file mode 100644
index 000000000..f56eec725
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_launcher.py
@@ -0,0 +1,136 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.tee.services import TrustedJobGroupService
+from fedlearner_webconsole.tee.models import TrustedJobGroup
+from fedlearner_webconsole.dataset.models import Dataset
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.algorithm.fetcher import AlgorithmFetcher
+from fedlearner_webconsole.algorithm.models import AlgorithmType
+from fedlearner_webconsole.exceptions import NotFoundException
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.participant.services import ParticipantService
+
+
+class TrustedJobLauncher(ResourceManager):
+ """Launch a configured trusted job based on the config of trusted job group"""
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.launch_trusted_job_data is not None
+ self._data = data.launch_trusted_job_data
+ self._session = session
+ self._group = None
+
+ def _check_group(self) -> Tuple[bool, str]:
+ if self._data.group_uuid:
+ group: TrustedJobGroup = self._session.query(TrustedJobGroup).filter_by(uuid=self._data.group_uuid).first()
+ if group is not None:
+ self._group = group
+ return True, ''
+ message = f'trusted job group {self._data.group_uuid} not found'
+ logging.info('[launch-trusted-job-2pc] prepare failed: %s', message)
+ return False, message
+
+ def _check_version(self) -> Tuple[bool, str]:
+ self_pure_domain_name = SettingService.get_system_info().pure_domain_name
+ if (self._group.latest_version >= self._data.version and
+ self._data.initiator_pure_domain_name != self_pure_domain_name):
+ message = (f'the latest version of trusted job group {self._data.group_uuid} '
+ f'is greater than or equal to the given version')
+ logging.info('[launch-trusted-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_auth(self) -> Tuple[bool, str]:
+ if self._group.auth_status != AuthStatus.AUTHORIZED:
+ message = f'trusted job group {self._data.group_uuid} not authorized'
+ logging.info('[launch-trusted-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_algorithm(self) -> Tuple[bool, str]:
+ try:
+ algorithm = AlgorithmFetcher(self._group.project_id).get_algorithm(self._group.algorithm_uuid)
+ if algorithm.type != AlgorithmType.TRUSTED_COMPUTING.name:
+ message = f'algorithm {self._group.algorithm_uuid} is not TRUSTED_COMPUTING type'
+ logging.info('[launch-trusted-job-2pc] prepare failed: %s', message)
+ return False, message
+ except NotFoundException as e:
+ message = e.message
+ logging.info('[launch-trusted-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_dataset(self) -> Tuple[bool, str]:
+ if self._group.dataset_id is not None:
+ dataset: Dataset = self._session.query(Dataset).get(self._group.dataset_id)
+ if dataset is None or not dataset.is_published:
+ message = f'dataset {self._group.dataset_id} is not found'
+ logging.info('[launch-trusted-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def _check_initiator(self) -> Tuple[bool, str]:
+ init_pure_dn = self._data.initiator_pure_domain_name
+ if SettingService.get_system_info().pure_domain_name == init_pure_dn:
+ return True, ''
+ participant = ParticipantService(self._session).get_participant_by_pure_domain_name(init_pure_dn)
+ if participant is None:
+ message = f'initiator {self._data.initiator_pure_domain_name} is not found'
+ logging.info('[launch-trusted-job-2pc] prepare failed: %s', message)
+ return False, message
+ return True, ''
+
+ def prepare(self) -> Tuple[bool, str]:
+ # _check_group should be the first
+ check_fn_list = [
+ self._check_group,
+ self._check_version,
+ self._check_auth,
+ self._check_algorithm,
+ self._check_dataset,
+ self._check_initiator,
+ ]
+ for check_fn in check_fn_list:
+ succeeded, message = check_fn()
+ if not succeeded:
+ return False, message
+ logging.info('[launch-trusted-job-2pc] prepare succeeded')
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ group: TrustedJobGroup = self._session.query(TrustedJobGroup).filter_by(uuid=self._data.group_uuid).first()
+ pure_dn = self._data.initiator_pure_domain_name
+ if SettingService.get_system_info().pure_domain_name == pure_dn:
+ coordinator_id = 0
+ else:
+ participant = ParticipantService(self._session).get_participant_by_pure_domain_name(pure_dn)
+ coordinator_id = participant.id
+ TrustedJobGroupService(self._session).launch_trusted_job(group, self._data.uuid, self._data.version,
+ coordinator_id)
+ logging.info(f'[launch-trusted-job-2pc] commit succeeded for group {group.name}')
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[launch-trusted-job-2pc] abort')
+ # As we did not preserve any resource, do nothing
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_launcher_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_launcher_test.py
new file mode 100644
index 000000000..5e774c7dd
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_launcher_test.py
@@ -0,0 +1,130 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch
+from testing.common import NoWebServerTestCase
+from google.protobuf.text_format import MessageToString
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.tee.models import TrustedJobGroup, GroupCreateStatus, TrustedJob
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.algorithm.models import Algorithm, AlgorithmType
+from fedlearner_webconsole.dataset.models import Dataset, DataBatch
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData, LaunchTrustedJobData
+from fedlearner_webconsole.two_pc.trusted_job_launcher import TrustedJobLauncher
+from fedlearner_webconsole.proto.setting_pb2 import SystemInfo
+from fedlearner_webconsole.proto.tee_pb2 import Resource
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.participant.models import Participant
+
+
+class TrustedJobLauncherTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project-name')
+ participant = Participant(id=1, name='part2', domain_name='fl-domain2.com')
+ algorithm = Algorithm(id=1,
+ uuid='algorithm-uuid1',
+ type=AlgorithmType.TRUSTED_COMPUTING,
+ path='file:///data/algorithm/test/run.sh')
+ dataset1 = Dataset(id=1, name='dataset-name1', uuid='dataset-uuid1', is_published=True)
+ data_batch1 = DataBatch(id=1, dataset_id=1)
+ dataset2 = Dataset(id=2, name='dataset-name2', uuid='dataset-uuid2', is_published=False)
+ group = TrustedJobGroup(id=1,
+ uuid='group-uuid',
+ project_id=1,
+ latest_version=1,
+ coordinator_id=1,
+ status=GroupCreateStatus.SUCCEEDED,
+ auth_status=AuthStatus.AUTHORIZED,
+ algorithm_uuid='algorithm-uuid1',
+ dataset_id=1,
+ resource=MessageToString(Resource(cpu=2, memory=2, replicas=1)))
+ session.add_all([project, participant, algorithm, dataset1, data_batch1, dataset2, group])
+ session.commit()
+
+ @staticmethod
+ def get_transaction_data(uuid: str, version: int, group_uuid: str, initiator_pure_domain_name: str):
+ return TransactionData(launch_trusted_job_data=LaunchTrustedJobData(
+ uuid=uuid, version=version, group_uuid=group_uuid, initiator_pure_domain_name=initiator_pure_domain_name))
+
+ @patch('fedlearner_webconsole.setting.service.SettingService.get_system_info')
+ def test_prepare(self, mock_get_system_info):
+ mock_get_system_info.return_value = SystemInfo(pure_domain_name='domain1')
+ with db.session_scope() as session:
+ group: TrustedJobGroup = session.query(TrustedJobGroup).get(1)
+ # successful
+ data = self.get_transaction_data('trusted-job-uuid', 2, 'group-uuid', 'domain2')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertTrue(flag)
+ # fail due to initiator not found
+ data = self.get_transaction_data('trusted-job-uuid', 2, 'group-uuid', 'domain3')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+ # fail due to group not found
+ data = self.get_transaction_data('trusted-job-uuid', 2, 'not-exist', 'domain2')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+ # fail due to version conflict
+ data = self.get_transaction_data('trusted-job-uuid', 1, 'group-uuid', 'domain2')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+ # fail due to auth
+ group.auth_status = AuthStatus.PENDING
+ data = self.get_transaction_data('trusted-job-uuid', 2, 'group-uuid', 'domain2')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+ # fail due to dataset unpublished
+ group.auth_status = AuthStatus.AUTHORIZED
+ group.dataset_id = 2
+ data = self.get_transaction_data('trusted-job-uuid', 2, 'group-uuid', 'domain2')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+ # fail due to algorithm not found
+ group.dataset_id = 1
+ group.algorithm_uuid = 'algorithm-not-exist'
+ data = self.get_transaction_data('trusted-job-uuid', 2, 'group-uuid', 'domain2')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.prepare()
+ self.assertFalse(flag)
+
+ @patch('fedlearner_webconsole.tee.services.get_batch_data_path')
+ def test_commit(self, mock_get_batch_data_path):
+ mock_get_batch_data_path.return_value = 'file:///data/test'
+ with db.session_scope() as session:
+ data = self.get_transaction_data('trusted-job-uuid', 2, 'group-uuid', 'domain2')
+ launcher = TrustedJobLauncher(session, '13', data)
+ flag, msg = launcher.commit()
+ session.commit()
+ self.assertTrue(flag)
+ with db.session_scope() as session:
+ trusted_job = session.query(TrustedJob).filter_by(trusted_job_group_id=1, version=2).first()
+ self.assertIsNotNone(trusted_job)
+ self.assertEqual(trusted_job.name, 'V2')
+ self.assertEqual(trusted_job.coordinator_id, 1)
+ group = session.query(TrustedJobGroup).get(1)
+ self.assertEqual(group.latest_version, 2)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_stopper.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_stopper.py
new file mode 100644
index 000000000..e2977fb6e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_stopper.py
@@ -0,0 +1,60 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+from fedlearner_webconsole.tee.models import TrustedJob, TrustedJobStatus
+from fedlearner_webconsole.tee.services import TrustedJobService
+
+
+class TrustedJobStopper(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.stop_trusted_job_data is not None
+ self._data = data.stop_trusted_job_data
+ self._session = session
+
+ def prepare(self) -> Tuple[bool, str]:
+ trusted_job = self._session.query(TrustedJob).filter_by(uuid=self._data.uuid).first()
+ if trusted_job is None:
+ message = f'failed to find trusted job by uuid {self._data.uuid}'
+ logging.info(f'[stop-trusted-job-2pc] prepare: {message}')
+ return False, message
+ if trusted_job.get_status() == TrustedJobStatus.PENDING:
+ message = 'trusted job status PENDING is unstoppable'
+ logging.info(f'[stop-trusted-job-2pc] prepare: {message}')
+ return False, message
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ trusted_job = self._session.query(TrustedJob).filter_by(uuid=self._data.uuid).first()
+ if trusted_job is None:
+ logging.error(f'[trusted-job-stop-2pc] commit: trusted job with uuid {self._data.uuid} not found')
+ return True, ''
+ TrustedJobService(self._session).stop_trusted_job(trusted_job)
+ if trusted_job.get_status() != TrustedJobStatus.STOPPED:
+ logging.warning(f'[trusted-job-stop-2pc] commit: stop trusted job with uuid {self._data.uuid} '
+ f'ending with status {trusted_job.get_status()}')
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[trusted-job-stop-2pc] abort')
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_stopper_test.py b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_stopper_test.py
new file mode 100644
index 000000000..f6cadfec0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/trusted_job_stopper_test.py
@@ -0,0 +1,88 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData, StopTrustedJobData
+from fedlearner_webconsole.tee.models import TrustedJob, TrustedJobStatus
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.two_pc.trusted_job_stopper import TrustedJobStopper
+
+
+class TrustedJobStopperTest(NoWebServerTestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ with db.session_scope() as session:
+ trusted_job1 = TrustedJob(id=1, uuid='uuid1', job_id=1)
+ job1 = Job(id=1,
+ name='job-name1',
+ job_type=JobType.CUSTOMIZED,
+ project_id=1,
+ workflow_id=0,
+ state=JobState.STARTED)
+ trusted_job2 = TrustedJob(id=2, uuid='uuid2', job_id=2)
+ job2 = Job(id=2,
+ name='job-name2',
+ job_type=JobType.CUSTOMIZED,
+ project_id=1,
+ workflow_id=0,
+ state=JobState.WAITING)
+ session.add_all([trusted_job1, trusted_job2, job1, job2])
+ session.commit()
+
+ @staticmethod
+ def get_transaction_data(uuid: str):
+ return TransactionData(stop_trusted_job_data=StopTrustedJobData(uuid=uuid))
+
+ def test_prepare(self):
+ with db.session_scope() as session:
+ # successful
+ data = self.get_transaction_data(uuid='uuid1')
+ stopper = TrustedJobStopper(session, '13', data)
+ flag, msg = stopper.prepare()
+ self.assertTrue(flag)
+ # fail due to trusted job not found
+ data = self.get_transaction_data(uuid='not-exist')
+ stopper = TrustedJobStopper(session, '13', data)
+ flag, msg = stopper.prepare()
+ self.assertFalse(flag)
+ # fail due to status not valid
+ data = self.get_transaction_data(uuid='uuid2')
+ stopper = TrustedJobStopper(session, '13', data)
+ flag, msg = stopper.prepare()
+ self.assertFalse(flag)
+
+ def test_commit(self):
+ with db.session_scope() as session:
+ # successful
+ data = self.get_transaction_data(uuid='uuid1')
+ stopper = TrustedJobStopper(session, '13', data)
+ stopper.commit()
+ # status not valid
+ data = self.get_transaction_data(uuid='uuid2')
+ stopper = TrustedJobStopper(session, '13', data)
+ stopper.commit()
+ session.commit()
+ with db.session_scope() as session:
+ trusted_job1 = session.query(TrustedJob).get(1)
+ self.assertEqual(trusted_job1.status, TrustedJobStatus.STOPPED)
+ trusted_job2 = session.query(TrustedJob).get(2)
+ self.assertEqual(trusted_job2.status, TrustedJobStatus.PENDING)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/two_pc/workflow_state_controller.py b/web_console_v2/api/fedlearner_webconsole/two_pc/workflow_state_controller.py
new file mode 100644
index 000000000..2892321b0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/two_pc/workflow_state_controller.py
@@ -0,0 +1,86 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+from typing import Tuple
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.workflow.service import WorkflowService
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.proto.two_pc_pb2 import TransactionData
+from fedlearner_webconsole.two_pc.resource_manager import ResourceManager
+from fedlearner_webconsole.workflow.workflow_controller import start_workflow_locally, stop_workflow_locally
+
+
+class WorkflowStateController(ResourceManager):
+
+ def __init__(self, session: Session, tid: str, data: TransactionData):
+ super().__init__(tid, data)
+ assert data.transit_workflow_state_data is not None
+ self._data = data.transit_workflow_state_data
+ self._session = session
+ self._state_convert_map = {
+ WorkflowState.RUNNING: lambda workflow: start_workflow_locally(self._session, workflow),
+ WorkflowState.STOPPED: lambda workflow: stop_workflow_locally(self._session, workflow),
+ }
+
+ def prepare(self) -> Tuple[bool, str]:
+ workflow = self._session.query(Workflow).filter_by(uuid=self._data.workflow_uuid).first()
+ if workflow is None:
+ message = f'failed to find workflow, uuid is {self._data.workflow_uuid}'
+ logging.warning(f'[workflow state 2pc] prepare: {message}, uuid: {self._data.workflow_uuid}')
+ return False, message
+
+ if WorkflowState[self._data.target_state] not in self._state_convert_map:
+ message = f'illegal target state {self._data.target_state}, uuid: {self._data.workflow_uuid}'
+ logging.warning(f'[workflow state 2pc] prepare: {message}')
+ return False, message
+ if not workflow.can_transit_to(WorkflowState[self._data.target_state]):
+ message = f'change worflow state from {workflow.state.name} to {self._data.target_state} is forbidden, \
+ uuid: {self._data.workflow_uuid}'
+
+ logging.warning(f'[workflow state 2pc] prepare: {message}')
+ return False, message
+
+ if WorkflowState[self._data.target_state] == WorkflowState.STOPPED:
+ return True, ''
+
+ is_valid, info = WorkflowService(self._session).validate_workflow(workflow)
+ if not is_valid:
+ job_name, validate_e = info
+ message = f'Invalid variable when try to format the job ' f'{job_name}:{str(validate_e)}, \
+ uuid: {self._data.workflow_uuid}'
+
+ logging.warning(f'[workflow state 2pc] prepare: {message}')
+ return False, message
+ return True, ''
+
+ def commit(self) -> Tuple[bool, str]:
+ workflow = self._session.query(Workflow).filter_by(uuid=self._data.workflow_uuid).first()
+ if workflow.is_invalid():
+ message = 'workflow is already invalidated by participant'
+ logging.error(f'[workflow state 2pc] commit: {message}, uuid: {self._data.workflow_uuid}')
+ raise ValueError(message)
+ try:
+ self._state_convert_map[WorkflowState[self._data.target_state]](workflow)
+ except RuntimeError as e:
+ logging.error(f'[workflow state 2pc] commit: {e}, uuid: {self._data.workflow_uuid}')
+ raise
+ return True, ''
+
+ def abort(self) -> Tuple[bool, str]:
+ logging.info('[workflow state 2pc] abort')
+ return True, ''
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/utils/BUILD.bazel
new file mode 100644
index 000000000..5db582165
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/BUILD.bazel
@@ -0,0 +1,610 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = [
+ "//web_console_v2/api:console_api_package",
+])
+
+py_library(
+ name = "app_version_lib",
+ srcs = ["app_version.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto"],
+)
+
+py_test(
+ name = "app_version_test",
+ srcs = ["app_version_test.py"],
+ imports = ["../.."],
+ main = "app_version_test.py",
+ deps = [
+ ":app_version_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_library(
+ name = "pp_base64_lib",
+ srcs = ["pp_base64.py"],
+ imports = ["../.."],
+)
+
+py_test(
+ name = "pp_base64_test",
+ srcs = ["pp_base64_test.py"],
+ imports = ["../.."],
+ main = "pp_base64_test.py",
+ deps = [
+ ":pp_base64_lib",
+ ],
+)
+
+py_library(
+ name = "const_lib",
+ srcs = ["const.py"],
+ imports = ["../.."],
+)
+
+py_library(
+ name = "pp_datetime_lib",
+ srcs = ["pp_datetime.py"],
+ imports = ["../.."],
+ deps = ["@common_python_dateutil//:pkg"],
+)
+
+py_test(
+ name = "pp_datetime_test",
+ srcs = ["pp_datetime_test.py"],
+ imports = ["../.."],
+ main = "pp_datetime_test.py",
+ deps = [
+ ":pp_datetime_lib",
+ ],
+)
+
+py_library(
+ name = "domain_name_lib",
+ srcs = ["domain_name.py"],
+ imports = ["../.."],
+)
+
+py_test(
+ name = "domain_name_test",
+ srcs = ["domain_name_test.py"],
+ imports = ["../.."],
+ main = "domain_name_test.py",
+ deps = [
+ ":domain_name_lib",
+ ],
+)
+
+py_library(
+ name = "es_lib",
+ srcs = [
+ "es.py",
+ "es_misc.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "@common_elasticsearch//:pkg",
+ ],
+)
+
+py_library(
+ name = "file_lib",
+ srcs = [
+ "file_manager.py",
+ "file_operator.py",
+ "file_tree.py",
+ "stream_tars.py",
+ "stream_untars.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_fsspec//:pkg",
+ # this is needed when using fsspec hdfs feature.
+ "@common_pyarrow//:pkg", # keep
+ "@common_tensorflow//:pkg",
+ ],
+)
+
+py_test(
+ name = "file_lib_test",
+ size = "small",
+ srcs = ["file_manager_test.py"],
+ imports = ["../.."],
+ main = "file_manager_test.py",
+ deps = [
+ ":file_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ "@common_tensorflow//:pkg",
+ ],
+)
+
+py_test(
+ name = "file_operator_test",
+ size = "small",
+ srcs = ["file_operator_test.py"],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ "//web_console_v2/api/testing/test_data/algorithm",
+ ],
+ imports = ["../.."],
+ main = "file_operator_test.py",
+ deps = [
+ ":file_lib",
+ "//web_console_v2/api:envs_lib",
+ ],
+)
+
+py_test(
+ name = "file_tree_test",
+ size = "small",
+ srcs = ["file_tree_test.py"],
+ imports = ["../.."],
+ main = "file_tree_test.py",
+ deps = [
+ ":file_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "//web_console_v2/api/testing:fake_lib",
+ "@common_tensorflow//:pkg",
+ ],
+)
+
+py_test(
+ name = "tars_test",
+ srcs = ["tars_test.py"],
+ imports = ["../.."],
+ main = "tars_test.py",
+ deps = [
+ ":file_lib",
+ ],
+)
+
+py_library(
+ name = "filtering_lib",
+ srcs = ["filtering.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_pyparsing//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "filtering_test",
+ size = "small",
+ srcs = ["filtering_test.py"],
+ imports = ["../.."],
+ main = "filtering_test.py",
+ deps = [
+ ":filtering_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "flask_utils_lib",
+ srcs = ["flask_utils.py"],
+ imports = ["../.."],
+ deps = [
+ ":filtering_lib",
+ ":proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "flask_utils_test",
+ size = "medium",
+ srcs = ["flask_utils_test.py"],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "flask_utils_test.py",
+ deps = [
+ ":flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/testing:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "hooks_lib",
+ srcs = ["hooks.py"],
+ imports = ["../.."],
+ deps = [
+ ":metrics_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:api_latency_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:middlewares_lib",
+ "//web_console_v2/api/fedlearner_webconsole/middleware:request_id_lib",
+ ],
+)
+
+py_test(
+ name = "hooks_test",
+ srcs = ["hooks_test.py"],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "hooks_test.py",
+ deps = [
+ ":hooks_lib",
+ ],
+)
+
+py_library(
+ name = "images_lib",
+ srcs = ["images.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api/fedlearner_webconsole/setting:service_lib"],
+)
+
+py_test(
+ name = "images_test",
+ srcs = ["images_test.py"],
+ imports = ["../.."],
+ main = "images_test.py",
+ deps = [
+ ":images_lib",
+ ],
+)
+
+py_library(
+ name = "job_metrics_lib",
+ srcs = ["job_metrics.py"],
+ imports = ["../.."],
+ deps = [
+ ":file_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_tensorflow//:pkg",
+ ],
+)
+
+py_library(
+ name = "kibana_lib",
+ srcs = ["kibana.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "@common_prison//:pkg",
+ "@common_pytz//:pkg",
+ "@common_requests//:pkg",
+ ],
+)
+
+py_test(
+ name = "kibana_test",
+ srcs = ["kibana_test.py"],
+ imports = ["../.."],
+ main = "kibana_test.py",
+ deps = [
+ ":kibana_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ ],
+)
+
+py_library(
+ name = "metrics_lib",
+ srcs = ["metrics.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api:envs_lib",
+ "@common_opentelemetry_exporter_otlp//:pkg",
+ "@common_opentelemetry_sdk//:pkg",
+ ],
+)
+
+py_test(
+ name = "metrics_test",
+ size = "small",
+ srcs = ["metrics_test.py"],
+ imports = ["../.."],
+ main = "metrics_test.py",
+ deps = [
+ ":metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "mixins_lib",
+ srcs = ["mixins.py"],
+ imports = ["../.."],
+ deps = [
+ ":pp_datetime_lib",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "mixins_test",
+ srcs = ["mixins_test.py"],
+ imports = ["../.."],
+ main = "mixins_test.py",
+ deps = [
+ ":mixins_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "paginate_lib",
+ srcs = ["paginate.py"],
+ imports = ["../.."],
+ deps = ["@common_sqlalchemy//:pkg"],
+)
+
+py_test(
+ name = "paginate_lib_test",
+ size = "small",
+ srcs = ["paginate_test.py"],
+ imports = ["../.."],
+ main = "paginate_test.py",
+ deps = [
+ ":paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "pp_flatten_dict_lib",
+ srcs = ["pp_flatten_dict.py"],
+ imports = ["../.."],
+ deps = [
+ "@common_flatten_dict//:pkg",
+ "@common_six//:pkg",
+ ],
+)
+
+py_test(
+ name = "pp_flatten_dict_test",
+ srcs = ["pp_flatten_dict_test.py"],
+ imports = ["../.."],
+ main = "pp_flatten_dict_test.py",
+ deps = [
+ ":pp_flatten_dict_lib",
+ ],
+)
+
+py_library(
+ name = "pp_yaml_lib",
+ srcs = ["pp_yaml.py"],
+ imports = ["../.."],
+ deps = [
+ ":const_lib",
+ ":pp_flatten_dict_lib",
+ ":system_envs_lib",
+ "//web_console_v2/api/fedlearner_webconsole/setting:service_lib",
+ "@common_simpleeval//:pkg",
+ ],
+)
+
+py_test(
+ name = "pp_yaml_test",
+ srcs = ["pp_yaml_test.py"],
+ imports = ["../.."],
+ main = "pp_yaml_test.py",
+ deps = [
+ ":pp_yaml_lib",
+ ],
+)
+
+py_library(
+ name = "proto_lib",
+ srcs = ["proto.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "proto_test",
+ srcs = ["proto_test.py"],
+ imports = ["../.."],
+ main = "proto_test.py",
+ deps = [
+ ":proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/testing:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "resource_name_lib",
+ srcs = ["resource_name.py"],
+ imports = ["../.."],
+)
+
+py_test(
+ name = "resource_name_test",
+ srcs = ["resource_name_test.py"],
+ imports = ["../.."],
+ main = "resource_name_test.py",
+ deps = [
+ ":resource_name_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto/testing:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "schema_lib",
+ srcs = ["schema.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api/fedlearner_webconsole:exceptions_lib"],
+)
+
+py_test(
+ name = "schema_test",
+ srcs = ["schema_test.py"],
+ imports = ["../.."],
+ main = "schema_test.py",
+ deps = [
+ ":schema_lib",
+ ],
+)
+
+py_library(
+ name = "sorting_lib",
+ srcs = ["sorting.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "sorting_test",
+ srcs = ["sorting_test.py"],
+ imports = ["../.."],
+ main = "sorting_test.py",
+ deps = [
+ ":sorting_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "swagger_lib",
+ srcs = ["swagger.py"],
+ imports = ["../.."],
+)
+
+py_test(
+ name = "swagger_test",
+ srcs = ["swagger_test.py"],
+ imports = ["../.."],
+ main = "swagger_test.py",
+ deps = [
+ ":swagger_lib",
+ ],
+)
+
+py_library(
+ name = "system_envs_lib",
+ srcs = ["system_envs.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api:envs_lib"],
+)
+
+py_test(
+ name = "system_envs_test",
+ srcs = ["system_envs_test.py"],
+ imports = ["../.."],
+ main = "system_envs_test.py",
+ deps = [
+ ":system_envs_lib",
+ ],
+)
+
+py_library(
+ name = "tfrecords_reader_lib",
+ srcs = ["tfrecords_reader.py"],
+ imports = ["../.."],
+ deps = ["@common_tensorflow//:pkg"],
+)
+
+py_test(
+ name = "tfrecord_test",
+ srcs = ["tfrecord_test.py"],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "tfrecord_test.py",
+ deps = [
+ ":tfrecords_reader_lib",
+ "//web_console_v2/api:envs_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_library(
+ name = "pp_time_lib",
+ srcs = ["pp_time.py"],
+ imports = ["../.."],
+)
+
+py_library(
+ name = "validator_lib",
+ srcs = ["validator.py"],
+ imports = ["../.."],
+)
+
+py_test(
+ name = "validator_test",
+ srcs = ["validator_test.py"],
+ imports = ["../.."],
+ main = "validator_test.py",
+ deps = [
+ ":validator_lib",
+ ],
+)
+
+py_library(
+ name = "workflow_lib",
+ srcs = ["workflow.py"],
+ imports = ["../.."],
+ deps = ["//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto"],
+)
+
+py_test(
+ name = "workflow_test",
+ srcs = ["workflow_test.py"],
+ imports = ["../.."],
+ main = "workflow_test.py",
+ deps = [
+ ":workflow_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "process_utils_lib",
+ srcs = ["process_utils.py"],
+ imports = ["../.."],
+ deps = [
+ ":hooks_lib",
+ ],
+)
+
+py_test(
+ name = "process_utils_test",
+ srcs = ["process_utils_test.py"],
+ imports = ["../.."],
+ main = "process_utils_test.py",
+ deps = [
+ ":process_utils_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/app_version.py b/web_console_v2/api/fedlearner_webconsole/utils/app_version.py
new file mode 100644
index 000000000..ee6b94f4d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/app_version.py
@@ -0,0 +1,121 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import re
+from typing import Optional
+
+from fedlearner_webconsole.proto import common_pb2
+
+_VERSION_RE = re.compile(r'^(\d+).(\d+).(\d+)')
+
+
+class Version(object):
+
+ def __init__(self, version: Optional[str] = None):
+ self._version = version
+ self._major = None
+ self._minor = None
+ self._patch = None
+ if version is not None:
+ self._parse_version(version)
+
+ def _parse_version(self, version: str):
+ matches = _VERSION_RE.match(version)
+ if matches:
+ self._major = int(matches.group(1))
+ self._minor = int(matches.group(2))
+ self._patch = int(matches.group(3))
+
+ @property
+ def version(self):
+ return self._version
+
+ @property
+ def major(self):
+ return self._major
+
+ @property
+ def minor(self):
+ return self._minor
+
+ @property
+ def patch(self):
+ return self._patch
+
+ def is_standard(self):
+ return self.major is not None and self.minor is not None and self.patch is not None
+
+ def __eq__(self, other):
+ assert isinstance(other, Version)
+ if self.is_standard():
+ return self.major == other.major and self.minor == other.minor and self.patch == other.patch
+ return self.version == other.version
+
+ def __ne__(self, other):
+ assert isinstance(other, Version)
+ return not self.__eq__(other)
+
+ def __gt__(self, other):
+ assert isinstance(other, Version)
+ if not self.is_standard() or not other.is_standard():
+ # Not compatible
+ return False
+ if self.major > other.major:
+ return True
+ if self.major < other.major:
+ return False
+ if self.minor > other.minor:
+ return True
+ if self.minor < other.minor:
+ return False
+ return self.patch > other.patch
+
+ def __lt__(self, other):
+ assert isinstance(other, Version)
+ if not self.is_standard() or not other.is_standard():
+ # Not compatible
+ return False
+ return not self.__ge__(other)
+
+ def __ge__(self, other):
+ assert isinstance(other, Version)
+ return self.__gt__(other) or self.__eq__(other)
+
+ def __le__(self, other):
+ assert isinstance(other, Version)
+ return self.__lt__(other) or self.__eq__(other)
+
+
+class ApplicationVersion(object):
+ """Version of the application.
+
+ Attributes:
+ revision: Commit id of the head
+ branch_name: Branch name of the image
+ pub_date: Date when image is built in ISO format
+ version: Semantic version
+ """
+
+ def __init__(self, revision: str, branch_name: str, pub_date: str, version: Optional[str] = None):
+ self.revision = revision
+ self.branch_name = branch_name
+ self.pub_date = pub_date
+ self.version = Version(version)
+
+ def to_proto(self) -> common_pb2.ApplicationVersion:
+ return common_pb2.ApplicationVersion(pub_date=self.pub_date,
+ revision=self.revision,
+ branch_name=self.branch_name,
+ version=self.version.version)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/app_version_test.py b/web_console_v2/api/fedlearner_webconsole/utils/app_version_test.py
new file mode 100644
index 000000000..e0ea0298f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/app_version_test.py
@@ -0,0 +1,124 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.utils.app_version import Version, ApplicationVersion
+
+
+class VersionTest(unittest.TestCase):
+
+ def test_version_number(self):
+ v = Version()
+ self.assertIsNone(v.version)
+ v = Version('non')
+ self.assertEqual(v.version, 'non')
+ self.assertIsNone(v.major)
+ self.assertIsNone(v.minor)
+ self.assertIsNone(v.patch)
+ v = Version('2.1.33.3')
+ self.assertEqual(v.version, '2.1.33.3')
+ self.assertEqual(v.major, 2)
+ self.assertEqual(v.minor, 1)
+ self.assertEqual(v.patch, 33)
+
+ def test_is_standard(self):
+ self.assertTrue(Version('2.1.33').is_standard())
+ self.assertTrue(Version('2.1.33.1').is_standard())
+ self.assertFalse(Version('non').is_standard())
+
+ # Tests == and !=
+ def test_eq_and_ne(self):
+ v1 = Version('non')
+ v2 = Version('non')
+ v3 = Version('2.1.33.1')
+ v4 = Version('2.1.33')
+ v5 = Version('2.1.34')
+ self.assertTrue(v1 == v2)
+ self.assertFalse(v1 != v2)
+ self.assertTrue(v3 == v4)
+ self.assertFalse(v3 != v4)
+ self.assertFalse(v1 == v3)
+ self.assertTrue(v1 != v3)
+ self.assertFalse(v4 == v5)
+ self.assertTrue(v4 != v5)
+
+ # Tests >
+ def test_gt(self):
+ v1 = Version('nffff')
+ v2 = Version('2.1.33')
+ v3 = Version('2.1.34')
+ v4 = Version('2.2.33')
+ self.assertFalse(v2 > v1)
+ self.assertTrue(v3 > v2)
+ self.assertFalse(v2 > v3)
+ self.assertTrue(v4 > v3)
+ self.assertFalse(v3 > v4)
+ self.assertTrue(v4 > v2)
+ self.assertFalse(v2 > v4)
+
+ # Tests <
+ def test_lt(self):
+ v1 = Version()
+ v2 = Version('1.1.33')
+ v3 = Version('2.1.34')
+ v4 = Version('2.2.34')
+ self.assertFalse(v1 < v2)
+ self.assertTrue(v2 < v3)
+ self.assertFalse(v3 < v2)
+ self.assertTrue(v3 < v4)
+ self.assertFalse(v4 < v3)
+ self.assertTrue(v2 < v4)
+ self.assertFalse(v4 < v2)
+
+ # Tests >=
+ def test_ge(self):
+ v1 = Version('nffff')
+ v2 = Version('2.1.33')
+ v3 = Version('2.1.34')
+ self.assertFalse(v1 >= v2)
+ self.assertFalse(v2 >= v1)
+ self.assertTrue(v3 >= v2)
+ self.assertFalse(v2 >= v3)
+
+ # Tests <=
+ def test_le(self):
+ v1 = Version()
+ v2 = Version('2.1.33')
+ v3 = Version('2.1.34')
+ self.assertFalse(v1 <= v2)
+ self.assertFalse(v2 <= v1)
+ self.assertTrue(v2 <= v3)
+ self.assertFalse(v3 <= v2)
+
+
+class ApplicationVersionTest(unittest.TestCase):
+
+ def test_to_proto(self):
+ v = ApplicationVersion(revision='1234234234',
+ branch_name='dev',
+ version='non-standard',
+ pub_date='Fri Jul 16 12:23:19 CST 2021')
+ self.assertEqual(
+ v.to_proto(),
+ common_pb2.ApplicationVersion(revision='1234234234',
+ branch_name='dev',
+ version='non-standard',
+ pub_date='Fri Jul 16 12:23:19 CST 2021'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base64.py b/web_console_v2/api/fedlearner_webconsole/utils/base64.py
deleted file mode 100644
index 06272b638..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/base64.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-from base64 import b64encode, b64decode
-
-
-def base64encode(s: str) -> str:
- return b64encode(s.encode('UTF-8')).decode('UTF-8')
-
-
-def base64decode(s: str) -> str:
- return b64decode(s).decode('UTF-8')
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/utils/base_model/BUILD.bazel
new file mode 100644
index 000000000..18883a3c6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/BUILD.bazel
@@ -0,0 +1,64 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "base_model_lib",
+ srcs = [
+ "auth_model.py",
+ "review_ticket_and_auth_model.py",
+ "review_ticket_model.py",
+ "softdelete_model.py",
+ ],
+ imports = ["../../.."],
+ deps = [
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "auth_model_test",
+ size = "small",
+ srcs = [
+ "auth_model_test.py",
+ ],
+ imports = ["../../.."],
+ main = "auth_model_test.py",
+ deps = [
+ ":base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/testing:common_lib",
+ ],
+)
+
+py_test(
+ name = "review_ticket_and_auth_model_test",
+ size = "small",
+ srcs = [
+ "review_ticket_and_auth_model_test.py",
+ ],
+ imports = ["../../.."],
+ main = "review_ticket_and_auth_model_test.py",
+ deps = [
+ ":base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_test(
+ name = "review_ticket_model_test",
+ srcs = [
+ "review_ticket_model_test.py",
+ ],
+ imports = ["../../.."],
+ main = "review_ticket_model_test.py",
+ deps = [
+ ":base_model_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/auth_model.py b/web_console_v2/api/fedlearner_webconsole/utils/base_model/auth_model.py
new file mode 100644
index 000000000..1960fde35
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/auth_model.py
@@ -0,0 +1,30 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+from sqlalchemy import Column, Enum
+
+
+class AuthStatus(enum.Enum):
+ PENDING = 'PENDING'
+ AUTHORIZED = 'AUTHORIZED'
+ WITHDRAW = 'WITHDRAW'
+
+
+class AuthModel(object):
+
+ auth_status = Column(Enum(AuthStatus, length=32, native_enum=False, create_constraint=False),
+ default=AuthStatus.PENDING,
+ comment='auth status')
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/auth_model_test.py b/web_console_v2/api/fedlearner_webconsole/utils/base_model/auth_model_test.py
new file mode 100644
index 000000000..3c3048252
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/auth_model_test.py
@@ -0,0 +1,42 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from testing.common import NoWebServerTestCase
+from fedlearner_webconsole.utils.base_model.auth_model import AuthModel, AuthStatus
+from fedlearner_webconsole.db import db, default_table_args
+
+
+class TestModel(db.Model, AuthModel):
+ __tablename__ = 'test_model'
+ __table_args__ = (default_table_args('This is webconsole dataset table'))
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+
+
+class AuthModelTest(NoWebServerTestCase):
+
+ def test_mixins(self):
+ with db.session_scope() as session:
+ model = TestModel(auth_status=AuthStatus.PENDING)
+ session.add(model)
+ session.commit()
+ with db.session_scope() as session:
+ models = session.query(TestModel).all()
+ self.assertEqual(len(models), 1)
+ self.assertEqual(models[0].auth_status, AuthStatus.PENDING)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_and_auth_model.py b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_and_auth_model.py
new file mode 100644
index 000000000..b11fc5f99
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_and_auth_model.py
@@ -0,0 +1,58 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+import sqlalchemy as sa
+from google.protobuf import text_format
+
+from fedlearner_webconsole.utils.base_model.auth_model import AuthModel, AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import ReviewTicketModel, TicketStatus
+from fedlearner_webconsole.proto import project_pb2
+
+
+class AuthFrontendState(enum.Enum):
+ TICKET_PENDING = 'TICKET_PENDING'
+ TICKET_DECLINED = 'TICKET_DECLINED'
+ AUTH_PENDING = 'AUTH_PENDING'
+ AUTH_APPROVED = 'AUTH_APPROVED'
+
+
+class ReviewTicketAndAuthModel(AuthModel, ReviewTicketModel):
+
+ participants_info = sa.Column(sa.Text(), comment='participants info')
+
+ @property
+ def auth_frontend_state(self) -> AuthFrontendState:
+ if self.ticket_status == TicketStatus.PENDING:
+ return AuthFrontendState.TICKET_PENDING
+ if self.ticket_status == TicketStatus.DECLINED:
+ return AuthFrontendState.TICKET_DECLINED
+ if self.ticket_status == TicketStatus.APPROVED and self.is_all_participants_authorized():
+ return AuthFrontendState.AUTH_APPROVED
+ return AuthFrontendState.AUTH_PENDING
+
+ def set_participants_info(self, participants_info: project_pb2.ParticipantsInfo):
+ self.participants_info = text_format.MessageToString(participants_info)
+
+ def get_participants_info(self) -> project_pb2.ParticipantsInfo:
+ participants_info = project_pb2.ParticipantsInfo()
+ if self.participants_info is not None:
+ participants_info = text_format.Parse(self.participants_info, project_pb2.ParticipantsInfo())
+ return participants_info
+
+ def is_all_participants_authorized(self) -> bool:
+ participants_info_list = self.get_participants_info().participants_map.values()
+ return all(
+ participant_info.auth_status == AuthStatus.AUTHORIZED.name for participant_info in participants_info_list)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_and_auth_model_test.py b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_and_auth_model_test.py
new file mode 100644
index 000000000..0ed9c16b3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_and_auth_model_test.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.utils.base_model.review_ticket_and_auth_model import ReviewTicketAndAuthModel, \
+ AuthFrontendState
+from fedlearner_webconsole.utils.base_model.auth_model import AuthStatus
+from fedlearner_webconsole.utils.base_model.review_ticket_model import TicketStatus
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.proto import project_pb2
+
+
+class TestModel(db.Model, ReviewTicketAndAuthModel):
+ __tablename__ = 'test_model'
+ __table_args__ = (default_table_args('This is webconsole test_model table'))
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+
+
+class ReviewTicketAndAuthModelTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ model = TestModel(auth_status=AuthStatus.PENDING)
+ session.add(model)
+ session.commit()
+
+ @patch(
+ 'fedlearner_webconsole.utils.base_model.review_ticket_and_auth_model.ReviewTicketAndAuthModel.' \
+ 'is_all_participants_authorized'
+ )
+ def test_auth_frontend_state(self, mock_authorized: MagicMock):
+ with db.session_scope() as session:
+ test_model: TestModel = session.query(TestModel).get(1)
+ mock_authorized.return_value = True
+ self.assertEqual(test_model.auth_frontend_state, AuthFrontendState.AUTH_APPROVED)
+
+ mock_authorized.return_value = False
+ self.assertEqual(test_model.auth_frontend_state, AuthFrontendState.AUTH_PENDING)
+
+ test_model.ticket_status = TicketStatus.PENDING
+ self.assertEqual(test_model.auth_frontend_state, AuthFrontendState.TICKET_PENDING)
+
+ test_model.ticket_status = TicketStatus.DECLINED
+ self.assertEqual(test_model.auth_frontend_state, AuthFrontendState.TICKET_DECLINED)
+
+ def test_set_and_get_participants_info(self):
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ with db.session_scope() as session:
+ test_model = session.query(TestModel).get(1)
+ test_model.set_participants_info(participants_info)
+ session.commit()
+ with db.session_scope() as session:
+ test_model = session.query(TestModel).get(1)
+ self.assertEqual(test_model.get_participants_info(), participants_info)
+
+ def test_is_all_participants_authorized(self):
+ # test no participants_info
+ with db.session_scope() as session:
+ test_model = session.query(TestModel).get(1)
+ self.assertTrue(test_model.is_all_participants_authorized())
+
+ # test all authorized
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name),
+ 'test_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ with db.session_scope() as session:
+ test_model = session.query(TestModel).get(1)
+ test_model.set_participants_info(participants_info)
+ self.assertTrue(test_model.is_all_participants_authorized())
+
+ # test not all authorized
+ participants_info = project_pb2.ParticipantsInfo(
+ participants_map={
+ 'test_1': project_pb2.ParticipantInfo(auth_status=AuthStatus.PENDING.name),
+ 'test_2': project_pb2.ParticipantInfo(auth_status=AuthStatus.AUTHORIZED.name)
+ })
+ with db.session_scope() as session:
+ test_model = session.query(TestModel).get(1)
+ test_model.set_participants_info(participants_info)
+ self.assertFalse(test_model.is_all_participants_authorized())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_model.py b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_model.py
new file mode 100644
index 000000000..0b66d16f0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_model.py
@@ -0,0 +1,34 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+import sqlalchemy as sa
+
+
+class TicketStatus(enum.Enum):
+ PENDING = 'PENDING'
+ APPROVED = 'APPROVED'
+ DECLINED = 'DECLINED'
+
+
+class ReviewTicketModel(object):
+
+ ticket_uuid = sa.Column(sa.String(255),
+ nullable=True,
+ comment='review ticket uuid, empty if review function is disable')
+ ticket_status = sa.Column(sa.Enum(TicketStatus, length=32, native_enum=False, create_constraint=False),
+ default=TicketStatus.APPROVED,
+ server_default=TicketStatus.APPROVED.name,
+ comment='review ticket status')
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_model_test.py b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_model_test.py
new file mode 100644
index 000000000..54cfcc232
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/review_ticket_model_test.py
@@ -0,0 +1,43 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.utils.base_model.review_ticket_model import ReviewTicketModel, TicketStatus
+from fedlearner_webconsole.db import db, default_table_args
+
+
+class TestModel(db.Model, ReviewTicketModel):
+ __tablename__ = 'test_model'
+ __table_args__ = (default_table_args('This is webconsole dataset table'))
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
+
+
+class ReviewTicketModelTest(NoWebServerTestCase):
+
+ def test_mixins(self):
+ with db.session_scope() as session:
+ model = TestModel(ticket_uuid='u1234', ticket_status=TicketStatus.APPROVED)
+ session.add(model)
+ session.commit()
+ with db.session_scope() as session:
+ models = session.query(TestModel).all()
+ self.assertEqual(len(models), 1)
+ self.assertEqual(models[0].ticket_status, TicketStatus.APPROVED)
+ self.assertEqual(models[0].ticket_uuid, 'u1234')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/base_model/softdelete_model.py b/web_console_v2/api/fedlearner_webconsole/utils/base_model/softdelete_model.py
new file mode 100644
index 000000000..c6563243e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/base_model/softdelete_model.py
@@ -0,0 +1,21 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from sqlalchemy import DateTime, Column
+
+
+class SoftDeleteModel(object):
+
+ deleted_at = Column(DateTime(timezone=True), comment='deleted time')
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/certificate.py b/web_console_v2/api/fedlearner_webconsole/utils/certificate.py
deleted file mode 100644
index 9d400e94f..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/certificate.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import os
-import json
-from base64 import b64encode
-
-from fedlearner_webconsole.utils.k8s_client import k8s_client
-
-
-def create_image_pull_secret():
- """Create certificate for image hub (Once for a system)"""
- image_hub_url = os.environ.get('IMAGE_HUB_URL')
- image_hub_username = os.environ.get('IMAGE_HUB_USERNAME')
- image_hub_password = os.environ.get('IMAGE_HUB_PASSWORD')
- if image_hub_url is None or image_hub_username is None or \
- image_hub_password is None:
- return
-
- # using base64 to encode authorization information
- encoded_username_password = str(b64encode(
- '{}:{}'.format(image_hub_username, image_hub_password)
- ))
- encoded_image_cert = str(b64encode(
- json.dumps({
- 'auths': {
- image_hub_url: {
- 'username': image_hub_username,
- 'password': image_hub_password,
- 'auth': encoded_username_password
- }
- }})), 'utf-8')
-
- k8s_client.create_or_update_secret(
- data={
- '.dockerconfigjson': encoded_image_cert
- },
- metadata={
- 'name': 'regcred',
- 'namespace': 'default'
- },
- secret_type='kubernetes.io/dockerconfigjson',
- name='regcred'
- )
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/const.py b/web_console_v2/api/fedlearner_webconsole/utils/const.py
new file mode 100644
index 000000000..994aa0d72
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/const.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+# API
+API_VERSION = '/api/v2'
+
+# Pagination
+DEFAULT_PAGE = 1
+DEFAULT_PAGE_SIZE = 50
+
+# name of preset data join workflow template
+SYS_DATA_JOIN_TEMPLATE = [
+ # data join
+ 'sys-preset-data-join',
+ 'sys-preset-fe-data-join',
+ # psi data join
+ 'sys-preset-psi-data-join',
+ 'sys-preset-fe-psi-data-join',
+ # light client
+ 'sys-preset-light-psi-data-join',
+ # TODO(xiangyuxuan.prs): change psi job type from TRANSFORMER to PSI_DATA_JOIN, when remove sys-preset-psi-data-join
+ # psi data join with analyzer
+ 'sys-preset-psi-data-join-analyzer',
+ 'sys-preset-converter-analyzer'
+]
+# name of preset model workflow template
+SYS_PRESET_VERTICAL_NN_TEMPLATE = 'sys-preset-nn-model'
+SYS_PRESET_HORIZONTAL_NN_TEMPLATE = 'sys-preset-nn-horizontal-model'
+SYS_PRESET_HORIZONTAL_NN_EVAL_TEMPLATE = 'sys-preset-nn-horizontal-eval-model'
+SYS_PRESET_TREE_TEMPLATE = 'sys-preset-tree-model'
+
+SYS_PRESET_TEMPLATE = [
+ *SYS_DATA_JOIN_TEMPLATE, SYS_PRESET_VERTICAL_NN_TEMPLATE, SYS_PRESET_HORIZONTAL_NN_TEMPLATE,
+ SYS_PRESET_HORIZONTAL_NN_EVAL_TEMPLATE, SYS_PRESET_TREE_TEMPLATE
+]
+
+# dataset
+DATASET_PREVIEW_NUM = 20
+
+DEFAULT_OWNER = 'no___user'
+
+DEFAULT_OWNER_FOR_JOB_WITHOUT_WORKFLOW = 'no___workflow'
+
+# auth related
+SIGN_IN_INTERVAL_SECONDS = 1800
+MAX_SIGN_IN_ATTEMPTS = 3
+
+SYSTEM_WORKFLOW_CREATOR_USERNAME = 's_y_s_t_e_m'
+
+SSO_HEADER = 'x-pc-auth'
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators.py b/web_console_v2/api/fedlearner_webconsole/utils/decorators.py
deleted file mode 100644
index 3f0a0aee2..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/decorators.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License
-
-# coding=utf-8
-
-import logging
-from functools import wraps
-from traceback import format_exc
-import flask_jwt_extended
-from flask_jwt_extended.utils import get_current_user
-from fedlearner_webconsole.auth.models import Role
-from fedlearner_webconsole.exceptions import UnauthorizedException
-from envs import Envs
-
-
-def admin_required(f):
- @wraps(f)
- def wrapper_inside(*args, **kwargs):
- current_user = get_current_user()
- if current_user.role != Role.ADMIN:
- raise UnauthorizedException('only admin can operate this')
- return f(*args, **kwargs)
- return wrapper_inside
-
-
-def jwt_required(*jwt_args, **jwt_kwargs):
- def decorator(f):
- if Envs.DEBUG:
- @wraps(f)
- def wrapper(*args, **kwargs):
- return f(*args, **kwargs)
- else:
- wrapper = flask_jwt_extended.jwt_required(
- *jwt_args, **jwt_kwargs)(f)
- return wrapper
- return decorator
-
-
-def retry_fn(retry_times: int = 3, needed_exceptions=None):
- def decorator_retry_fn(f):
- # to resolve pylint warning
- # Dangerous default value [] as argument (dangerous-default-value)
- nonlocal needed_exceptions
- if needed_exceptions is None:
- needed_exceptions = [Exception]
-
- @wraps(f)
- def wrapper(*args, **kwargs):
- for i in range(retry_times):
- try:
- return f(*args, **kwargs)
- except tuple(needed_exceptions):
- logging.error('Call function failed, retrying %s times...',
- i + 1)
- logging.error('Exceptions:\n%s', format_exc())
- logging.error(
- 'function name is %s, args are %s, kwargs are %s',
- f.__name__, repr(args), repr(kwargs))
- if i == retry_times - 1:
- raise
- continue
-
- return wrapper
-
- return decorator_retry_fn
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/utils/decorators/BUILD.bazel
new file mode 100644
index 000000000..57d66a3b4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/decorators/BUILD.bazel
@@ -0,0 +1,62 @@
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "decorators_lib",
+ srcs = [
+ "lru_cache.py",
+ "pp_flask.py",
+ "retry.py",
+ ],
+ imports = ["../../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "@common_flask//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_webargs//:pkg",
+ ],
+)
+
+py_test(
+ name = "lru_cache_test",
+ srcs = [
+ "lru_cache_test.py",
+ ],
+ imports = ["../../.."],
+ main = "lru_cache_test.py",
+ deps = [
+ ":decorators_lib",
+ ],
+)
+
+py_test(
+ name = "pp_flask_test",
+ size = "medium",
+ srcs = [
+ "pp_flask_test.py",
+ ],
+ imports = ["../../.."],
+ main = "pp_flask_test.py",
+ deps = [
+ ":decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/testing:common_lib",
+ "@common_flask//:pkg",
+ "@common_marshmallow//:pkg",
+ ],
+)
+
+py_test(
+ name = "retry_test",
+ srcs = [
+ "retry_test.py",
+ ],
+ imports = ["../../.."],
+ main = "retry_test.py",
+ deps = [
+ ":decorators_lib",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators/lru_cache.py b/web_console_v2/api/fedlearner_webconsole/utils/decorators/lru_cache.py
new file mode 100644
index 000000000..bfd92ef03
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/decorators/lru_cache.py
@@ -0,0 +1,49 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+from datetime import datetime, timedelta
+
+
+# TODO(xiangyuxuan): use custom lru to implement cache_clear(key)
+def lru_cache(timeout: int = 600, maxsize: int = 10000):
+ """Extension of functools lru_cache with a timeout.
+
+ Notice!: Do not use this decorator in class methods, or it will leak memory.
+ https://stackoverflow.com/questions/1227121/
+ compare-object-instances-for-equality-by-their-attributes
+
+ Args:
+ timeout (int): Timeout in seconds to clear the WHOLE cache, default = 10 minutes
+ maxsize (int): Maximum Size of the Cache
+ """
+
+ def wrapper_cache(func):
+ func = functools.lru_cache(maxsize=maxsize)(func)
+ func.delta = timedelta(seconds=timeout)
+ func.expiration = datetime.utcnow() + func.delta
+
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ if datetime.utcnow() >= func.expiration:
+ func.cache_clear()
+ func.expiration = datetime.utcnow() + func.delta
+
+ return func(*args, **kwargs)
+
+ wrapped_func.cache_clear = func.cache_clear
+ return wrapped_func
+
+ return wrapper_cache
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators/lru_cache_test.py b/web_console_v2/api/fedlearner_webconsole/utils/decorators/lru_cache_test.py
new file mode 100644
index 000000000..0776b275c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/decorators/lru_cache_test.py
@@ -0,0 +1,57 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import unittest
+
+from fedlearner_webconsole.utils.decorators.lru_cache import lru_cache
+
+
+class LruCacheTest(unittest.TestCase):
+
+ def test_lru_cache(self):
+ count = 0
+ count2 = 0
+
+ @lru_cache(timeout=1)
+ def test(arg1):
+ nonlocal count
+ count += 1
+ return count
+
+ @lru_cache(timeout=10)
+ def test_another(arg2):
+ nonlocal count2
+ count2 += 1
+ return count2
+
+ self.assertEqual(test(1), 1)
+ self.assertEqual(test(1), 1)
+
+ self.assertEqual(test(-1), 2)
+ self.assertEqual(test(-1), 2)
+
+ self.assertEqual(test_another(1), 1)
+ self.assertEqual(test_another(1), 1)
+
+ # test cache expired
+ time.sleep(1)
+ self.assertEqual(test(1), 3)
+ self.assertEqual(test(-1), 4)
+ self.assertEqual(test_another(1), 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators/pp_flask.py b/web_console_v2/api/fedlearner_webconsole/utils/decorators/pp_flask.py
new file mode 100644
index 000000000..214d11bda
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/decorators/pp_flask.py
@@ -0,0 +1,91 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import re
+from functools import wraps
+from flask import request
+from marshmallow import EXCLUDE
+from webargs.flaskparser import FlaskParser
+
+from fedlearner_webconsole.utils.flask_utils import get_current_user
+from fedlearner_webconsole.auth.models import Role
+from fedlearner_webconsole.exceptions import InvalidArgumentException, \
+ UnauthorizedException
+
+# TODO(xiangyuxuan.prs): valid Kubernetes Object with its own regex
+# [DNS Subdomain Names](https://kubernetes.io/docs/concepts/overview/working-with-objects/names/)
+# Regex to match the pattern:
+# Start/end with English/Chinese characters or numbers
+# other content could be English/Chinese character, -, _ or numbers
+# Max length 64
+UNIVERSAL_NAME_PATTERN = r'^[a-zA-Z0-9\u4e00-\u9fa5]' \
+ r'[a-zA-Z0-9\u4e00-\u9fa5\-_\.]' \
+ r'{0,62}[a-zA-Z0-9\u4e00-\u9fa5]$'
+MAX_COMMENT_LENGTH = 200
+
+
+def admin_required(f):
+
+ @wraps(f)
+ def wrapper_inside(*args, **kwargs):
+ current_user = get_current_user()
+ if current_user.role != Role.ADMIN:
+ raise UnauthorizedException('only admin can operate this')
+ return f(*args, **kwargs)
+
+ return wrapper_inside
+
+
+def input_validator(f):
+
+ @wraps(f)
+ def wrapper_inside(*args, **kwargs):
+ if hasattr(request, 'content_type') and request.content_type.startswith('multipart/form-data'):
+ params = request.form
+ else:
+ params = request.get_json() or {}
+ name = params.get('name', None)
+ comment = params.get('comment', '')
+ if name is not None:
+ _validate_name(name)
+ if comment:
+ _validate_comment(comment)
+ return f(*args, **kwargs)
+
+ return wrapper_inside
+
+
+def _validate_name(name: str):
+ if re.match(UNIVERSAL_NAME_PATTERN, name) is None:
+ raise InvalidArgumentException(f'Invalid name {name}: Must start/end'
+ f' with uppercase and lowercase letters,'
+ f' numbers or Chinese characters, could'
+ f' contain - or _ in the middle, and '
+ f'max length is 63 characters. ')
+
+
+def _validate_comment(comment: str):
+ if len(comment) > MAX_COMMENT_LENGTH:
+ raise InvalidArgumentException(f'Input comment too long, max length' f' is {MAX_COMMENT_LENGTH}')
+
+
+# Ref: https://webargs.readthedocs.io/en/latest/advanced.html#default-unknown
+class _Parser(FlaskParser):
+ DEFAULT_UNKNOWN_BY_LOCATION = {'query': EXCLUDE, 'json': EXCLUDE, 'form': EXCLUDE}
+
+
+parser = _Parser()
+use_args = parser.use_args
+use_kwargs = parser.use_kwargs
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators/pp_flask_test.py b/web_console_v2/api/fedlearner_webconsole/utils/decorators/pp_flask_test.py
new file mode 100644
index 000000000..953548d40
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/decorators/pp_flask_test.py
@@ -0,0 +1,119 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from http import HTTPStatus
+
+import flask
+import unittest
+from unittest.mock import patch
+
+from marshmallow import fields
+
+from fedlearner_webconsole.auth.models import User, Role
+from fedlearner_webconsole.utils.decorators.pp_flask import admin_required, input_validator, use_args, use_kwargs
+from fedlearner_webconsole.exceptions import InvalidArgumentException, UnauthorizedException
+from fedlearner_webconsole.utils.flask_utils import make_flask_response
+from testing.common import BaseTestCase
+
+
+@admin_required
+def some_authorized_login():
+ return 1
+
+
+@input_validator
+def test_func():
+ return 1
+
+
+class FlaskTest(unittest.TestCase):
+
+ @staticmethod
+ def generator_helper(inject_res):
+ for r in inject_res:
+ yield r
+
+ @patch('fedlearner_webconsole.utils.decorators.pp_flask.get_current_user')
+ def test_admin_required(self, mock_get_current_user):
+ admin = User(id=0, username='adamin', password='admin', role=Role.ADMIN)
+ user = User(id=1, username='ada', password='ada', role=Role.USER)
+ mock_get_current_user.return_value = admin
+ self.assertTrue(some_authorized_login() == 1)
+
+ mock_get_current_user.return_value = user
+ self.assertRaises(UnauthorizedException, some_authorized_login)
+
+ def test_input_validator(self):
+ app = flask.Flask(__name__)
+ with app.test_request_context('/', json={'name': 'valid_name', 'comment': 'valid comment'}):
+ self.assertTrue(test_func() == 1)
+ with app.test_request_context('/', json={'name': '', 'comment': 'valid comment'}):
+ self.assertRaises(InvalidArgumentException, test_func)
+ with app.test_request_context('/', json={'name': '???invalid_name', 'comment': 'valid comment'}):
+ self.assertRaises(InvalidArgumentException, test_func)
+ with app.test_request_context('/', json={'name': 'a' * 65, 'comment': 'valid comment'}):
+ self.assertRaises(InvalidArgumentException, test_func)
+ with app.test_request_context('/', json={'name': 'valid_name', 'comment': 'a' * 201}):
+ self.assertRaises(InvalidArgumentException, test_func)
+ with app.test_request_context('/', json={'name': 'valid_name'}):
+ self.assertTrue(test_func() == 1)
+ with app.test_request_context('/', json={'unrelated': '??'}):
+ self.assertTrue(test_func() == 1)
+ with app.test_request_context('/', json={'name': 'valid_name.test'}):
+ self.assertTrue(test_func() == 1)
+
+
+class ParserTest(BaseTestCase):
+
+ def test_unknown_query(self):
+
+ @self.app.route('/hello')
+ @use_args({'msg': fields.String(required=True)}, location='query')
+ def test_route(params):
+ return make_flask_response({'msg': params['msg']})
+
+ resp = self.get_helper('/hello?msg=123&unknown=fff', use_auth=False)
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'msg': '123'})
+
+ def test_unknown_body(self):
+
+ @self.app.route('/test_create', methods=['POST'])
+ @use_kwargs({'msg': fields.String(required=True)}, location='json')
+ def test_route(msg: str):
+ return make_flask_response({'msg': msg})
+
+ resp = self.post_helper('/test_create?ufj=4', use_auth=False, data={
+ 'msg': 'hello',
+ 'unknown': 'fasdf',
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ self.assertResponseDataEqual(resp, {'msg': 'hello'})
+
+ def test_invalid_parameter(self):
+
+ @self.app.route('/test', methods=['POST'])
+ @use_kwargs({'n': fields.Integer(required=True)}, location='json')
+ def test_route(n: int):
+ return make_flask_response({'n': n})
+
+ resp = self.post_helper('/test', use_auth=False, data={
+ 'n': 'hello',
+ })
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators/retry.py b/web_console_v2/api/fedlearner_webconsole/utils/decorators/retry.py
new file mode 100644
index 000000000..6252f013b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/decorators/retry.py
@@ -0,0 +1,74 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Optional, Callable
+
+import time
+from functools import wraps
+
+
+def _default_need_retry(unused_exception: Exception) -> bool:
+ del unused_exception
+ # Retry for all exceptions
+ return True
+
+
+def retry_fn(retry_times: int = 3,
+ delay: int = 0,
+ backoff: float = 1.0,
+ need_retry: Optional[Callable[[Exception], bool]] = None):
+ """A function to generate a decorator for retry.
+
+ Args:
+ retry_times: Times to try.
+ delay: Intervals in milliseconds between attempts, default is 0 (no delay).
+ backoff: Multiplier on the delay between attempts, default is 1. For example, if delay is set to 1000,
+ and backoff is 2, then the first retry will be delayed 1000ms, and second one will be 2000ms,
+ third one will be 4000ms.
+ need_retry: A callable function to check if the raised exception will trigger retry or not.
+ """
+
+ def decorator_retry_fn(f):
+ nonlocal need_retry
+ if need_retry is None:
+ # By default retry for all exceptions
+ need_retry = _default_need_retry
+
+ @wraps(f)
+ def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
+ nonlocal delay
+ # NOTE: a local variable for delay is a MUST-HAVE, if you reuse the delay in the parameter,
+ # it will be accumulated for all function executions.
+ local_delay = delay
+ for i in range(retry_times):
+ try:
+ return f(*args, **kwargs)
+ except Exception as e: # pylint: disable=broad-except
+ # Re-raise if there is no need for retry
+ if not need_retry(e):
+ raise
+ logging.exception(
+ f'Call function failed, retrying {i + 1} times...\n'
+ f'function name is {f.__name__}, args are {repr(args)}, kwargs are {repr(kwargs)}')
+ if i == retry_times - 1:
+ raise
+ if local_delay > 0:
+ time.sleep(local_delay / 1000.0)
+ local_delay = local_delay * backoff
+
+ return wrapper
+
+ return decorator_retry_fn
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/decorators/retry_test.py b/web_console_v2/api/fedlearner_webconsole/utils/decorators/retry_test.py
new file mode 100644
index 000000000..5aa751f19
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/decorators/retry_test.py
@@ -0,0 +1,105 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch, Mock
+
+from fedlearner_webconsole.utils.decorators.retry import retry_fn
+
+
+class RpcError(Exception):
+
+ def __init__(self, status: int = 0):
+ super().__init__()
+ self.status = status
+
+
+def some_unstable_connect(grpc_call):
+ res = grpc_call()
+ if res['status'] != 0:
+ raise RpcError(res['status'])
+ return res['data']
+
+
+class RetryTest(unittest.TestCase):
+
+ def test_retry_fn(self):
+
+ @retry_fn(retry_times=2, need_retry=lambda e: isinstance(e, RpcError))
+ def retry_twice(grpc_call):
+ return some_unstable_connect(grpc_call)
+
+ grpc_call = MagicMock()
+ grpc_call.side_effect = [{'status': -1, 'data': 'hhhhhh'}, {'status': -1, 'data': 'hhhh'}]
+ with self.assertRaises(RpcError):
+ retry_twice(grpc_call=grpc_call)
+
+ grpc_call = MagicMock()
+ grpc_call.side_effect = [{'status': -1, 'data': 'hhhhhh'}, {'status': 0, 'data': 'hhhh'}]
+ self.assertEqual(retry_twice(grpc_call=grpc_call), 'hhhh')
+
+ @patch('fedlearner_webconsole.utils.decorators.retry.time.sleep')
+ def test_retry_fn_with_delay(self, mock_sleep: Mock):
+ sleep_time = 0
+
+ def fake_sleep(s):
+ nonlocal sleep_time
+ sleep_time = sleep_time + s
+
+ mock_sleep.side_effect = fake_sleep
+
+ @retry_fn(retry_times=5, delay=1000, backoff=2)
+ def retry_with_delay(grpc_call):
+ return some_unstable_connect(grpc_call)
+
+ grpc_call = MagicMock()
+ grpc_call.return_value = {'status': 0, 'data': '123'}
+ self.assertEqual(retry_with_delay(grpc_call), '123')
+ mock_sleep.assert_not_called()
+
+ grpc_call = MagicMock()
+ grpc_call.side_effect = [{'status': 255}, {'status': -1}, {'status': 2}, {'status': 0, 'data': '123'}]
+ self.assertEqual(retry_with_delay(grpc_call), '123')
+ self.assertEqual(mock_sleep.call_count, 3)
+ # 1 + 2 + 4
+ self.assertEqual(sleep_time, 7)
+
+ # Failed case
+ sleep_time = 0
+ mock_sleep.reset_mock()
+ grpc_call = MagicMock()
+ grpc_call.side_effect = RuntimeError()
+ with self.assertRaises(RuntimeError):
+ retry_with_delay(grpc_call=grpc_call)
+ self.assertEqual(mock_sleep.call_count, 4)
+ # 1 + 2 + 4 + 8
+ self.assertEqual(sleep_time, 15)
+
+ def test_retry_fn_with_need_retry(self):
+
+ @retry_fn(retry_times=10, need_retry=lambda e: e.status == 3)
+ def custom_retry(grpc_call):
+ return some_unstable_connect(grpc_call)
+
+ grpc_call = MagicMock()
+ grpc_call.side_effect = [{'status': 3}, {'status': 3}, {'status': 5}, {'status': 6}]
+ with self.assertRaises(RpcError):
+ custom_retry(grpc_call)
+ # When status is 5, it will not retry again.
+ self.assertEqual(grpc_call.call_count, 3)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/domain_name.py b/web_console_v2/api/fedlearner_webconsole/utils/domain_name.py
new file mode 100644
index 000000000..c572c72b1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/domain_name.py
@@ -0,0 +1,33 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import re
+from typing import Optional
+
+
+def get_pure_domain_name(common_name: str) -> Optional[str]:
+ """Get domain name from common name filed in x.509
+
+ Args:
+ common_name (str): common name that parse from x.509
+
+ Returns:
+ str: domain name, like bytedance/bytedance-test
+ """
+ for regex in [r'.*fl-([^\.]+)(\.com)?', r'(.+)\.fedlearner\.net']:
+ matched = re.match(regex, common_name)
+ if matched:
+ return matched.group(1)
+ return None
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/domain_name_test.py b/web_console_v2/api/fedlearner_webconsole/utils/domain_name_test.py
new file mode 100644
index 000000000..6e652b54c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/domain_name_test.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.utils.domain_name import get_pure_domain_name
+
+
+class DomainNameTest(unittest.TestCase):
+
+ def test_get_pure_domain_name(self):
+ self.assertEqual(get_pure_domain_name('*.fl-bytedance.com'), 'bytedance')
+ self.assertEqual(get_pure_domain_name('fl-bytedance.com'), 'bytedance')
+ self.assertEqual(get_pure_domain_name('fl-bytedance-test'), 'bytedance-test')
+ self.assertEqual(get_pure_domain_name('bytedance.fedlearner.net'), 'bytedance')
+ self.assertIsNone(get_pure_domain_name('bytedancefedlearner.net'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/es.py b/web_console_v2/api/fedlearner_webconsole/utils/es.py
index 058d7fbf6..3e90ea39c 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/es.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/es.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
# coding: utf-8
# pylint: disable=invalid-string-quote
import json
+import time
+from typing import Dict, List, Optional
from elasticsearch import Elasticsearch
@@ -26,26 +28,45 @@ class ElasticSearchClient(object):
def __init__(self):
self._es_client = None
self._es_client = Elasticsearch([{
- 'host': Envs.ES_READ_HOST or Envs.ES_HOST,
+ 'host': Envs.ES_HOST,
'port': Envs.ES_PORT
}],
- http_auth=(Envs.ES_USERNAME,
- Envs.ES_PASSWORD))
+ http_auth=(Envs.ES_USERNAME, Envs.ES_PASSWORD),
+ timeout=10000)
def search(self, *args, **kwargs):
return self._es_client.search(*args, **kwargs)
def query_log(self,
- index,
- keyword,
- pod_name,
- start_time,
- end_time,
- match_phrase=None):
+ index: str,
+ keyword: str,
+ pod_name: str,
+ start_time: int = 0,
+ end_time: Optional[int] = None,
+ match_phrase: Optional[Dict[str, str]] = None) -> List[str]:
+ """query log from es
+
+ Args:
+ index (str): the es index you that you want to search from
+ keyword (str): some keyword you may want to filter
+ pod_name (str): the pod that you want to query
+ start_time (int, optional): start time for search range in microsecond
+ end_time (int, optional): end time for search range in microsecond. Defaults to None.
+ match_phrase (Dict[str, str], optional): match phrase. Defaults to None.
+
+ Returns:
+ List[str]: List for logs per line
+ """
+ end_time = end_time or int(time.time() * 1000)
query_body = {
'version': True,
'size': 8000,
'sort': [{
+ 'log.nanostimestamp': {
+ 'order': 'desc',
+ 'unmapped_type': 'long'
+ }
+ }, {
'@timestamp': 'desc'
}, {
'log.offset': {
@@ -66,7 +87,7 @@ def query_log(self,
'query': keyword,
'analyze_wildcard': True,
'default_operator': 'AND',
- 'default_field': '*'
+ 'default_field': 'message'
}
}] if keyword else []
match_phrase_list = [
@@ -88,17 +109,16 @@ def query_log(self,
response = self._es_client.search(index=index, body=query_body)
return [item['_source']['message'] for item in response['hits']['hits']]
- def query_events(self,
- index,
- keyword,
- pod_name,
- start_time,
- end_time,
- match_phrase=None):
+ def query_events(self, index, keyword, pod_name, start_time, end_time, match_phrase=None):
query_body = {
'version': True,
'size': 8000,
'sort': [{
+ 'log.nanostimestamp': {
+ 'order': 'desc',
+ 'unmapped_type': 'long'
+ }
+ }, {
'@timestamp': 'desc'
}, {
'log.offset': {
@@ -119,7 +139,7 @@ def query_events(self,
'query': f'{keyword} AND Event',
'analyze_wildcard': True,
'default_operator': 'AND',
- 'default_field': '*'
+ 'default_field': 'message'
}
}] if keyword else []
match_phrase_list = [
@@ -141,11 +161,7 @@ def query_events(self,
response = self._es_client.search(index=index, body=query_body)
return [item['_source']['message'] for item in response['hits']['hits']]
- def put_ilm(self,
- ilm_name,
- hot_size='50gb',
- hot_age='10d',
- delete_age='30d'):
+ def put_ilm(self, ilm_name, hot_size='50gb', hot_age='10d', delete_age='30d'):
if self._es_client is None:
raise RuntimeError('ES client not yet initialized.')
ilm_body = {
@@ -264,43 +280,50 @@ def query_data_join_metrics(self, job_name, num_buckets):
}
}
}
-
return es.search(index='data_join*', body=query)
- def query_nn_metrics(self, job_name, num_buckets):
+ def query_nn_metrics(self, job_name: str, metric_list: List[str], num_buckets: int = 30):
query = {
- "size": 0,
- "query": {
- "bool": {
- "must": [{
- "term": {
- "tags.application_id": job_name
+ 'size': 0,
+ 'query': {
+ 'bool': {
+ 'must': [{
+ 'term': {
+ 'tags.application_id': job_name
}
}, {
- "term": {
- "name": "auc"
+ 'terms': {
+ 'name': metric_list
}
}]
}
},
- "aggs": {
- "PROCESS_TIME": {
- "auto_date_histogram": {
- "field": "tags.process_time",
- "format": "strict_date_optional_time",
- "buckets": num_buckets
+ 'aggs': {
+ metric: {
+ 'filter': {
+ 'term': {
+ 'name': metric
+ }
},
- "aggs": {
- "AUC": {
- "avg": {
- "field": "value"
+ 'aggs': {
+ 'PROCESS_TIME': {
+ 'auto_date_histogram': {
+ 'field': 'tags.process_time',
+ 'format': 'strict_date_optional_time',
+ 'buckets': num_buckets
+ },
+ 'aggs': {
+ 'VALUE': {
+ 'avg': {
+ 'field': 'value'
+ }
+ }
}
}
}
- }
+ } for metric in metric_list
}
}
-
return es.search(index='metrics*', body=query)
def query_tree_metrics(self, job_name, metric_list):
@@ -340,7 +363,9 @@ def query_tree_metrics(self, job_name, metric_list):
"TOP": {
"top_hits": {
"size": 100,
- "sort": [{"tags.process_time": "asc"}],
+ "sort": [{
+ "tags.process_time": "asc"
+ }],
"_source": ["value", "tags.iteration"]
}
}
@@ -350,8 +375,7 @@ def query_tree_metrics(self, job_name, metric_list):
} for metric in metric_list
}
}
- response = es.search(index='metrics*', body=query)
- return response['aggregations']
+ return es.search(index='metrics*', body=query)
def query_time_metrics(self, job_name, num_buckets, index='raw_data*'):
query = {
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/es_misc.py b/web_console_v2/api/fedlearner_webconsole/utils/es_misc.py
index 87861097a..2915c308f 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/es_misc.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/es_misc.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
@@ -19,16 +19,14 @@
_es_datetime_format = 'strict_date_optional_time'
RAW_DATA_MAPPINGS = {
'dynamic': True,
- 'dynamic_templates': [
- {
- 'strings': {
- 'match_mapping_type': 'string',
- 'mapping': {
- 'type': 'keyword'
- }
+ 'dynamic_templates': [{
+ 'strings': {
+ 'match_mapping_type': 'string',
+ 'mapping': {
+ 'type': 'keyword'
}
}
- ],
+ }],
'properties': {
'tags': {
'properties': {
@@ -54,16 +52,14 @@
DATA_JOIN_MAPPINGS = {
'dynamic': True,
# for dynamically adding string fields, use keyword to reduce space
- 'dynamic_templates': [
- {
- 'strings': {
- 'match_mapping_type': 'string',
- 'mapping': {
- 'type': 'keyword'
- }
+ 'dynamic_templates': [{
+ 'strings': {
+ 'match_mapping_type': 'string',
+ 'mapping': {
+ 'type': 'keyword'
}
}
- ],
+ }],
'properties': {
'tags': {
'properties': {
@@ -105,16 +101,14 @@
}
METRICS_MAPPINGS = {
'dynamic': True,
- 'dynamic_templates': [
- {
- 'strings': {
- 'match_mapping_type': 'string',
- 'mapping': {
- 'type': 'keyword'
- }
+ 'dynamic_templates': [{
+ 'strings': {
+ 'match_mapping_type': 'string',
+ 'mapping': {
+ 'type': 'keyword'
}
}
- ],
+ }],
'properties': {
'name': {
'type': 'keyword'
@@ -155,33 +149,33 @@
}
}
}
-ALIAS_NAME = {'metrics': 'metrics_v2',
- 'raw_data': 'raw_data',
- 'data_join': 'data_join'}
-INDEX_MAP = {'metrics': METRICS_MAPPINGS,
- 'raw_data': RAW_DATA_MAPPINGS,
- 'data_join': DATA_JOIN_MAPPINGS}
+ALIAS_NAME = {'metrics': 'metrics_v2', 'raw_data': 'raw_data', 'data_join': 'data_join'}
+INDEX_MAP = {'metrics': METRICS_MAPPINGS, 'raw_data': RAW_DATA_MAPPINGS, 'data_join': DATA_JOIN_MAPPINGS}
def get_es_template(index_type, shards):
assert index_type in ALIAS_NAME
alias_name = ALIAS_NAME[index_type]
- template = {'index_patterns': ['{}-*'.format(alias_name)],
- 'settings': {
- 'index': {
- 'lifecycle': {
- 'name': 'fedlearner_{}_ilm'.format(index_type),
- 'rollover_alias': alias_name
- },
- 'codec': 'best_compression',
- 'routing': {
- 'allocation': {
- 'total_shards_per_node': '1'
- }
- },
- 'number_of_shards': str(shards),
- 'number_of_replicas': '1',
+ # pylint: disable=consider-using-f-string
+ template = {
+ 'index_patterns': ['{}-*'.format(alias_name)],
+ 'settings': {
+ 'index': {
+ 'lifecycle': {
+ 'name': 'fedlearner_{}_ilm'.format(index_type),
+ 'rollover_alias': alias_name
+ },
+ 'codec': 'best_compression',
+ 'routing': {
+ 'allocation': {
+ 'total_shards_per_node': '1'
}
},
- 'mappings': INDEX_MAP[index_type]}
+ 'number_of_shards': str(shards),
+ 'number_of_replicas': '1',
+ }
+ },
+ 'mappings': INDEX_MAP[index_type]
+ }
+ # pylint: enable=consider-using-f-string
return template
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/fake_k8s_client.py b/web_console_v2/api/fedlearner_webconsole/utils/fake_k8s_client.py
deleted file mode 100644
index 1c24824be..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/fake_k8s_client.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-# pylint: disable=logging-format-interpolation
-import logging
-import datetime
-from kubernetes import client
-
-_RAISE_EXCEPTION_KEY = 'raise_exception'
-
-
-class FakeK8sClient(object):
- """A fake k8s client for development.
-
- With this client we can decouple the dependency of k8s cluster.
- """
- def close(self):
- pass
-
- def create_or_update_secret(self,
- data,
- metadata,
- secret_type,
- name,
- namespace='default'):
- # User may pass two type of data:
- # 1. dictionary
- # 2. K8s Object
- # They are both accepted by real K8s client,
- # but K8s Object is not iterable.
- if isinstance(data, dict) and _RAISE_EXCEPTION_KEY in data:
- raise RuntimeError('[500] Fake exception for save_secret')
- # Otherwise succeeds
- logging.info('======================')
- logging.info('Saved a secret with: data: {}, '
- 'metadata: {}, type: {}'.format(data, metadata,
- secret_type))
-
- def delete_secret(self, name, namespace='default'):
- logging.info('======================')
- logging.info('Deleted a secret with: name: {}'.format(name))
-
- def get_secret(self, name, namespace='default'):
- return client.V1Secret(api_version='v1',
- data={'test': 'test'},
- kind='Secret',
- metadata={
- 'name': name,
- 'namespace': namespace
- },
- type='Opaque')
-
- def create_or_update_service(self,
- metadata,
- spec,
- name,
- namespace='default'):
- logging.info('======================')
- logging.info('Saved a service with: spec: {}, metadata: {}'.format(
- spec, metadata))
-
- def delete_service(self, name, namespace='default'):
- logging.info('======================')
- logging.info('Deleted a service with: name: {}'.format(name))
-
- def get_service(self, name, namespace='default'):
- return client.V1Service(
- api_version='v1',
- kind='Service',
- metadata=client.V1ObjectMeta(name=name, namespace=namespace),
- spec=client.V1ServiceSpec(selector={'app': 'nginx'}))
-
- def create_or_update_ingress(self,
- metadata,
- spec,
- name,
- namespace='default'):
- logging.info('======================')
- logging.info('Saved a ingress with: spec: {}, metadata: {}'.format(
- spec, metadata))
-
- def delete_ingress(self, name, namespace='default'):
- logging.info('======================')
- logging.info('Deleted a ingress with: name: {}'.format(name))
-
- def get_ingress(self, name, namespace='default'):
- return client.NetworkingV1beta1Ingress(
- api_version='networking.k8s.io/v1beta1',
- kind='Ingress',
- metadata=client.V1ObjectMeta(name=name, namespace=namespace),
- spec=client.NetworkingV1beta1IngressSpec())
-
- def create_or_update_deployment(self,
- metadata,
- spec,
- name,
- namespace='default'):
- logging.info('======================')
- logging.info('Saved a deployment with: spec: {}, metadata: {}'.format(
- spec, metadata))
-
- def delete_deployment(self, name, namespace='default'):
- logging.info('======================')
- logging.info('Deleted a deployment with: name: {}'.format(name))
-
- def get_deployment(self, name, namespace='default'):
- return client.V1Deployment(
- api_version='apps/v1',
- kind='Deployment',
- metadata=client.V1ObjectMeta(name=name, namespace=namespace),
- spec=client.V1DeploymentSpec(
- selector={'matchLabels': {
- 'app': 'fedlearner-operator'
- }},
- template=client.V1PodTemplateSpec(spec=client.V1PodSpec(
- containers=[
- client.V1Container(name='fedlearner-operator',
- args=['test'])
- ]))))
-
- def delete_flapp(self, flapp_name):
- pass
-
- def create_flapp(self, flapp_yaml):
- pass
-
- def get_flapp(self, flapp_name):
- pods = {
- 'pods': {
- 'metadata': {
- 'selfLink': '/api/v1/namespaces/default/pods',
- 'resourceVersion': '780480990'
- }
- },
- 'items': [{
- 'metadata': {
- 'name': '{}-0'.format(flapp_name)
- }
- }, {
- 'metadata': {
- 'name': '{}-1'.format(flapp_name)
- }
- }]
- }
- flapp = {
- 'kind': 'FLAPP',
- 'metadata': {
- 'name': flapp_name,
- 'namesapce': 'default'
- },
- 'status': {
- 'appState': 'FLStateRunning',
- 'flReplicaStatus': {
- 'Master': {
- 'active': {
- 'laomiao-raw-data-1223-v1-follower'
- '-master-0-717b53c4-'
- 'fef7-4d65-a309-63cf62494286': {}
- }
- },
- 'Worker': {
- 'active': {
- 'laomiao-raw-data-1223-v1-follower'
- '-worker-0-61e49961-'
- 'e6dd-4015-a246-b6d25e69a61c': {},
- 'laomiao-raw-data-1223-v1-follower'
- '-worker-1-accef16a-'
- '317f-440f-8f3f-7dd5b3552d25': {}
- }
- }
- }
- }
- }
- return {'flapp': flapp, 'pods': pods}
-
- def get_webshell_session(self,
- flapp_name,
- container_name: str,
- namespace='default'):
- return {'id': 1}
-
- def get_sparkapplication(self,
- name: str,
- namespace: str = 'default') -> dict:
- logging.info('======================')
- logging.info(
- f'get spark application, name: {name}, namespace: {namespace}')
- return {
- 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
- 'kind': 'SparkApplication',
- 'metadata': {
- 'creationTimestamp': '2021-04-15T10:43:15Z',
- 'generation': 1,
- 'name': name,
- 'namespace': namespace,
- },
- 'status': {
- 'applicationState': {
- 'state': 'COMPLETED'
- },
- }
- }
-
- def create_sparkapplication(self,
- json_object: dict,
- namespace: str = 'default') -> dict:
- logging.info('======================')
- logging.info(f'create spark application, namespace: {namespace}, '
- f'json: {json_object}')
- return {
- 'apiVersion': 'sparkoperator.k8s.io/v1beta2',
- 'kind': 'SparkApplication',
- 'metadata': {
- 'creationTimestamp': '2021-04-15T10:43:15Z',
- 'generation': 1,
- 'name': 'fl-transformer-yaml',
- 'namespace': 'fedlearner',
- 'resourceVersion': '348817823',
- },
- 'spec': {
- 'arguments': [
- 'hdfs://user/feature/data.csv',
- 'hdfs://user/feature/data_tfrecords/'
- ],
- }
- }
-
- def delete_sparkapplication(self,
- name: str,
- namespace: str = 'default') -> dict:
- logging.info('======================')
- logging.info(
- f'delete spark application, name: {name}, namespace: {namespace}')
- return {
- 'kind': 'Status',
- 'apiVersion': 'v1',
- 'metadata': {},
- 'status': 'Success',
- 'details': {
- 'name': name,
- 'group': 'sparkoperator.k8s.io',
- 'kind': 'sparkapplications',
- 'uid': '790603b6-9dd6-11eb-9282-b8599fb51ea8'
- }
- }
-
- def get_pod_log(self, name: str, namespace: str, tail_lines: int):
- return [str(datetime.datetime.now())]
-
- def get_pods(self, namespace, label_selector):
- return ['fake_fedlearner_web_console_v2']
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/file_manager.py b/web_console_v2/api/fedlearner_webconsole/utils/file_manager.py
index be9e51110..18420df2d 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/file_manager.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/file_manager.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,29 +17,38 @@
import logging
import os
import re
+import fsspec
from collections import namedtuple
-from typing import List
+from typing import List, Dict, Union, Optional
-from tensorflow.io import gfile
+from tensorflow.io import gfile # pylint: disable=import-error
+
+from envs import Envs
# path: absolute path of the file
# size: file size in bytes
# mtime: time of last modification, unix timestamp in seconds.
-File = namedtuple('File', ['path', 'size', 'mtime'])
-# Currently the supported format '/' or 'hdfs://'
+File = namedtuple('File', ['path', 'size', 'mtime', 'is_directory'])
+# Currently the supported format '/', 'hdfs://' or 'file://'
# TODO(chenyikan): Add oss format when verified.
-SUPPORTED_FILE_PREFIXES = r'\.+\/|^\/|^hdfs:\/\/'
+SUPPORTED_FILE_PREFIXES = r'\.+\/|^\/|^hdfs:\/\/|^file:\/\/'
+FILE_PREFIX = 'file://'
class FileManagerBase(object):
"""A base interface for file manager, please implement this interface
if you have specific logic to handle files, for example, HDFS with ACL."""
+
def can_handle(self, path: str) -> bool:
"""If the manager can handle such file."""
raise NotImplementedError()
- def ls(self, path: str, recursive=False) -> List[str]:
+ def info(self) -> Dict:
+ """Give details of entry at path."""
+ raise NotImplementedError()
+
+ def ls(self, path: str, include_directory=False) -> List[File]:
"""Lists files under a path.
Raises:
ValueError: When the path does not exist.
@@ -53,7 +62,7 @@ def move(self, source: str, destination: str) -> bool:
raise NotImplementedError()
def remove(self, path: str) -> bool:
- """Removes files under a path."""
+ """Removes files under a path. Raises exception when path is not exists"""
raise NotImplementedError()
def copy(self, source: str, destination: str) -> bool:
@@ -67,42 +76,101 @@ def mkdir(self, path: str) -> bool:
raise NotImplementedError()
def read(self, path: str) -> str:
+ """Read from a file path."""
+ raise NotImplementedError()
+
+ def read_bytes(self, path: str) -> bytes:
+ """Read from a file path by Bytes"""
+ raise NotImplementedError()
+
+ def write(self, path: str, payload: str, mode: str = 'w') -> bool:
+ """Write payload to a file path. Will override original content."""
+ raise NotImplementedError()
+
+ def exists(self, path: str) -> bool:
+ """Determine whether a path exists or not"""
+ raise NotImplementedError()
+
+ def isdir(self, path: str) -> bool:
+ """Return whether the path is a directory or not"""
+ raise NotImplementedError()
+
+ def listdir(self, path: str) -> List[str]:
+ """Return all file/directory names in this path, not recursive"""
+ raise NotImplementedError()
+
+ def rename(self, source: str, dest: str):
+ """Rename or move a file / directory"""
raise NotImplementedError()
class GFileFileManager(FileManagerBase):
"""Gfile file manager for all FS supported by TF,
currently it covers all file types we have."""
+
+ # TODO(gezhengqiang): change the class name
+ def __init__(self):
+ self._fs_dict = {}
+
+ def get_customized_fs(self, path: str) -> fsspec.spec.AbstractFileSystem:
+ """
+ Ref: https://filesystem-spec.readthedocs.io/en/latest/_modules/fsspec/core.html?highlight=split_protocol#
+ # >>> from fsspec.core import split_protocol
+ # >>> split_protocol('hdfs:///user/test')
+ # >>> ('hdfs', '/user/test')
+ """
+ protocol = self._get_protocol_from_path(path) or 'file'
+ if protocol not in self._fs_dict:
+ self._fs_dict[protocol] = fsspec.get_mapper(path).fs
+ return self._fs_dict[protocol]
+
def can_handle(self, path):
if path.startswith('fake://'):
return False
return re.match(SUPPORTED_FILE_PREFIXES, path)
- def ls(self, path: str, recursive=False) -> List[File]:
- def _get_file_stats(path: str):
- stat = gfile.stat(path)
- return File(path=path,
- size=stat.length,
- mtime=int(stat.mtime_nsec / 1e9))
-
- if not gfile.exists(path):
- raise ValueError(
- f'cannot access {path}: No such file or directory')
+ @staticmethod
+ def _get_protocol_from_path(path: str) -> Optional[str]:
+ """If path is '/data', then return None. If path is 'file:///data', then return 'file'."""
+ return fsspec.core.split_protocol(path)[0]
+
+ @staticmethod
+ def _get_file_stats_from_dict(file: Dict) -> File:
+ return File(path=file['path'],
+ size=file['size'],
+ mtime=int(file['mtime'] if 'mtime' in file else file['last_modified_time']),
+ is_directory=(file['type'] == 'directory'))
+
+ def info(self, path: str) -> str:
+ fs = self.get_customized_fs(path)
+ info = fs.info(path)
+ if 'last_modified' in info:
+ info['last_modified_time'] = info['last_modified']
+ return info
+
+ def ls(self, path: str, include_directory=False) -> List[File]:
+ fs = self.get_customized_fs(path)
+ if not fs.exists(path):
+ raise ValueError(f'cannot access {path}: No such file or directory')
# If it is a file
- if not gfile.isdir(path):
- return [_get_file_stats(path)]
+ info = self.info(path)
+ if info['type'] != 'directory':
+ info['path'] = path
+ return [self._get_file_stats_from_dict(info)]
files = []
- if recursive:
- for root, _, res in gfile.walk(path):
- for file in res:
- if not gfile.isdir(os.path.join(root, file)):
- files.append(_get_file_stats(os.path.join(root, file)))
- else:
- for file in gfile.listdir(path):
- if not gfile.isdir(os.path.join(path, file)):
- files.append(_get_file_stats(os.path.join(path, file)))
- # Files only
+ for file in fs.ls(path, detail=True):
+ # file['name'] from 'fs.ls' delete the protocol of the path,
+ # here use 'join' to obtain the file['path'] with protocol
+ base_path = self.info(path)['name'] # base_path does not have protocol
+ rel_path = os.path.relpath(file['name'], base_path) # file['name'] does not have protocol
+ file['path'] = os.path.join(path, rel_path) # file['path'] has protocol as well as path
+ if file['type'] == 'directory':
+ if include_directory:
+ files.append(self._get_file_stats_from_dict(file))
+ else:
+ files.append(self._get_file_stats_from_dict(file))
+
return files
def move(self, source: str, destination: str) -> bool:
@@ -112,16 +180,13 @@ def move(self, source: str, destination: str) -> bool:
def remove(self, path: str) -> bool:
if not gfile.isdir(path):
- return os.remove(path)
+ return gfile.remove(path)
return gfile.rmtree(path)
def copy(self, source: str, destination: str) -> bool:
if gfile.isdir(destination):
# gfile requires a file name for copy destination.
- return gfile.copy(source,
- os.path.join(destination,
- os.path.basename(source)),
- overwrite=True)
+ return gfile.copy(source, os.path.join(destination, os.path.basename(source)), overwrite=True)
return gfile.copy(source, destination, overwrite=True)
def mkdir(self, path: str) -> bool:
@@ -130,6 +195,33 @@ def mkdir(self, path: str) -> bool:
def read(self, path: str) -> str:
return gfile.GFile(path).read()
+ def read_bytes(self, path: str) -> bytes:
+ return gfile.GFile(path, 'rb').read()
+
+ def write(self, path: str, payload: str, mode: str = 'w') -> bool:
+ if gfile.isdir(path):
+ raise ValueError(f'{path} is a directory: Must provide a filename')
+ if gfile.exists(path):
+ self.remove(path)
+ if not gfile.exists(os.path.dirname(path)):
+ self.mkdir(os.path.dirname(path))
+ return gfile.GFile(path, mode).write(payload)
+
+ def exists(self, path: str) -> bool:
+ return gfile.exists(path)
+
+ def isdir(self, path: str) -> bool:
+ return gfile.isdir(path)
+
+ def listdir(self, path: str) -> List[str]:
+ """Return all file/directory names in this path, not recursive"""
+ if not gfile.isdir(path):
+ raise ValueError(f'{path} must be a directory!')
+ return gfile.listdir(path)
+
+ def rename(self, source: str, dest: str):
+ gfile.rename(source, dest)
+
class FileManager(FileManagerBase):
"""A centralized manager to handle files.
@@ -138,9 +230,10 @@ class FileManager(FileManagerBase):
`CUSTOMIZED_FILE_MANAGER`. For example,
'fedlearner_webconsole.utils.file_manager:HdfsFileManager'
"""
+
def __init__(self):
self._file_managers = []
- cfm_path = os.environ.get('CUSTOMIZED_FILE_MANAGER')
+ cfm_path = Envs.CUSTOMIZED_FILE_MANAGER
if cfm_path:
module_path, class_name = cfm_path.split(':')
module = importlib.import_module(module_path)
@@ -149,16 +242,22 @@ def __init__(self):
self._file_managers.append(customized_file_manager())
self._file_managers.append(GFileFileManager())
- def can_handle(self, path):
+ def can_handle(self, path) -> bool:
for fm in self._file_managers:
if fm.can_handle(path):
return True
return False
- def ls(self, path: str, recursive=False) -> List[File]:
+ def info(self, path: str) -> Dict:
+ for fm in self._file_managers:
+ if fm.can_handle(path):
+ return fm.info(path)
+ raise RuntimeError(f'info is not supported for {path}')
+
+ def ls(self, path: str, include_directory=False) -> List[File]:
for fm in self._file_managers:
if fm.can_handle(path):
- return fm.ls(path, recursive=recursive)
+ return fm.ls(path, include_directory=include_directory)
raise RuntimeError(f'ls is not supported for {path}')
def move(self, source: str, destination: str) -> bool:
@@ -167,8 +266,7 @@ def move(self, source: str, destination: str) -> bool:
if fm.can_handle(source) and fm.can_handle(destination):
return fm.move(source, destination)
# TODO(chenyikan): Support cross FileManager move by using buffers.
- raise RuntimeError(
- f'move is not supported for {source} and {destination}')
+ raise RuntimeError(f'move is not supported for {source} and {destination}')
def remove(self, path: str) -> bool:
logging.info('Removing file [%s]', path)
@@ -183,8 +281,7 @@ def copy(self, source: str, destination: str) -> bool:
if fm.can_handle(source) and fm.can_handle(destination):
return fm.copy(source, destination)
# TODO(chenyikan): Support cross FileManager move by using buffers.
- raise RuntimeError(
- f'copy is not supported for {source} and {destination}')
+ raise RuntimeError(f'copy is not supported for {source} and {destination}')
def mkdir(self, path: str) -> bool:
logging.info('Create directory [%s]', path)
@@ -199,3 +296,49 @@ def read(self, path: str) -> str:
if fm.can_handle(path):
return fm.read(path)
raise RuntimeError(f'read is not supported for {path}')
+
+ def read_bytes(self, path: str) -> bytes:
+ logging.info(f'Read file from [{path}]')
+ for fm in self._file_managers:
+ if fm.can_handle(path):
+ return fm.read_bytes(path)
+ raise RuntimeError(f'read_bytes is not supported for {path}')
+
+ def write(self, path: str, payload: Union[str, bytes], mode: str = 'w') -> bool:
+ logging.info(f'Write file to [{path}]')
+ for fm in self._file_managers:
+ if fm.can_handle(path):
+ return fm.write(path, payload, mode)
+ raise RuntimeError(f'write is not supported for {path}')
+
+ def exists(self, path: str) -> bool:
+ logging.info(f'Check [{path}] existence')
+ for fm in self._file_managers:
+ if fm.can_handle(path):
+ return fm.exists(path)
+ raise RuntimeError(f'check existence is not supported for {path}')
+
+ def isdir(self, path: str) -> bool:
+ logging.info(f'Determine whether [{path}] is a directory')
+ for fm in self._file_managers:
+ if fm.can_handle(path):
+ return fm.isdir(path)
+ raise RuntimeError(f'check isdir is not supported for {path}')
+
+ def listdir(self, path: str) -> List[str]:
+ logging.info(f'get file/directory names from [{path}]')
+ for fm in self._file_managers:
+ if fm.can_handle(path):
+ return fm.listdir(path)
+ raise RuntimeError(f'listdir is not supported for {path}')
+
+ def rename(self, source: str, dest: str):
+ logging.info(f'Rename[{source}] to [{dest}]')
+ for fm in self._file_managers:
+ if fm.can_handle(source):
+ fm.rename(source, dest)
+ return
+ raise RuntimeError(f'rename is not supported for {source}')
+
+
+file_manager = FileManager()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/file_manager_test.py b/web_console_v2/api/fedlearner_webconsole/utils/file_manager_test.py
new file mode 100644
index 000000000..60d28d3ed
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/file_manager_test.py
@@ -0,0 +1,232 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import os
+import shutil
+import tempfile
+import unittest
+
+from pathlib import Path
+from unittest.mock import patch
+from tensorflow.python.framework.errors_impl import NotFoundError, InvalidArgumentError
+
+from fedlearner_webconsole.utils.file_manager import GFileFileManager, FileManager, File
+
+
+class GFileFileManagerTest(unittest.TestCase):
+ _F1_SIZE = 3
+ _F2_SIZE = 3
+ _S1_SIZE = 3
+ _SUB_SIZE = 4096
+
+ def setUp(self):
+ # Create a temporary directory
+ self._test_dir = tempfile.mkdtemp()
+ subdir = Path(self._test_dir).joinpath('subdir')
+ subdir.mkdir(exist_ok=True)
+ Path(self._test_dir).joinpath('f1.txt').write_text('xxx', encoding='utf-8')
+ Path(self._test_dir).joinpath('f2.txt').write_text('xxx', encoding='utf-8')
+ subdir.joinpath('s1.txt').write_text('xxx', encoding='utf-8')
+
+ self._fm = GFileFileManager()
+
+ def tearDown(self):
+ # Remove the directory after the test
+ shutil.rmtree(self._test_dir)
+
+ def _assert_file(self, file: File, path: str, size: int, is_directory: bool):
+ self.assertEqual(file.path, path)
+ self.assertEqual(file.size, size)
+ self.assertEqual(file.is_directory, is_directory)
+
+ def _get_temp_path(self, file_path: str = None) -> str:
+ return str(Path(self._test_dir, file_path or '').absolute())
+
+ def test_can_handle(self):
+ self.assertTrue(self._fm.can_handle('/data/abc'))
+ self.assertFalse(self._fm.can_handle('data'))
+
+ def test_info(self):
+ info = self._fm.info(self._get_temp_path('f1.txt'))
+ self.assertEqual(info['name'], self._get_temp_path('f1.txt'))
+ self.assertEqual(info['type'], 'file')
+ with patch('fsspec.implementations.local.LocalFileSystem.info') as mock_info:
+ mock_info.return_value = {'last_modified': 1}
+ info = self._fm.info(self._get_temp_path('f1.txt'))
+ self.assertEqual(info, {'last_modified': 1, 'last_modified_time': 1})
+
+ def test_ls(self):
+ # List file
+ files = self._fm.ls(self._get_temp_path('f1.txt'))
+ self.assertEqual(len(files), 1)
+ self._assert_file(files[0], self._get_temp_path('f1.txt'), self._F1_SIZE, False)
+ with patch('fsspec.implementations.local.LocalFileSystem.info') as mock_info:
+ mock_info.return_value = {
+ 'name': self._get_temp_path('f1.txt'),
+ 'size': 3,
+ 'type': 'file',
+ 'last_modified': 1
+ }
+ files = self._fm.ls(self._get_temp_path('f1.txt'))
+ self._assert_file(files[0], self._get_temp_path('f1.txt'), self._F1_SIZE, False)
+ self.assertEqual(files[0].mtime, 1)
+ # List folder
+ files = sorted(self._fm.ls(self._get_temp_path()), key=lambda file: file.path)
+ self.assertEqual(len(files), 2)
+ self._assert_file(files[0], self._get_temp_path('f1.txt'), self._F1_SIZE, False)
+ self._assert_file(files[1], self._get_temp_path('f2.txt'), self._F2_SIZE, False)
+ # List directories
+ files = sorted(self._fm.ls(self._get_temp_path(), include_directory=True), key=lambda file: file.path)
+ self.assertEqual(len(files), 3)
+ self._assert_file(files[0], self._get_temp_path('f1.txt'), self._F1_SIZE, False)
+ self._assert_file(files[1], self._get_temp_path('f2.txt'), self._F2_SIZE, False)
+ self._assert_file(files[2], self._get_temp_path('subdir'), self._SUB_SIZE, True)
+
+ def test_ls_when_path_has_protocol(self):
+ path1 = 'file://' + self._get_temp_path('f1.txt')
+ files = self._fm.ls(path1)
+ self.assertEqual(len(files), 1)
+ self.assertEqual(files[0].path, path1)
+ path2 = 'file://' + self._get_temp_path()
+ files = sorted(self._fm.ls(path2), key=lambda file: file.path)
+ self.assertEqual(len(files), 2)
+ self.assertEqual(files[0].path, 'file://' + self._get_temp_path('f1.txt'))
+ self.assertEqual(files[1].path, 'file://' + self._get_temp_path('f2.txt'))
+ files = sorted(self._fm.ls(path2, include_directory=True), key=lambda file: file.path)
+ self.assertEqual(len(files), 3)
+ self.assertEqual(files[0].path, 'file://' + self._get_temp_path('f1.txt'))
+ self.assertEqual(files[1].path, 'file://' + self._get_temp_path('f2.txt'))
+ self.assertEqual(files[2].path, 'file://' + self._get_temp_path('subdir'))
+
+ def test_move(self):
+ # Moves to another folder
+ self._fm.move(self._get_temp_path('f1.txt'), self._get_temp_path('subdir/'))
+ files = sorted(self._fm.ls(self._get_temp_path('subdir')), key=lambda file: file.path)
+ self.assertEqual(len(files), 2)
+ self._assert_file(files[0], self._get_temp_path('subdir/f1.txt'), self._F1_SIZE, False)
+ self._assert_file(files[1], self._get_temp_path('subdir/s1.txt'), self._S1_SIZE, False)
+ # Renames
+ self._fm.move(self._get_temp_path('f2.txt'), self._get_temp_path('f3.txt'))
+ with self.assertRaises(ValueError):
+ self._fm.ls(self._get_temp_path('f2.txt'))
+ files = self._fm.ls(self._get_temp_path('f3.txt'))
+ self.assertEqual(len(files), 1)
+ self._assert_file(files[0], self._get_temp_path('f3.txt'), self._F2_SIZE, False)
+
+ def test_remove(self):
+ self._fm.remove(self._get_temp_path('f1.txt'))
+ self._fm.remove(self._get_temp_path('subdir'))
+ files = self._fm.ls(self._get_temp_path(), include_directory=True)
+ self.assertEqual(len(files), 1)
+ self._assert_file(files[0], self._get_temp_path('f2.txt'), self._F2_SIZE, False)
+
+ def test_copy(self):
+ self._fm.copy(self._get_temp_path('f1.txt'), self._get_temp_path('subdir'))
+ files = self._fm.ls(self._get_temp_path('f1.txt'))
+ self.assertEqual(len(files), 1)
+ self._assert_file(files[0], self._get_temp_path('f1.txt'), self._F1_SIZE, False)
+ files = self._fm.ls(self._get_temp_path('subdir/f1.txt'))
+ self.assertEqual(len(files), 1)
+ self._assert_file(files[0], self._get_temp_path('subdir/f1.txt'), self._F1_SIZE, False)
+
+ def test_mkdir(self):
+ self._fm.mkdir(os.path.join(self._get_temp_path(), 'subdir2'))
+ self.assertTrue(os.path.isdir(self._get_temp_path('subdir2')))
+
+ def test_read(self):
+ content = self._fm.read(self._get_temp_path('f1.txt'))
+ self.assertEqual('xxx', content)
+
+ def test_write(self):
+ self.assertRaises(ValueError, lambda: self._fm.write(self._get_temp_path(), 'aaa'))
+
+ first_write_content = 'aaaa'
+ second_write_content = 'bbbb'
+ self._fm.write(self._get_temp_path('abc/write.txt'), first_write_content)
+ self.assertEqual(first_write_content, self._fm.read(self._get_temp_path('abc/write.txt')))
+ self._fm.write(self._get_temp_path('abc/write.txt'), second_write_content)
+ self.assertEqual(second_write_content, self._fm.read(self._get_temp_path('abc/write.txt')))
+
+ def test_listdir(self):
+ names = self._fm.listdir(self._get_temp_path())
+ self.assertCountEqual(names, ['f1.txt', 'f2.txt', 'subdir'])
+ with self.assertRaises(ValueError):
+ self._fm.listdir(self._get_temp_path('not_exist_path'))
+
+ def test_rename(self):
+ first_write_content = 'aaaa'
+ self._fm.write(self._get_temp_path('abc/write.txt'), first_write_content)
+ self.assertRaises(
+ NotFoundError,
+ lambda: self._fm.rename(self._get_temp_path('abc/write.txt'), self._get_temp_path('abcd/write.txt')))
+ self._fm.rename(self._get_temp_path('abc/write.txt'), self._get_temp_path('read.txt'))
+ self.assertEqual(first_write_content, self._fm.read(self._get_temp_path('read.txt')))
+ self.assertRaises(InvalidArgumentError,
+ lambda: self._fm.rename(self._get_temp_path('abc'), self._get_temp_path('abc/abc')))
+ self.assertRaises(NotFoundError,
+ lambda: self._fm.rename(self._get_temp_path('abc'), self._get_temp_path('abcd/abc')))
+ self._fm.mkdir(self._get_temp_path('abcd'))
+ self._fm.rename(self._get_temp_path('abc'), self._get_temp_path('abcd/abcd'))
+ self.assertTrue(os.path.isdir(self._get_temp_path('abcd/abcd')))
+
+
+class FileManagerTest(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ fake_fm = 'testing.fake_file_manager:FakeFileManager'
+ self._patcher = patch('fedlearner_webconsole.utils.file_manager.Envs.CUSTOMIZED_FILE_MANAGER', fake_fm)
+ self._patcher.start()
+ self._fm = FileManager()
+
+ def tearDown(self):
+ self._patcher.stop()
+
+ def test_can_handle(self):
+ self.assertTrue(self._fm.can_handle('fake://123'))
+ # Falls back to default manager
+ self.assertTrue(self._fm.can_handle('/data/123'))
+ self.assertFalse(self._fm.can_handle('unsupported:///123'))
+
+ def test_ls(self):
+ self.assertEqual(self._fm.ls('fake://data'), [{'path': 'fake://data/f1.txt', 'size': 0}])
+
+ def test_move(self):
+ self.assertTrue(self._fm.move('fake://move/123', 'fake://move/234'))
+ self.assertFalse(self._fm.move('fake://do_not_move/123', 'fake://move/234'))
+ # No file manager can handle this
+ self.assertRaises(RuntimeError, lambda: self._fm.move('hdfs://123', 'fake://abc'))
+
+ def test_remove(self):
+ self.assertTrue(self._fm.remove('fake://remove/123'))
+ self.assertFalse(self._fm.remove('fake://do_not_remove/123'))
+ # No file manager can handle this
+ self.assertRaises(RuntimeError, lambda: self._fm.remove('unsupported://123'))
+
+ def test_copy(self):
+ self.assertTrue(self._fm.copy('fake://copy/123', 'fake://copy/234'))
+ self.assertFalse(self._fm.copy('fake://do_not_copy/123', 'fake://copy/234'))
+ # No file manager can handle this
+ self.assertRaises(RuntimeError, lambda: self._fm.copy('hdfs://123', 'fake://abc'))
+
+ def test_mkdir(self):
+ self.assertTrue(self._fm.mkdir('fake://mkdir/123'))
+ self.assertFalse(self._fm.mkdir('fake://do_not_mkdir/123'))
+ # No file manager can handle this
+ self.assertRaises(RuntimeError, lambda: self._fm.mkdir('unsupported:///123'))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/file_operator.py b/web_console_v2/api/fedlearner_webconsole/utils/file_operator.py
new file mode 100644
index 000000000..5af5f5b82
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/file_operator.py
@@ -0,0 +1,230 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+import os
+import tempfile
+
+import fsspec
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.utils.stream_untars import StreamingUntar
+from fedlearner_webconsole.utils.stream_tars import StreamingTar
+from tensorflow.io import gfile # pylint: disable=import-error
+from typing import Union
+
+HDFS_PREFIX = 'hdfs://'
+TAR_SUFFIX = ('.tar',)
+GZIP_SUFFIX = ('.gz', '.tgz')
+
+
+class FileOperator(object):
+
+ def __init__(self):
+ self._fm = FileManager()
+ self._streaming_untar = StreamingUntar(self._fm)
+ self._streaming_tar = StreamingTar(self._fm)
+
+ def clear_and_make_an_empty_dir(self, dir_name: str):
+ try:
+ self._fm.remove(dir_name)
+ except Exception as err: # pylint: disable=broad-except
+ logging.debug('failed to remove %s with exception %s', dir_name, err)
+ finally:
+ self._fm.mkdir(dir_name)
+
+ def getsize(self, path: str) -> float:
+ """Return all files size under path and dont skip the sybolic link
+ Args:
+ path (str): file/directory
+
+ Returns:
+ total_size (float): total size(B)
+ """
+ fs: fsspec.AbstractFileSystem = fsspec.get_mapper(path).fs
+
+ def get_dsize(dpath: str) -> int:
+ """Gets size for directory."""
+ total = 0
+ for sub_path in fs.ls(dpath, detail=True):
+ if sub_path.get('type') == 'directory':
+ total += get_dsize(sub_path.get('name'))
+ else:
+ total += sub_path.get('size', 0)
+ return total
+
+ if not fs.exists(path):
+ return 0
+ if fs.isdir(path):
+ return get_dsize(path)
+ # File
+ return fs.size(path)
+
+ def archive_to(self, source: Union[str, list], destination: str, gzip_compress: bool = False, move: bool = False):
+ """compress the file/directory to the destination tarfile/gzip file.
+ src and dst should be path-like objects or strings.
+ eg:
+
+ Args:
+ source (str): source file/directory
+ destination (str): tarfile/gzip file
+ gzip_compress (bool): if gzip_compress is true, will compress to gzip file
+ move (bool): if move is true, will delete source after archive
+ Raises:
+ ValueError: if destination tarfile not ends with .tar/.tar.gz
+ Exception: if io operation failed
+ """
+ logging.info(f'File Operator: will archive {source} to {destination}')
+ # check destination suffix
+ if not gzip_compress and not destination.endswith(TAR_SUFFIX):
+ logging.error(f'Error in archive_to: destination:{destination} is not endswith TAR_SUFFIX')
+ raise ValueError(f'destination:{destination} is not endswith TAR_SUFFIX')
+ if gzip_compress and not destination.endswith(GZIP_SUFFIX):
+ logging.error(f'Error in archive_to: destination:{destination} is not endswith GZIP_SUFFIX')
+ raise ValueError(f'destination:{destination} is not endswith GZIP_SUFFIX')
+ src_paths = source
+ if isinstance(source, str):
+ src_paths = [source]
+ # check the source list is on the same platform or not.
+ is_from_hdfs = src_paths[0].startswith(HDFS_PREFIX)
+ for src_path in src_paths:
+ if src_path.startswith(HDFS_PREFIX) != is_from_hdfs:
+ logging.error(f'Error in archive_to: source list:{source} is not on the same platform.')
+ raise ValueError(f'source list:{source} is not the same platform.')
+ is_to_hdfs = destination.startswith(HDFS_PREFIX)
+ is_hdfs = is_from_hdfs or is_to_hdfs
+ if is_hdfs:
+ # src_parent_dir/src_basename/xx -> tmp_dir/src_basename/xx -> tmp_dir/dest_basename -> dest
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_archive_path = os.path.join(tmp_dir, os.path.basename(destination))
+ tmp_src_paths = []
+ for src_path in src_paths:
+ tmp_src_path = os.path.join(tmp_dir, os.path.basename(src_path))
+ # if src_path is dir, copytree only copy the sub-items in src_path
+ if self._fm.isdir(src_path):
+ self._fm.mkdir(tmp_src_path)
+ self._copytree(src_path, tmp_src_path)
+ tmp_src_paths.append(tmp_src_path)
+ self._streaming_tar.archive(tmp_src_paths, tmp_archive_path, gzip_compress=gzip_compress)
+ self._fm.copy(tmp_archive_path, destination)
+ else:
+ self._streaming_tar.archive(source, destination, gzip_compress=gzip_compress)
+ if move:
+ self._fm.remove(source)
+
+ def extract_to(self, source: str, destination: str, create_dir: bool = False):
+ """extract the file to the directory dst. src and dst should be path-like objects or strings.
+
+ Args:
+ source (str): source file/directory/tarfile
+ destination (str): directory
+ create_dir (bool): if create_dir is true, will create the destination dir
+ Raises:
+ ValueError: if tarfile not ends with .tar/.tar.gz
+ Exception: if io operation failed
+ """
+ self.copy_to(source, destination, extract=True, move=False, create_dir=create_dir)
+
+ def copy_to(self,
+ source: str,
+ destination: str,
+ extract: bool = False,
+ move: bool = False,
+ create_dir: bool = False):
+ """Copies the file src to the directory dst. src and dst should be path-like objects or strings,
+ the file will be copied into dst using the base filename from src.
+
+ Args:
+ source (str): source file/directory/tarfile
+ destination (str): directory
+ extract (bool): extract source file if it is tarfile
+ move (bool): if move is true, will delete source after copy
+ create_dir (bool): if create_dir is true, will create the destination dir
+
+ Raises:
+ ValueError: if tarfile not ends with .tar/.tar.gz
+ Exception: if io operation failed
+ """
+ # create the destination dir
+ if create_dir and not self._fm.exists(destination):
+ self._fm.mkdir(destination)
+ if not self._fm.isdir(destination):
+ logging.error(f'Error in copy_to: destination:{destination} is not a existed directory')
+ raise ValueError(f'destination:{destination} is not a existed directory')
+ if not extract:
+ self._copytree(source, destination)
+ if move:
+ self._fm.remove(source)
+ return
+ is_hdfs = source.startswith(HDFS_PREFIX) or destination.startswith(HDFS_PREFIX)
+ if is_hdfs:
+ self._unpack_hdfs_tarfile(source, destination, is_move=move)
+ else:
+ self._unpack_tarfile(source, destination, is_move=move)
+
+ def _unpack_tarfile(self, filename: str, extract_dir: str, is_move: bool = False):
+ """Unpack tar/tar.gz/ `filename` to `extract_dir`
+ """
+ self._streaming_untar.untar(filename, extract_dir)
+ if is_move:
+ self._fm.remove(filename)
+
+ def _unpack_hdfs_tarfile(self, filename: str, extract_dir: str, is_move: bool = False):
+ """Unpack tar/tar.gz/ `filename` to `extract_dir`
+ will copy the tarfile locally to unzip it and then upload
+ """
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ try:
+ self._fm.copy(filename, tmp_dir)
+ tmp_tarfile = os.path.join(tmp_dir, os.path.basename(filename))
+ tmp_sub_dir = os.path.join(tmp_dir, 'tmp_sub_dir')
+ self._fm.mkdir(tmp_sub_dir)
+ self._streaming_untar.untar(tmp_tarfile, tmp_sub_dir)
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'failed to untar file {filename}, exception: {e}')
+ return
+ self._copytree(tmp_sub_dir, extract_dir)
+ if is_move:
+ self._fm.remove(filename)
+
+ def _copytree(self, source: str, dest: str):
+ """Recursively copy an entire directory tree rooted at src to a directory named dest
+
+ Args:
+ source (str): source file/directory/tarfile
+ dest (str): directory
+
+ Raises:
+ Exception: if io operation failed
+ """
+ # file
+ if self._fm.exists(source) and not self._fm.isdir(source):
+ self._fm.copy(source, dest)
+ # directory
+ # TODO(wangzeju): use file manager instead of gfile
+ for root, dirs, files in gfile.walk(source):
+ relative_path = os.path.relpath(root, source)
+ for f in files:
+ file_path = os.path.join(root, f)
+ dest_file = os.path.join(dest, relative_path, f)
+ try:
+ self._fm.copy(file_path, dest_file)
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'failed to copy file, from {file_path} to {dest_file}, ex: {e}')
+ for d in dirs:
+ dest_dir = os.path.join(dest, relative_path, d)
+ try:
+ self._fm.mkdir(dest_dir)
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'failed to mkdir {dest_dir}, ex: {e}')
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/file_operator_test.py b/web_console_v2/api/fedlearner_webconsole/utils/file_operator_test.py
new file mode 100644
index 000000000..12d8d52ae
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/file_operator_test.py
@@ -0,0 +1,89 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import os
+import tempfile
+import unittest
+from pathlib import Path
+
+from envs import Envs
+
+from fedlearner_webconsole.utils.file_operator import FileOperator
+from fedlearner_webconsole.utils.file_manager import FILE_PREFIX
+
+
+class FileOperatorTest(unittest.TestCase):
+
+ def test_copy(self):
+ fo = FileOperator()
+ source = os.path.join(Envs.BASE_DIR, 'testing/test_data/sparkapp.tar')
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ fo.copy_to(source, tmp_dir)
+ dest = os.path.join(tmp_dir, os.path.basename(source))
+ self.assertTrue(os.path.exists(dest), 'sparkapp.tar not found')
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ fo.copy_to(source, tmp_dir, extract=True)
+ self.assertTrue(len(os.listdir(tmp_dir)) > 0, 'sparkapp/ not found')
+
+ def test_getsize(self):
+ temp_dir = tempfile.mkdtemp()
+ # 1 byte
+ Path(temp_dir).joinpath('f1.txt').write_text('1', encoding='utf-8')
+ # 2 bytes
+ Path(temp_dir).joinpath('f2.txt').write_text('22', encoding='utf-8')
+ subdir = Path(temp_dir).joinpath('subdir')
+ subdir.mkdir(exist_ok=True)
+ # 3 bytes
+ Path(subdir).joinpath('f3.txt').write_text('333', encoding='utf-8')
+ fo = FileOperator()
+ # Folder
+ self.assertEqual(fo.getsize(str(Path(temp_dir).resolve())), 6)
+ # File
+ self.assertEqual(fo.getsize(str(Path(temp_dir).joinpath('f2.txt').resolve())), 2)
+ # Invalid path
+ self.assertEqual(fo.getsize('/invalidfolder/notexist'), 0)
+
+ def test_archive_to(self):
+ fo = FileOperator()
+ source = os.path.join(Envs.BASE_DIR, 'testing/test_data/algorithm/e2e_test')
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dest = os.path.join(tmp_dir, os.path.basename(source))
+ dest = dest + '.tar'
+ fo.archive_to(source, dest)
+ self.assertTrue(os.path.exists(dest), 'dest tar file not found')
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dest = os.path.join(tmp_dir, os.path.basename(source))
+ dest = dest + '.tar'
+ fo.archive_to(FILE_PREFIX + source, FILE_PREFIX + dest)
+ self.assertTrue(os.path.exists(dest), 'dest tar file not found')
+
+ def test_extract_to(self):
+ fo = FileOperator()
+ source = os.path.join(Envs.BASE_DIR, 'testing/test_data/sparkapp.tar')
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ fo.extract_to(source, tmp_dir)
+ dest = os.path.join(tmp_dir, 'class.csv')
+ self.assertTrue(os.path.exists(dest), 'dest tar file not found')
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ fo.extract_to(FILE_PREFIX + source, FILE_PREFIX + tmp_dir)
+ dest = os.path.join(tmp_dir, 'class.csv')
+ self.assertTrue(os.path.exists(dest), 'dest tar file not found')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/file_tree.py b/web_console_v2/api/fedlearner_webconsole/utils/file_tree.py
new file mode 100644
index 000000000..46542de22
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/file_tree.py
@@ -0,0 +1,71 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from typing import List
+from fedlearner_webconsole.utils.file_manager import FileManager
+from fedlearner_webconsole.proto.algorithm_pb2 import FileTreeNode
+
+
+# TODO(hangweiqiang): make it object oriented
+class FileTreeBuilder:
+
+ def __init__(self, path: str, relpath: bool = False):
+ self.file_manager = FileManager()
+ self.path = path
+ self.relpath = relpath
+
+ def _recursive_build(self, path: str) -> List[FileTreeNode]:
+ files = self.file_manager.ls(path, include_directory=True)
+ file_nodes = []
+ for file in files:
+ filename = os.path.split(file.path)[-1]
+ filepath = file.path # filepath has protocol
+ relpath = os.path.relpath(filepath, self.path) # relative path does not have protocol
+ tree_node = FileTreeNode(filename=filename,
+ path=relpath if self.relpath else filepath,
+ mtime=file.mtime,
+ size=file.size,
+ is_directory=file.is_directory)
+ if file.is_directory:
+ dir_path = os.path.join(self.path, relpath) # dir_path has protocol
+ files = self._recursive_build(path=dir_path) # path needs protocol
+ tree_node.files.extend(files)
+ file_nodes.append(tree_node)
+ return file_nodes
+
+ def build(self) -> List[FileTreeNode]:
+ return self._recursive_build(self.path)
+
+ def build_with_root(self) -> FileTreeNode:
+ info = self.file_manager.info(self.path)
+ filename = os.path.split(self.path)[-1]
+ root = FileTreeNode(filename=filename,
+ mtime=int(info['mtime'] if 'mtime' in info else info['last_modified_time']),
+ size=info['size'],
+ is_directory=(info['type'] == 'directory'))
+ root.files.extend(self._recursive_build(path=self.path))
+ return root
+
+ def _get_size(self, tree_node: FileTreeNode):
+ file_size = tree_node.size
+ if tree_node.is_directory:
+ for file in tree_node.files:
+ file_size += self._get_size(file)
+ return file_size
+
+ def size(self):
+ tree_nodes = self.build()
+ return sum([self._get_size(node) for node in tree_nodes])
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/file_tree_test.py b/web_console_v2/api/fedlearner_webconsole/utils/file_tree_test.py
new file mode 100644
index 000000000..726222b9d
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/file_tree_test.py
@@ -0,0 +1,163 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import tempfile
+from pathlib import Path
+from unittest.mock import patch
+from testing.common import BaseTestCase
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.utils.file_tree import FileTreeBuilder
+from fedlearner_webconsole.utils.file_manager import File
+
+
+class FileTreeTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ path = tempfile.mkdtemp()
+ path = Path(path, 'e2e_test').resolve()
+ self._base_path = str(path)
+ path.mkdir()
+ path.joinpath('follower').mkdir()
+ path.joinpath('leader').mkdir()
+ file_path = path.joinpath('leader').joinpath('main.py')
+ file_path.touch()
+ file_path.write_text('import tensorflow') # pylint: disable=unspecified-encoding
+
+ def test_build(self):
+ file_trees = FileTreeBuilder(self._base_path, relpath=True).build()
+ data = [to_dict(file_tree) for file_tree in file_trees]
+ data = sorted(data, key=lambda d: d['filename'])
+ self.assertPartiallyEqual(data[0], {
+ 'filename': 'follower',
+ 'path': 'follower',
+ 'is_directory': True,
+ 'files': []
+ },
+ ignore_fields=['mtime', 'size'])
+ self.assertPartiallyEqual(data[1], {
+ 'filename': 'leader',
+ 'path': 'leader',
+ 'is_directory': True
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+ self.assertPartiallyEqual(data[1]['files'][0], {
+ 'filename': 'main.py',
+ 'path': 'leader/main.py',
+ 'is_directory': False
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+
+ def test_build_with_root(self):
+ root = FileTreeBuilder(self._base_path, relpath=True).build_with_root()
+ data = to_dict(root)
+ self.assertPartiallyEqual(data, {
+ 'filename': 'e2e_test',
+ 'path': '',
+ 'is_directory': True
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+ files = data['files']
+ files = sorted(files, key=lambda f: f['filename'])
+ self.assertPartiallyEqual(files[0], {
+ 'filename': 'follower',
+ 'path': 'follower',
+ 'is_directory': True,
+ 'files': []
+ },
+ ignore_fields=['mtime', 'size'])
+ self.assertPartiallyEqual(files[1], {
+ 'filename': 'leader',
+ 'path': 'leader',
+ 'is_directory': True
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+ self.assertPartiallyEqual(files[1]['files'][0], {
+ 'filename': 'main.py',
+ 'path': 'leader/main.py',
+ 'is_directory': False
+ },
+ ignore_fields=['size', 'mtime', 'files'])
+
+ @patch('fedlearner_webconsole.utils.file_manager.GFileFileManager.info')
+ @patch('fedlearner_webconsole.utils.file_manager.GFileFileManager.ls')
+ def test_build_when_ls_corner_case(self, mock_ls, mock_info):
+ mock_ls.side_effect = [
+ [
+ File(path='hdfs://browser-hdfs/business/content-cloud/fedlearner/20221113/leader',
+ size=1,
+ is_directory=True,
+ mtime=1),
+ File(path='hdfs://browser-hdfs/business/content-cloud/fedlearner/20221113/leader.py',
+ size=1,
+ is_directory=False,
+ mtime=1),
+ File(path='hdfs://browser-hdfs/business/content-cloud/fedlearner/20221113/follower.py',
+ size=1,
+ is_directory=False,
+ mtime=1)
+ ],
+ [
+ File(path='hdfs://browser-hdfs/business/content-cloud/fedlearner/20221113/leader/main.py',
+ size=1,
+ is_directory=False,
+ mtime=1)
+ ]
+ ]
+ mock_info.side_effect = [{
+ 'name': '/business/content-cloud/fedlearner/20221113'
+ }, {
+ 'name': '/business/content-cloud/fedlearner/20221113'
+ }]
+ path = 'hdfs://browser-hdfs/business/content-cloud/fedlearner/20221113'
+ file_trees = FileTreeBuilder(path=path, relpath=True).build()
+ data = [to_dict(file_tree) for file_tree in file_trees]
+ data = sorted(data, key=lambda d: d['filename'])
+ self.assertPartiallyEqual(data[0], {
+ 'filename': 'follower.py',
+ 'path': 'follower.py',
+ 'is_directory': False,
+ 'files': []
+ },
+ ignore_fields=['mtime', 'size'])
+ self.assertPartiallyEqual(data[1], {
+ 'filename':
+ 'leader',
+ 'path':
+ 'leader',
+ 'is_directory':
+ True,
+ 'files': [{
+ 'filename': 'main.py',
+ 'path': 'leader/main.py',
+ 'size': 1,
+ 'mtime': 1,
+ 'is_directory': False,
+ 'files': []
+ }]
+ },
+ ignore_fields=['mtime', 'size'])
+ self.assertPartiallyEqual(data[2], {
+ 'filename': 'leader.py',
+ 'path': 'leader.py',
+ 'is_directory': False,
+ 'files': []
+ },
+ ignore_fields=['mtime', 'size'])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/filtering.py b/web_console_v2/api/fedlearner_webconsole/utils/filtering.py
new file mode 100644
index 000000000..efb2314e6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/filtering.py
@@ -0,0 +1,230 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+import logging
+from typing import Type, NamedTuple, Dict, Optional, Callable, Any
+
+from pyparsing import Keyword, replace_with, common, dbl_quoted_string, remove_quotes, Suppress, Group, Word, \
+ alphas, Literal, Forward, delimited_list, alphanums, Opt, ParseResults, ParseException
+from sqlalchemy import Column, and_
+from sqlalchemy.orm import Query
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, SimpleExpression, FilterOp, FilterExpressionKind
+
+# Using pyparsing to construct a DSL for filtering syntax.
+# Why not using regex to parse the expression directly? It has a lot of corner cases of handling
+# string literally, for example, we use brackets to split sub-expressions, but there may be ')('
+# in the value, so we need to try different parsing solution brute-forcefully, which is inefficient
+# and buggy. DSL is a more elegant way. Ref: https://pypi.org/project/pyparsing/
+# --------------------Grammar---------------------------
+# Names
+_SIMPLE_EXP_RNAME = 'simple_exp'
+_EXP_COMBINER_RNAME = 'exp_combiner'
+_SUB_EXPS_RNAME = 'sub_exps'
+_EXP_RNAME = 'exp'
+# Those values follow json standards,
+# ref: https://github.com/pyparsing/pyparsing/blob/master/examples/jsonParser.py
+_LEFT_SQUARE_BRACKET = Suppress('[')
+_RIGHT_SQUARE_BRACKET = Suppress(']')
+_TRUE = Literal('true').set_parse_action(replace_with(True))
+_FALSE = Literal('false').set_parse_action(replace_with(False))
+_BOOL_VALUE = _TRUE | _FALSE
+_STRING_VALUE = dbl_quoted_string().set_parse_action(remove_quotes)
+_NUMBER_VALUE = common.number()
+_NUMBER_LIST = Group(_LEFT_SQUARE_BRACKET + Opt(delimited_list(_NUMBER_VALUE, delim=',')) + _RIGHT_SQUARE_BRACKET,
+ aslist=True)
+_STRING_LIST = Group(_LEFT_SQUARE_BRACKET + Opt(delimited_list(_STRING_VALUE, delim=',')) + _RIGHT_SQUARE_BRACKET,
+ aslist=True)
+PRIMITIVE_VALUE = _BOOL_VALUE | _STRING_VALUE | _NUMBER_VALUE
+LIST_VALUE = _NUMBER_LIST | _STRING_LIST
+VALUE = PRIMITIVE_VALUE | LIST_VALUE
+
+_LEFT_BRACKET = Suppress('(')
+_RIGHT_BRACKET = Suppress(')')
+FIELD = Word(init_chars=alphas, body_chars=alphanums + '_' + '.', min=1)
+# IN op only support number list value
+_IN_EXP_MEMBER = FIELD + Literal(':') + LIST_VALUE
+_EQUAL_EXP_MEMBER = FIELD + Literal('=') + PRIMITIVE_VALUE
+_GREATER_THAN_EXP_MEMBER = FIELD + Literal('>') + _NUMBER_VALUE
+_LESS_THAN_EXP_MEMBER = FIELD + Literal('<') + _NUMBER_VALUE
+_CONTAIN_EXP_MEMBER = FIELD + Literal('~=') + _STRING_VALUE
+_EXP_MEMBER = _IN_EXP_MEMBER | _EQUAL_EXP_MEMBER | _CONTAIN_EXP_MEMBER | \
+ _GREATER_THAN_EXP_MEMBER | _LESS_THAN_EXP_MEMBER
+SIMPLE_EXP = Group(_LEFT_BRACKET + _EXP_MEMBER + _RIGHT_BRACKET).set_results_name(_SIMPLE_EXP_RNAME)
+
+EXP_COMBINER = Keyword('and').set_results_name(_EXP_COMBINER_RNAME)
+EXP = Forward()
+EXP <<= Group(SIMPLE_EXP | (_LEFT_BRACKET + EXP_COMBINER + Group(EXP[2, ...]).set_results_name(_SUB_EXPS_RNAME) +
+ _RIGHT_BRACKET)).set_results_name(_EXP_RNAME)
+# --------------------End of grammar--------------------
+
+
+def _build_simple_expression(parse_results: ParseResults) -> SimpleExpression:
+ """Builds simple expression by parsed result, ref to `SIMPLE_EXP`."""
+ field, op_str, typed_value = parse_results.as_list()
+
+ op = FilterOp.EQUAL
+ if op_str == ':':
+ op = FilterOp.IN
+ elif op_str == '~=':
+ op = FilterOp.CONTAIN
+ elif op_str == '>':
+ op = FilterOp.GREATER_THAN
+ elif op_str == '<':
+ op = FilterOp.LESS_THAN
+ exp = SimpleExpression(
+ field=field,
+ op=op,
+ )
+ if isinstance(typed_value, bool):
+ exp.bool_value = typed_value
+ elif isinstance(typed_value, str):
+ exp.string_value = typed_value
+ elif isinstance(typed_value, (int, float)):
+ exp.number_value = typed_value
+ elif isinstance(typed_value, list):
+ if len(typed_value) > 0 and isinstance(typed_value[0], str):
+ exp.list_value.string_list.extend(typed_value)
+ else:
+ exp.list_value.number_list.extend(typed_value)
+ else:
+ logging.warning('[FilterExpression] Unsupportd value: %s', typed_value)
+ raise ValueError(f'Unsupported value: {typed_value}')
+ return exp
+
+
+def _build_expression(exp: ParseResults) -> FilterExpression:
+ """Builds expression by parsed results, ref to `EXP`."""
+ if _SIMPLE_EXP_RNAME in exp:
+ return FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=_build_simple_expression(exp[_SIMPLE_EXP_RNAME]))
+ combiner = exp.get(_EXP_COMBINER_RNAME)
+ assert combiner == 'and', 'Combiner must be and as of now'
+ exp_pb = FilterExpression(kind=FilterExpressionKind.AND)
+ for sub_exp in exp.get(_SUB_EXPS_RNAME, []):
+ exp_pb.exps.append(_build_expression(sub_exp))
+ return exp_pb
+
+
+def parse_expression(exp: str) -> FilterExpression:
+ try:
+ parse_results = EXP.parse_string(exp, parse_all=True)
+ return _build_expression(parse_results[0])
+ except ParseException as e:
+ error_message = f'[FilterExpression] unsupported expression {exp}'
+ logging.exception(error_message)
+ raise ValueError(error_message) from e
+
+
+class FieldType(enum.Enum):
+ STRING = 'STRING'
+ NUMBER = 'NUMBER'
+ BOOL = 'BOOL'
+
+
+class SupportedField(NamedTuple):
+ # Field type
+ type: FieldType
+ # Supported ops, key is the op, value is the custom criterion builder.
+ ops: Dict['FilterOp', Optional[Callable[[SimpleExpression], Any]]]
+
+
+def validate_expression(exp: FilterExpression, supported_fields: Dict[str, SupportedField]):
+ """Validates if the expression is supported.
+
+ Raises:
+ ValueError: if the expression is unsupported.
+ """
+ if exp.kind == FilterExpressionKind.SIMPLE:
+ simple_exp = exp.simple_exp
+ if simple_exp.field not in supported_fields:
+ raise ValueError(f'Unsupported field {simple_exp.field}')
+ supported_field = supported_fields[simple_exp.field]
+ if simple_exp.op not in supported_field.ops:
+ raise ValueError(f'Unsupported op {FilterOp.Name(simple_exp.op)} on {simple_exp.field}')
+ pb_value_field = simple_exp.WhichOneof('value')
+ value_type = FieldType.STRING
+ if pb_value_field == 'bool_value':
+ value_type = FieldType.BOOL
+ elif pb_value_field == 'number_value':
+ value_type = FieldType.NUMBER
+ elif pb_value_field == 'list_value':
+ if len(simple_exp.list_value.number_list) > 0:
+ value_type = FieldType.NUMBER
+ else:
+ value_type = FieldType.STRING
+ if value_type != supported_field.type:
+ raise ValueError(
+ f'Type of {simple_exp.field} is invalid, expected {supported_field.type}, actual {value_type}')
+ return
+ for sub_exp in exp.exps:
+ validate_expression(sub_exp, supported_fields)
+
+
+class FilterBuilder(object):
+
+ def __init__(self, model_class: Type[db.Model], supported_fields: Dict[str, SupportedField]):
+ self.model_class = model_class
+ self.supported_fields = supported_fields
+
+ def _build_criterions(self, exp: FilterExpression):
+ """Builds sqlalchemy criterions for the filter expression."""
+ if exp.kind == FilterExpressionKind.SIMPLE:
+ simple_exp = exp.simple_exp
+ supported_field = self.supported_fields.get(simple_exp.field)
+ custom_builder = None
+ if supported_field:
+ custom_builder = supported_field.ops.get(simple_exp.op)
+ # Calls custom builder if it is specified
+ if callable(custom_builder):
+ return custom_builder(simple_exp)
+
+ column: Optional[Column] = getattr(self.model_class, simple_exp.field, None)
+ assert column is not None, f'{simple_exp.field} is not a column key'
+ if simple_exp.op == FilterOp.EQUAL:
+ pb_value_field = simple_exp.WhichOneof('value')
+ return column == getattr(simple_exp, pb_value_field)
+ if simple_exp.op == FilterOp.IN:
+ number_list = simple_exp.list_value.number_list
+ string_list = simple_exp.list_value.string_list
+ list_value = number_list
+ if len(string_list) > 0:
+ list_value = string_list
+ return column.in_(list_value)
+ if simple_exp.op == FilterOp.CONTAIN:
+ return column.ilike(f'%{simple_exp.string_value}%')
+ if simple_exp.op == FilterOp.GREATER_THAN:
+ return column > simple_exp.number_value
+ if simple_exp.op == FilterOp.LESS_THAN:
+ return column < simple_exp.number_value
+ raise ValueError(f'Unsupported filter op: {simple_exp.op}')
+ # AND-combined sub expressions
+ assert exp.kind == FilterExpressionKind.AND
+ criterions = [self._build_criterions(sub_exp) for sub_exp in exp.exps]
+ return and_(*criterions)
+
+ def build_query(self, query: Query, exp: FilterExpression) -> Query:
+ """Build query by expression.
+
+ Raises:
+ ValueError: if the expression is unsupported.
+ """
+ # A special case that the expression is empty
+ if exp.ByteSize() == 0:
+ return query
+ validate_expression(exp, self.supported_fields)
+ return query.filter(self._build_criterions(exp))
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/filtering_test.py b/web_console_v2/api/fedlearner_webconsole/utils/filtering_test.py
new file mode 100644
index 000000000..188c73d87
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/filtering_test.py
@@ -0,0 +1,453 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import Any
+
+from pyparsing import ParseException, ParserElement
+from sqlalchemy import and_
+
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, SimpleExpression, FilterOp, FilterExpressionKind
+from fedlearner_webconsole.utils.filtering import VALUE, SIMPLE_EXP, EXP, parse_expression, FilterBuilder, \
+ SupportedField, FieldType, validate_expression
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class TestModel(db.Model):
+ __tablename__ = 'test_table'
+ __table_args__ = (default_table_args('Test table'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True)
+ name = db.Column(db.String(255))
+ disabled = db.Column(db.Boolean, default=False)
+ amount = db.Column(db.Float, default=0)
+
+
+class DslTest(unittest.TestCase):
+
+ def _parse_single(self, e: ParserElement, s: str) -> Any:
+ results = e.parse_string(s, parse_all=True).as_list()
+ self.assertEqual(len(results), 1)
+ return results[0]
+
+ def test_bool_value(self):
+ self.assertEqual(self._parse_single(VALUE, 'true'), True)
+ self.assertEqual(self._parse_single(VALUE, ' false '), False)
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, 'True')
+
+ def test_string_value(self):
+ self.assertEqual(self._parse_single(VALUE, '"u你好🤩 ok"'), 'u你好🤩 ok')
+ self.assertEqual(self._parse_single(VALUE, '"hey"'), 'hey')
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, '\'single quote\'')
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, '"no quote pair')
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, 'no quote')
+
+ def test_number_value(self):
+ self.assertEqual(self._parse_single(VALUE, '01234'), 1234)
+ self.assertEqual(self._parse_single(VALUE, '-56.877'), -56.877)
+ self.assertEqual(self._parse_single(VALUE, '1e4'), 10000)
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, '2^20')
+
+ def test_number_list(self):
+ self.assertEqual(self._parse_single(VALUE, '[]'), [])
+ self.assertEqual(self._parse_single(VALUE, '[-2e2]'), [-200])
+ self.assertEqual(self._parse_single(VALUE, '[-1, +2.06, 3]'), [-1, 2.06, 3])
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, '-2 ]')
+
+ def test_string_list(self):
+ self.assertEqual(self._parse_single(VALUE, '["hello world"]'), ['hello world'])
+ self.assertEqual(self._parse_single(VALUE, '["🐷", "行\\\"卫\'qiang", "🤩"]'), ['🐷', '行\\"卫\'qiang', '🤩'])
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, '[\'hello\']')
+ with self.assertRaises(ParseException):
+ self._parse_single(VALUE, '["hello]')
+
+ def test_simple_exp_with_equal(self):
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(abc=123)'), ['abc', '=', 123])
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(a_b_c=false)'), ['a_b_c', '=', False])
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(x123="test值")'), ['x123', '=', 'test值'])
+ with self.assertRaises(ParseException):
+ # Without brackets
+ self._parse_single(SIMPLE_EXP, 'abc=123')
+ with self.assertRaises(ParseException):
+ # Invalid value
+ self._parse_single(SIMPLE_EXP, 'abc=abc')
+ with self.assertRaises(ParseException):
+ # List value is not supported for equal
+ self._parse_single(SIMPLE_EXP, '(f=[1,-2])')
+
+ def test_simple_exp_with_in(self):
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(abc:[1,-2])'), ['abc', ':', [1, -2]])
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(x12_3:[2.3333])'), ['x12_3', ':', [2.3333]])
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(s1:["h","w"])'), ['s1', ':', ['h', 'w']])
+ with self.assertRaises(ParseException):
+ # Without brackets
+ self._parse_single(SIMPLE_EXP, 'abc:[-1]')
+ with self.assertRaises(ParseException):
+ # Primitive value is not supported
+ self._parse_single(SIMPLE_EXP, '(f:"hello")')
+
+ def test_simple_exp_with_greater_than(self):
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(start_at>123)'), ['start_at', '>', 123])
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(amount>-1.2)'), ['amount', '>', -1.2])
+ with self.assertRaises(ParseException):
+ # String value is not supported
+ self._parse_single(SIMPLE_EXP, '(s>"hello")')
+
+ def test_simple_exp_with_less_than(self):
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(end_at<187777)'), ['end_at', '<', 187777])
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(amount < -1.23)'), ['amount', '<', -1.23])
+ with self.assertRaises(ParseException):
+ # String value is not supported
+ self._parse_single(SIMPLE_EXP, '(amount<"hello")')
+
+ def test_simple_exp_with_contain(self):
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(abc~="keyword")'), ['abc', '~=', 'keyword'])
+ self.assertEqual(self._parse_single(SIMPLE_EXP, '(a_b_c~="~=你好")'), ['a_b_c', '~=', '~=你好'])
+ with self.assertRaises(ParseException):
+ # Without brackets
+ self._parse_single(SIMPLE_EXP, 'abc~="keyword"')
+ with self.assertRaises(ParseException):
+ # Invalid value
+ self._parse_single(SIMPLE_EXP, 'abc~=abc')
+ with self.assertRaises(ParseException):
+ # List value is not supported
+ self._parse_single(SIMPLE_EXP, '(f~=["fff"])')
+
+ def test_exp_simple(self):
+ self.assertEqual(self._parse_single(EXP, '(a.b:[1,2,3])'), [['a.b', ':', [1, 2, 3]]])
+ self.assertEqual(self._parse_single(EXP, '(x123="h h")'), [['x123', '=', 'h h']])
+ self.assertEqual(self._parse_single(EXP, '(s1~="ooo")'), [['s1', '~=', 'ooo']])
+
+ def test_exp_and_combined(self):
+ result = self._parse_single(EXP, '(and(a:[1])(b=true)(c=")(")(d~="like"))')
+ self.assertEqual(result,
+ ['and', [[['a', ':', [1]]], [['b', '=', True]], [['c', '=', ')(']], [['d', '~=', 'like']]]])
+ with self.assertRaises(ParseException):
+ # No brackets
+ self._parse_single(EXP, 'and(f=false)(x=true)')
+ with self.assertRaises(ParseException):
+ # Only one sub-exp
+ self._parse_single(EXP, '(and(f=false))')
+ with self.assertRaises(ParseException):
+ # Invalid value
+ self._parse_single(EXP, '(and(f=false)(x=)())')
+
+ def test_exp_nested(self):
+ result = self._parse_single(EXP, '(and(a:[1,2])(and(x1=true)(x2=false)(x3~="ss"))(and(y1="and()")(y2="中文")))')
+ self.assertEqual(result, [
+ 'and',
+ [[['a', ':', [1, 2]]], ['and', [[['x1', '=', True]], [['x2', '=', False]], [['x3', '~=', 'ss']]]],
+ ['and', [[['y1', '=', 'and()']], [['y2', '=', '中文']]]]]
+ ])
+
+
+class ParseExpressionTest(unittest.TestCase):
+
+ def test_simple_expression(self):
+ self.assertEqual(
+ parse_expression('(test_field="hey 🐷")'),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='test_field', op=FilterOp.EQUAL, string_value='hey 🐷')))
+ self.assertEqual(
+ parse_expression('(test_field:[ -2, 3 ])'),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='test_field',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(number_list=[-2, 3]))))
+ self.assertEqual(
+ parse_expression('(test_field:["你 好", "🐷"])'),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='test_field',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(string_list=['你 好', '🐷']))))
+ self.assertEqual(
+ parse_expression('(test_field~="like")'),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='test_field', op=FilterOp.CONTAIN, string_value='like')))
+ self.assertEqual(
+ parse_expression('(start_at > 123)'),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='start_at', op=FilterOp.GREATER_THAN, number_value=123)))
+ self.assertEqual(
+ parse_expression('(test_field<-12.3)'),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='test_field', op=FilterOp.LESS_THAN,
+ number_value=-12.3)))
+
+ def test_and_expression(self):
+ self.assertEqual(
+ parse_expression('(and(x1="床前明月光")(x2:["o","y"])(x3=true))'),
+ FilterExpression(kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='x1',
+ op=FilterOp.EQUAL,
+ string_value='床前明月光')),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='x2',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(string_list=['o', 'y']))),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='x3',
+ op=FilterOp.EQUAL,
+ bool_value=True))
+ ]))
+
+ def test_nested_expression(self):
+ self.assertEqual(
+ parse_expression('(and(and(x1="(and(x1=true)(x2=false))")(x2:[1]))(and(x3=false)(x4=1.1e3))(x5~="x"))'),
+ FilterExpression(kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='x1',
+ op=FilterOp.EQUAL,
+ string_value='(and(x1=true)(x2=false))')),
+ FilterExpression(
+ kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='x2',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(number_list=[1]))),
+ ]),
+ FilterExpression(kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='x3',
+ op=FilterOp.EQUAL,
+ bool_value=False)),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='x4',
+ op=FilterOp.EQUAL,
+ number_value=1100)),
+ ]),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='x5',
+ op=FilterOp.CONTAIN,
+ string_value='x')),
+ ]))
+
+ def test_invalid_expression(self):
+ with self.assertRaises(ValueError):
+ # No brackets
+ parse_expression('x1=true')
+ with self.assertRaises(ValueError):
+ # No brackets
+ parse_expression('and(x1=true)(x2=false)')
+ with self.assertRaises(ValueError):
+ # Only one sub expression
+ parse_expression('(and(x1=true))')
+
+
+class ValidateExpressionTest(unittest.TestCase):
+ SUPPORTED_FIELDS = {
+ 'f1': SupportedField(type=FieldType.NUMBER, ops={FilterOp.IN: None}),
+ 'f2': SupportedField(type=FieldType.STRING, ops={
+ FilterOp.EQUAL: lambda exp: True,
+ FilterOp.IN: None
+ }),
+ 'f3': SupportedField(type=FieldType.BOOL, ops={FilterOp.EQUAL: None}),
+ 'f4': SupportedField(type=FieldType.STRING, ops={FilterOp.CONTAIN: None}),
+ }
+
+ def test_valid(self):
+ exp = FilterExpression(kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='f1',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(number_list=[1, 2]))),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='f2',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(string_list=['hello']))),
+ FilterExpression(kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='f2',
+ op=FilterOp.EQUAL,
+ string_value='hello')),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='f3',
+ op=FilterOp.EQUAL,
+ bool_value=True)),
+ ]),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='f4',
+ op=FilterOp.CONTAIN,
+ string_value='lifjfasdf asdf')),
+ ])
+ validate_expression(exp, self.SUPPORTED_FIELDS)
+
+ def test_unsupported_field(self):
+ exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='f123123',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(number_list=[1, 2])))
+ with self.assertRaisesRegex(ValueError, 'Unsupported field f123123'):
+ validate_expression(exp, self.SUPPORTED_FIELDS)
+
+ def test_unsupported_op(self):
+ exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='f3',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(number_list=[1, 2])))
+ with self.assertRaisesRegex(ValueError, 'Unsupported op IN on f3'):
+ validate_expression(exp, self.SUPPORTED_FIELDS)
+
+ def test_invalid_type(self):
+ exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='f1', op=FilterOp.IN, string_value='invalid'))
+ with self.assertRaisesRegex(ValueError,
+ 'Type of f1 is invalid, expected FieldType.NUMBER, actual FieldType.STRING'):
+ validate_expression(exp, self.SUPPORTED_FIELDS)
+ exp = FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='f2',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(number_list=[1])))
+ with self.assertRaisesRegex(ValueError,
+ 'Type of f2 is invalid, expected FieldType.STRING, actual FieldType.NUMBER'):
+ validate_expression(exp, self.SUPPORTED_FIELDS)
+
+
+class FilterBuilderTest(NoWebServerTestCase):
+
+ def test_build_query(self):
+ exp = FilterExpression(kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(
+ field='id',
+ op=FilterOp.IN,
+ list_value=SimpleExpression.ListValue(number_list=[1, 2]))),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='name',
+ op=FilterOp.EQUAL,
+ string_value='test name')),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='name',
+ op=FilterOp.CONTAIN,
+ string_value='test')),
+ FilterExpression(
+ kind=FilterExpressionKind.AND,
+ exps=[
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='disabled',
+ op=FilterOp.EQUAL,
+ bool_value=True)),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='amount',
+ op=FilterOp.EQUAL,
+ number_value=666.6)),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='id',
+ op=FilterOp.GREATER_THAN,
+ number_value=1)),
+ FilterExpression(kind=FilterExpressionKind.SIMPLE,
+ simple_exp=SimpleExpression(field='id',
+ op=FilterOp.LESS_THAN,
+ number_value=1999)),
+ ]),
+ ])
+ with db.session_scope() as session:
+ models = [
+ TestModel(
+ id=1,
+ name='test name',
+ disabled=False,
+ amount=666.6,
+ ),
+ TestModel(
+ id=2,
+ name='test name',
+ disabled=True,
+ amount=666.6,
+ ),
+ TestModel(
+ id=3,
+ name='test name',
+ disabled=True,
+ amount=666.6,
+ )
+ ]
+ session.add_all(models)
+ session.commit()
+
+ def amount_filter(exp: FilterExpression):
+ return and_(TestModel.amount > 600, TestModel.amount < 700)
+
+ builder = FilterBuilder(TestModel,
+ supported_fields={
+ 'id':
+ SupportedField(type=FieldType.NUMBER,
+ ops={
+ FilterOp.EQUAL: None,
+ FilterOp.IN: None,
+ FilterOp.GREATER_THAN: None,
+ FilterOp.LESS_THAN: None,
+ }),
+ 'name':
+ SupportedField(type=FieldType.STRING,
+ ops={
+ FilterOp.EQUAL: None,
+ FilterOp.CONTAIN: None
+ }),
+ 'disabled':
+ SupportedField(type=FieldType.BOOL, ops={FilterOp.EQUAL: None}),
+ 'amount':
+ SupportedField(type=FieldType.NUMBER, ops={FilterOp.EQUAL: amount_filter}),
+ })
+ with db.session_scope() as session:
+ query = session.query(TestModel)
+ query = builder.build_query(query, exp)
+ self.assertEqual(
+ self.generate_mysql_statement(query),
+ 'SELECT test_table.id, test_table.name, test_table.disabled, test_table.amount \n'
+ 'FROM test_table \n'
+ # lower() is called since it is meant to be case-insensitive
+ # pylint: disable-next=line-too-long
+ 'WHERE test_table.id IN (1.0, 2.0) AND test_table.name = \'test name\' AND lower(test_table.name) LIKE lower(\'%%test%%\') AND test_table.disabled = true AND test_table.amount > 600 AND test_table.amount < 700 AND test_table.id > 1.0 AND test_table.id < 1999.0'
+ )
+ model_ids = [m.id for m in query.all()]
+ self.assertCountEqual(model_ids, [2])
+
+ def test_build_query_for_empty_exp(self):
+ exp = FilterExpression()
+ builder = FilterBuilder(TestModel, supported_fields={})
+ with db.session_scope() as session:
+ query = session.query(TestModel)
+ query = builder.build_query(query, exp)
+ self.assertEqual(
+ self.generate_mysql_statement(query),
+ 'SELECT test_table.id, test_table.name, test_table.disabled, test_table.amount \n'
+ 'FROM test_table')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/flask_utils.py b/web_console_v2/api/fedlearner_webconsole/utils/flask_utils.py
new file mode 100644
index 000000000..9349e8a2e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/flask_utils.py
@@ -0,0 +1,105 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import io
+import json
+import typing
+import urllib.parse
+from http import HTTPStatus
+from typing import Optional, Tuple, Union
+from flask import send_file, g, has_request_context, request
+from google.protobuf.message import Message
+from marshmallow import ValidationError
+from webargs import fields
+
+from envs import Envs
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression
+from fedlearner_webconsole.utils.const import SSO_HEADER
+from fedlearner_webconsole.utils.filtering import parse_expression
+from fedlearner_webconsole.utils.proto import to_dict
+
+
+def download_json(content: dict, filename: str):
+ in_memory_file = io.BytesIO()
+ # `ensure_ascii=False` to make sure non-ascii show correctly
+ in_memory_file.write(json.dumps(content, ensure_ascii=False).encode('utf-8'))
+ in_memory_file.seek(0)
+ return send_file(in_memory_file,
+ as_attachment=True,
+ attachment_filename=f'{filename}.json',
+ mimetype='application/json; charset=UTF-8',
+ cache_timeout=0)
+
+
+def get_current_sso() -> Optional[str]:
+ sso_headers = request.headers.get(SSO_HEADER, None)
+ if sso_headers:
+ return sso_headers.split()[0]
+ return None
+
+
+def get_current_user() -> Optional[User]:
+ if has_request_context() and hasattr(g, 'current_user'):
+ return g.current_user
+ return None
+
+
+def set_current_user(current_user: User):
+ g.current_user = current_user
+
+
+def _normalize_data(data: Union[Message, dict, list]) -> Union[dict, list]:
+ if isinstance(data, Message):
+ return to_dict(data)
+ if isinstance(data, list):
+ return [_normalize_data(d) for d in data]
+ if isinstance(data, dict):
+ return {k: _normalize_data(v) for k, v in data.items()}
+ return data
+
+
+def make_flask_response(data: Optional[Union[Message, dict, list]] = None,
+ page_meta: Optional[dict] = None,
+ status: int = HTTPStatus.OK) -> Tuple[dict, int]:
+ if data is None:
+ data = {}
+ data = _normalize_data(data)
+
+ if page_meta is None:
+ page_meta = {}
+ return {
+ 'data': data,
+ 'page_meta': page_meta,
+ }, status
+
+
+def get_link(path: str) -> str:
+ host_url = None
+ if has_request_context():
+ host_url = request.host_url
+ if not host_url:
+ host_url = Envs.SERVER_HOST
+ return urllib.parse.urljoin(host_url, path)
+
+
+class FilterExpField(fields.Field):
+ """A marshmallow field represents the filtering expression. See details in filtering.py."""
+
+ def _deserialize(self, value: str, attr: str, data: typing.Any, **kwargs) -> FilterExpression:
+ try:
+ return parse_expression(value)
+ except ValueError as e:
+ raise ValidationError(f'Failed to parse filter {value}') from e
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/flask_utils_test.py b/web_console_v2/api/fedlearner_webconsole/utils/flask_utils_test.py
new file mode 100644
index 000000000..7c5793cec
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/flask_utils_test.py
@@ -0,0 +1,157 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+import json
+import unittest
+
+from http import HTTPStatus
+from unittest.mock import patch
+from google.protobuf import struct_pb2
+
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression
+from fedlearner_webconsole.proto.testing.testing_pb2 import Tdata
+from fedlearner_webconsole.utils.decorators.pp_flask import use_kwargs
+from fedlearner_webconsole.utils.flask_utils import download_json, get_current_user, set_current_user, \
+ make_flask_response, _normalize_data, get_link, FilterExpField, get_current_sso
+from testing.common import BaseTestCase
+
+
+class FlaskUtilsTest(BaseTestCase):
+
+ def test_download_json(self):
+ test_content = {'haha': {'hh': [0]}, 'abc': 123, 'unicode': ['boss行味镪', '🤑']}
+
+ @self.app.route('/test', methods=['POST'])
+ def test_route():
+ return download_json(test_content, 'test_file')
+
+ response = self.client.post('/test')
+ self.assertEqual(
+ response.data, b'{"haha": {"hh": [0]}, "abc": 123,'
+ b' "unicode": ["boss\xe8\xa1\x8c\xe5\x91\xb3\xe9\x95\xaa", "\xf0\x9f\xa4\x91"]}')
+ self.assertEqual(response.data.decode('utf-8'),
+ '{"haha": {"hh": [0]}, "abc": 123, "unicode": ["boss行味镪", "🤑"]}')
+ self.assertEqual(json.loads(response.data.decode('utf-8')), test_content)
+ self.assertEqual(response.headers['Content-Disposition'], 'attachment; filename=test_file.json')
+ self.assertEqual(response.headers['Content-Type'], 'application/json; charset=UTF-8')
+
+ def test_get_current_user(self):
+ test_user = User(id=1, username='test')
+
+ @self.app.route('/test', methods=['POST'])
+ def test_route():
+ set_current_user(test_user)
+ return {}, HTTPStatus.OK
+
+ self.client.post('/test')
+ self.assertEqual(test_user, get_current_user())
+
+ def test_normalize_data(self):
+ # Dict
+ d = {'a': 123}
+ self.assertEqual(_normalize_data(d), d)
+ # Proto
+ self.assertEqual(_normalize_data(Tdata(id=134)), {
+ 'id': 134,
+ 'mappers': {},
+ 'projects': [],
+ 'tt': 'UNSPECIFIED',
+ })
+ # Array of proto
+ self.assertEqual(_normalize_data([Tdata(id=1), Tdata(id=2)]), [{
+ 'id': 1,
+ 'mappers': {},
+ 'projects': [],
+ 'tt': 'UNSPECIFIED',
+ }, {
+ 'id': 2,
+ 'mappers': {},
+ 'projects': [],
+ 'tt': 'UNSPECIFIED',
+ }])
+ # Array
+ l = [{'a': 44}, {'b': '123'}]
+ self.assertEqual(_normalize_data(l), l)
+ # Dict with nested Protobuf Message and map structure.
+ self.assertEqual(_normalize_data({'a': Tdata(id=1, mappers={0: struct_pb2.Value(string_value='test')})}),
+ {'a': {
+ 'id': 1,
+ 'mappers': {
+ '0': 'test',
+ },
+ 'projects': [],
+ 'tt': 'UNSPECIFIED',
+ }})
+
+ def test_make_flask_response(self):
+ resp, status = make_flask_response()
+ self.assertDictEqual(resp, {'data': {}, 'page_meta': {}})
+ self.assertEqual(HTTPStatus.OK, status)
+
+ data = [{'name': 'kiyoshi'} for _ in range(5)]
+ page_meta = {'page': 1, 'page_size': 0, 'total_items': 5, 'total_pages': 1}
+ resp, status = make_flask_response(data, page_meta)
+ self.assertDictEqual(data[0], resp.get('data')[0])
+ self.assertDictEqual(page_meta, resp.get('page_meta'))
+
+ def test_get_link_in_flask(self):
+
+ @self.app.route('/test')
+ def test_route():
+ return get_link('/v2/workflow-center/workflows/123')
+
+ resp = self.get_helper('/test', use_auth=False)
+ self.assertEqual(resp.data.decode('utf-8'), 'http://localhost/v2/workflow-center/workflows/123')
+
+ @patch('fedlearner_webconsole.utils.flask_utils.request.headers.get')
+ def test_get_current_sso(self, mock_headers):
+ mock_headers.return_value = 'test oauth access_token'
+ sso_name = get_current_sso()
+ self.assertEqual(sso_name, 'test')
+
+
+class NonFlaskTest(unittest.TestCase):
+
+ def test_get_link_not_in_flask(self):
+ self.assertEqual(get_link('/v2/test'), 'http://localhost:666/v2/test')
+
+
+class FilterExpFieldTest(BaseTestCase):
+
+ def test_custom_field(self):
+
+ @self.app.route('/test')
+ @use_kwargs({
+ 'filter_exp': FilterExpField(required=False, load_default=None),
+ }, location='query')
+ def test_route(filter_exp: FilterExpression):
+ return make_flask_response(data=filter_exp)
+
+ resp = self.get_helper('/test', use_auth=False)
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+
+ resp = self.get_helper('/test?filter_exp=invalid', use_auth=False)
+ self.assertEqual(resp.status_code, HTTPStatus.BAD_REQUEST)
+
+ resp = self.get_helper('/test?filter_exp=(x%3D123)', use_auth=False)
+ self.assertEqual(resp.status_code, HTTPStatus.OK)
+ data = self.get_response_data(resp)
+ self.assertEqual(data['simple_exp']['field'], 'x')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/hooks.py b/web_console_v2/api/fedlearner_webconsole/utils/hooks.py
index 25d0b9a6b..b9327d3a7 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/hooks.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/hooks.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,18 +14,37 @@
# coding: utf-8
import importlib
+from typing import Any
from envs import Envs
-from fedlearner_webconsole.db import db_handler as db, get_database_uri
+from fedlearner_webconsole.db import db, get_database_uri
+from fedlearner_webconsole.middleware.middlewares import flask_middlewares
+from fedlearner_webconsole.middleware.request_id import FlaskRequestId
+from fedlearner_webconsole.middleware.api_latency import api_latency_middleware
+
+
+def parse_and_get_fn(module_fn_path: str) -> Any:
+ if module_fn_path.find(':') == -1:
+ raise RuntimeError(f'Invalid module_fn_path: {module_fn_path}')
+
+ module_path, func_name = module_fn_path.split(':')
+ try:
+ module = importlib.import_module(module_path)
+ fn = getattr(module, func_name)
+ except (ModuleNotFoundError, AttributeError) as e:
+ raise RuntimeError(f'Skipping run {module_fn_path} for no fn found') from e
+ # Dynamically run the function
+ return fn
def pre_start_hook():
before_hook_path = Envs.PRE_START_HOOK
if before_hook_path:
- module_path, func_name = before_hook_path.split(':')
- module = importlib.import_module(module_path)
- # Dynamically run the function
- getattr(module, func_name)()
+ parse_and_get_fn(before_hook_path)()
# explicit rebind db engine to make hook work
db.rebind(get_database_uri())
+
+ # Applies middlewares
+ flask_middlewares.register(FlaskRequestId())
+ flask_middlewares.register(api_latency_middleware)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/hooks_test.py b/web_console_v2/api/fedlearner_webconsole/utils/hooks_test.py
new file mode 100644
index 000000000..84f5e3743
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/hooks_test.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+import unittest
+
+from fedlearner_webconsole.utils.hooks import parse_and_get_fn
+
+
+class HookTest(unittest.TestCase):
+
+ def test_parse_and_get_fn(self):
+ # right one
+ right_hook = 'testing.test_data.hello:hello'
+ self.assertEqual(parse_and_get_fn(right_hook)(), 1)
+
+ # unexisted one
+ unexisted_hook = 'hello:hello'
+ with self.assertRaises(RuntimeError):
+ parse_and_get_fn(unexisted_hook)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/images.py b/web_console_v2/api/fedlearner_webconsole/utils/images.py
new file mode 100644
index 000000000..0464ca4ac
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/images.py
@@ -0,0 +1,23 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from fedlearner_webconsole.setting.service import SettingService
+
+
+def generate_unified_version_image(image_prefix: str) -> str:
+ # remove tag if input image_prefix has tag
+ if image_prefix.find(':') != -1:
+ image_prefix = image_prefix.rsplit(':', 1)[0]
+ return f'{image_prefix}:{SettingService.get_application_version().version.version}'
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/images_test.py b/web_console_v2/api/fedlearner_webconsole/utils/images_test.py
new file mode 100644
index 000000000..e8ebcc5fb
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/images_test.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import MagicMock, patch
+
+from fedlearner_webconsole.utils.images import generate_unified_version_image
+
+
+class ImageUtilsTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.utils.images.SettingService.get_application_version')
+ def test_generate_unified_version_image(self, mock_get_application_version: MagicMock):
+ mock_version = MagicMock()
+ mock_version.version.version = '2.2.2.2'
+ mock_get_application_version.return_value = mock_version
+ self.assertEqual(generate_unified_version_image('artifact.bytedance.com/fedlearner/pp_data_inspection'),
+ 'artifact.bytedance.com/fedlearner/pp_data_inspection:2.2.2.2')
+ self.assertEqual(generate_unified_version_image('artifact.bytedance.com/fedlearner/pp_data_inspection:test'),
+ 'artifact.bytedance.com/fedlearner/pp_data_inspection:2.2.2.2')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/job_metrics.py b/web_console_v2/api/fedlearner_webconsole/utils/job_metrics.py
new file mode 100644
index 000000000..7ad8f9976
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/job_metrics.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import os
+import logging
+import tensorflow.compat.v1 as tf
+from typing import Dict
+from google.protobuf import text_format
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.proto.tree_model_pb2 import BoostingTreeEnsambleProto
+from fedlearner_webconsole.utils.file_manager import file_manager
+
+
+def get_feature_importance(job: Job) -> Dict[str, float]:
+ storage_root_dir = job.project.get_storage_root_path(None)
+ if storage_root_dir is None:
+ return {}
+ job_name = job.name
+ path = os.path.join(storage_root_dir, 'job_output', job_name, 'exported_models')
+ if not file_manager.exists(path):
+ return {}
+ fin = tf.io.gfile.GFile(path, 'r')
+ model = BoostingTreeEnsambleProto()
+ try:
+ text_format.Parse(fin.read(), model, allow_unknown_field=True)
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning('parsing tree proto with error %s', str(e))
+ return {}
+ fscore = model.feature_importance
+ feature_names = list(model.feature_names)
+ cat_feature_names = list(model.cat_feature_names)
+ feature_names.extend(cat_feature_names)
+ if len(feature_names) == 0:
+ feature_names = [f'f{i}' for i in range(len(fscore))]
+ feature_importance = {}
+ for i, name in enumerate(feature_names):
+ feature_importance[name] = fscore[i]
+ for j in range(len(feature_names), len(fscore)):
+ feature_importance[f'peer_f{j}'] = fscore[j]
+ return feature_importance
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/k8s_cache.py b/web_console_v2/api/fedlearner_webconsole/utils/k8s_cache.py
deleted file mode 100644
index 44783205b..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/k8s_cache.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import threading
-from enum import Enum
-
-
-class EventType(Enum):
- ADDED = 'ADDED'
- MODIFIED = 'MODIFIED'
- DELETED = 'DELETED'
-
-
-class ObjectType(Enum):
- POD = 'POD'
- FLAPP = 'FLAPP'
-
-
-class Event(object):
- def __init__(self, flapp_name, event_type, obj_type, obj_dict):
- self.flapp_name = flapp_name
- self.event_type = event_type
- self.obj_type = obj_type
- # {'status': {}, 'metadata': {}}
- self.obj_dict = obj_dict
-
- @staticmethod
- def from_json(event, obj_type):
- # TODO(xiangyuxuan): move this to k8s/models.py
- event_type = event['type']
- obj = event['object']
- if obj_type == ObjectType.POD:
- obj = obj.to_dict()
- metadata = obj.get('metadata')
- status = obj.get('status')
- flapp_name = metadata['labels']['app-name']
- return Event(flapp_name,
- EventType(event_type),
- obj_type,
- obj_dict={'status': status,
- 'metadata': metadata})
- metadata = obj.get('metadata')
- status = obj.get('status')
- # put event to queue
- return Event(metadata['name'],
- EventType(event_type),
- obj_type,
- obj_dict={'status': status})
-
-
-class K8sCache(object):
-
- def __init__(self):
- self._lock = threading.Lock()
- # key: flapp_name, value: a dict
- # {'flapp': flapp cache, 'pods': pods cache,
- # 'deleted': is flapp deleted}
- self._cache = {}
-
- # TODO(xiangyuxuan): use class instead of json to manage cache and queue
- def update_cache(self, event: Event):
- with self._lock:
- flapp_name = event.flapp_name
- if flapp_name not in self._cache:
- self._cache[flapp_name] = {'pods': {'items': []},
- 'deleted': False}
- # if not flapp's then pod's event
- if event.obj_type == ObjectType.FLAPP:
- if event.event_type == EventType.DELETED:
- self._cache[flapp_name] = {'pods': {'items': []},
- 'deleted': True}
- else:
- self._cache[flapp_name]['deleted'] = False
- self._cache[flapp_name]['flapp'] = event.obj_dict
- else:
- if self._cache[flapp_name]['deleted']:
- return
- existed = False
- for index, pod in enumerate(
- self._cache[flapp_name]['pods']['items']):
- if pod['metadata']['name'] == \
- event.obj_dict['metadata']['name']:
- existed = True
- self._cache[flapp_name]['pods']['items'][index] \
- = event.obj_dict
- break
- if not existed:
- self._cache[flapp_name]['pods'][
- 'items'].append(event.obj_dict)
-
- def get_cache(self, flapp_name):
- # use read-write lock to fast
- with self._lock:
- if flapp_name in self._cache:
- return self._cache[flapp_name]
- return {'flapp': None, 'pods': {'items': []}}
-
-
-k8s_cache = K8sCache()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/k8s_client.py b/web_console_v2/api/fedlearner_webconsole/utils/k8s_client.py
deleted file mode 100644
index 3106bdecd..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/k8s_client.py
+++ /dev/null
@@ -1,389 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import enum
-import logging
-import os
-from http import HTTPStatus
-from typing import Optional
-
-import kubernetes
-import requests
-from kubernetes import client
-from kubernetes.client.exceptions import ApiException
-
-from envs import Envs
-from fedlearner_webconsole.exceptions import (InvalidArgumentException,
- NotFoundException,
- ResourceConflictException,
- InternalException)
-from fedlearner_webconsole.utils.decorators import retry_fn
-from fedlearner_webconsole.utils.fake_k8s_client import FakeK8sClient
-from fedlearner_webconsole.utils.k8s_cache import k8s_cache
-
-
-class CrdKind(enum.Enum):
- FLAPP = 'flapps'
- SPARK_APPLICATION = 'sparkapplications'
-
-
-FEDLEARNER_CUSTOM_GROUP = 'fedlearner.k8s.io'
-FEDLEARNER_CUSTOM_VERSION = 'v1alpha1'
-
-SPARKOPERATOR_CUSTOM_GROUP = 'sparkoperator.k8s.io'
-SPARKOPERATOR_CUSTOM_VERSION = 'v1beta2'
-SPARKOPERATOR_NAMESPACE = Envs.K8S_NAMESPACE
-
-
-class K8sClient(object):
- def __init__(self):
- self.core = None
- self.crds = None
- self._networking = None
- self._app = None
- self._api_server_url = 'http://{}:{}'.format(
- os.environ.get('FL_API_SERVER_HOST', 'fedlearner-apiserver'),
- os.environ.get('FL_API_SERVER_PORT', 8101))
-
- def init(self, config_path: Optional[str] = None):
- # Sets config
- if config_path is None:
- kubernetes.config.load_incluster_config()
- else:
- kubernetes.config.load_kube_config(config_path)
- # Inits API clients
- self.core = client.CoreV1Api()
- self.crds = client.CustomObjectsApi()
- self._networking = client.NetworkingV1beta1Api()
- self._app = client.AppsV1Api()
-
- def close(self):
- self.core.api_client.close()
- self._networking.api_client.close()
-
- def _raise_runtime_error(self, exception: ApiException):
- raise RuntimeError('[{}] {}'.format(exception.status,
- exception.reason))
-
- def create_or_update_secret(self,
- data,
- metadata,
- secret_type,
- name,
- namespace='default'):
- """Create secret. If existed, then replace"""
- request = client.V1Secret(api_version='v1',
- data=data,
- kind='Secret',
- metadata=metadata,
- type=secret_type)
- try:
- self.core.read_namespaced_secret(name, namespace)
- # If the secret already exists, then we use patch to replace it.
- # We don't use replace method because it requires `resourceVersion`.
- self.core.patch_namespaced_secret(name, namespace, request)
- return
- except ApiException as e:
- # 404 is expected if the secret does not exist
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
- try:
- self.core.create_namespaced_secret(namespace, request)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def delete_secret(self, name, namespace='default'):
- try:
- self.core.delete_namespaced_secret(name, namespace)
- except ApiException as e:
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
-
- def get_secret(self, name, namespace='default'):
- try:
- return self.core.read_namespaced_secret(name, namespace)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def create_or_update_service(self,
- metadata,
- spec,
- name,
- namespace='default'):
- """Create secret. If existed, then replace"""
- request = client.V1Service(api_version='v1',
- kind='Service',
- metadata=metadata,
- spec=spec)
- try:
- self.core.read_namespaced_service(name, namespace)
- # If the service already exists, then we use patch to replace it.
- # We don't use replace method because it requires `resourceVersion`.
- self.core.patch_namespaced_service(name, namespace, request)
- return
- except ApiException as e:
- # 404 is expected if the service does not exist
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
- try:
- self.core.create_namespaced_service(namespace, request)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def delete_service(self, name, namespace='default'):
- try:
- self.core.delete_namespaced_service(name, namespace)
- except ApiException as e:
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
-
- def get_service(self, name, namespace='default'):
- try:
- return self.core.read_namespaced_service(name, namespace)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def create_or_update_ingress(self,
- metadata,
- spec,
- name,
- namespace='default'):
- request = client.NetworkingV1beta1Ingress(
- api_version='networking.k8s.io/v1beta1',
- kind='Ingress',
- metadata=metadata,
- spec=spec)
- try:
- self._networking.read_namespaced_ingress(name, namespace)
- # If the ingress already exists, then we use patch to replace it.
- # We don't use replace method because it requires `resourceVersion`.
- self._networking.patch_namespaced_ingress(name, namespace, request)
- return
- except ApiException as e:
- # 404 is expected if the ingress does not exist
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
- try:
- self._networking.create_namespaced_ingress(namespace, request)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def delete_ingress(self, name, namespace='default'):
- try:
- self._networking.delete_namespaced_ingress(name, namespace)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def get_ingress(self, name, namespace='default'):
- try:
- return self._networking.read_namespaced_ingress(name, namespace)
- except ApiException as e:
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
-
- def create_or_update_deployment(self,
- metadata,
- spec,
- name,
- namespace='default'):
- request = client.V1Deployment(api_version='apps/v1',
- kind='Deployment',
- metadata=metadata,
- spec=spec)
- try:
- self._app.read_namespaced_deployment(name, namespace)
- # If the deployment already exists, then we use patch to replace it.
- # We don't use replace method because it requires `resourceVersion`.
- self._app.patch_namespaced_deployment(name, namespace, request)
- return
- except ApiException as e:
- # 404 is expected if the deployment does not exist
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
- try:
- self._app.create_namespaced_deployment(namespace, request)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def delete_deployment(self, name, namespace='default'):
- try:
- self._app.delete_namespaced_deployment(name, namespace)
- except ApiException as e:
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
-
- def get_deployment(self, name, namespace='default'):
- try:
- return self._app.read_namespaced_deployment(name, namespace)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- @retry_fn(retry_times=3)
- def delete_flapp(self, flapp_name):
- try:
- self.crds.delete_namespaced_custom_object(
- group=FEDLEARNER_CUSTOM_GROUP,
- version=FEDLEARNER_CUSTOM_VERSION,
- namespace=Envs.K8S_NAMESPACE,
- plural=CrdKind.FLAPP.value,
- name=flapp_name)
- except ApiException as e:
- # If the flapp has been deleted then the exception gets ignored
- if e.status != HTTPStatus.NOT_FOUND:
- self._raise_runtime_error(e)
-
- @retry_fn(retry_times=3)
- def create_flapp(self, flapp_yaml):
- try:
- self.crds.create_namespaced_custom_object(
- group=FEDLEARNER_CUSTOM_GROUP,
- version=FEDLEARNER_CUSTOM_VERSION,
- namespace=Envs.K8S_NAMESPACE,
- plural=CrdKind.FLAPP.value,
- body=flapp_yaml)
- except ApiException as e:
- # If the flapp exists then we delete it
- if e.status == HTTPStatus.CONFLICT:
- self.delete_flapp(flapp_yaml['metadata']['name'])
- # Raise to make it retry
- raise
-
- def get_flapp(self, flapp_name):
- return k8s_cache.get_cache(flapp_name)
-
- def get_webshell_session(self,
- flapp_name: str,
- container_name: str,
- namespace='default'):
- response = requests.get(
- '{api_server_url}/namespaces/{namespace}/pods/{custom_object_name}/'
- 'shell/${container_name}'.format(
- api_server_url=self._api_server_url,
- namespace=namespace,
- custom_object_name=flapp_name,
- container_name=container_name))
- if response.status_code != HTTPStatus.OK:
- raise RuntimeError('{}:{}'.format(response.status_code,
- response.content))
- return response.json()
-
- def get_sparkapplication(self,
- name: str,
- namespace: str = SPARKOPERATOR_NAMESPACE) -> dict:
- """get sparkapp
-
- Args:
- name (str): sparkapp name
- namespace (str, optional): namespace to submit.
-
- Raises:
- ApiException
-
- Returns:
- dict: resp of k8s
- """
- try:
- return self.crds.get_namespaced_custom_object(
- group=SPARKOPERATOR_CUSTOM_GROUP,
- version=SPARKOPERATOR_CUSTOM_VERSION,
- namespace=namespace,
- plural=CrdKind.SPARK_APPLICATION.value,
- name=name)
- except ApiException as err:
- if err.status == 404:
- raise NotFoundException()
- raise InternalException(details=err.body)
-
- def create_sparkapplication(
- self,
- json_object: dict,
- namespace: str = SPARKOPERATOR_NAMESPACE) -> dict:
- """ create sparkapp
-
- Args:
- json_object (dict): json object of config
- namespace (str, optional): namespace to submit.
-
- Raises:
- ApiException
-
- Returns:
- dict: resp of k8s
- """
- try:
- logging.debug('create sparkapp json is %s', json_object)
- return self.crds.create_namespaced_custom_object(
- group=SPARKOPERATOR_CUSTOM_GROUP,
- version=SPARKOPERATOR_CUSTOM_VERSION,
- namespace=namespace,
- plural=CrdKind.SPARK_APPLICATION.value,
- body=json_object)
- except ApiException as err:
- if err.status == 409:
- raise ResourceConflictException(message=err.reason)
- if err.status == 400:
- raise InvalidArgumentException(details=err.reason)
- raise InternalException(details=err.body)
-
- def delete_sparkapplication(self,
- name: str,
- namespace: str = SPARKOPERATOR_NAMESPACE
- ) -> dict:
- """ delete sparkapp
-
- Args:
- name (str): sparkapp name
- namespace (str, optional): namespace to delete.
-
- Raises:
- ApiException
-
- Returns:
- dict: resp of k8s
- """
- try:
- return self.crds.delete_namespaced_custom_object(
- group=SPARKOPERATOR_CUSTOM_GROUP,
- version=SPARKOPERATOR_CUSTOM_VERSION,
- namespace=namespace,
- plural=CrdKind.SPARK_APPLICATION.value,
- name=name,
- body=client.V1DeleteOptions())
- except ApiException as err:
- if err.status == 404:
- raise NotFoundException()
- raise InternalException(details=err.body)
-
- def get_pod_log(self, name: str, namespace: str, tail_lines: int):
- try:
- return self.core.read_namespaced_pod_log(name=name,
- namespace=namespace,
- tail_lines=tail_lines)
- except ApiException as e:
- self._raise_runtime_error(e)
-
- def get_pods(self, namespace, label_selector):
- try:
- return self.core.list_namespaced_pod(namespace=namespace,
- label_selector=label_selector)
- except ApiException as e:
- self._raise_runtime_error(e)
-
-
-k8s_client = FakeK8sClient()
-if Envs.FLASK_ENV == 'production' or \
- Envs.K8S_CONFIG_PATH is not None:
- k8s_client = K8sClient()
- k8s_client.init(Envs.K8S_CONFIG_PATH)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/k8s_watcher.py b/web_console_v2/api/fedlearner_webconsole/utils/k8s_watcher.py
deleted file mode 100644
index 22372a1ab..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/k8s_watcher.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import logging
-import threading
-import queue
-import traceback
-from http import HTTPStatus
-from kubernetes import client, watch
-from envs import Envs, Features
-from fedlearner_webconsole.utils.k8s_cache import k8s_cache, \
- Event, ObjectType
-from fedlearner_webconsole.utils.k8s_client import (
- k8s_client, FEDLEARNER_CUSTOM_GROUP,
- FEDLEARNER_CUSTOM_VERSION)
-from fedlearner_webconsole.mmgr.service import ModelService
-from fedlearner_webconsole.db import make_session_context
-from fedlearner_webconsole.job.service import JobService
-
-
-session_context = make_session_context()
-
-class K8sWatcher(object):
- def __init__(self):
- self._lock = threading.Lock()
- self._running = False
- self._flapp_watch_thread = None
- self._pods_watch_thread = None
- self._event_consumer_thread = None
-
- # https://stackoverflow.com/questions/62223424/
- # simplequeue-vs-queue-in-python-what-is-the-
- # advantage-of-using-simplequeue
- # if use simplequeue, put opt never block.
- # TODO(xiangyuxuan): change to simplequeue
- self._queue = queue.Queue()
- self._cache = {}
- self._cache_lock = threading.Lock()
-
- def start(self):
- with self._lock:
- if self._running:
- logging.warning('K8s watcher has already started')
- return
- self._running = True
- self._flapp_watch_thread = threading.Thread(
- target=self._k8s_flapp_watcher,
- name='flapp_watcher',
- daemon=True)
- self._pods_watch_thread = threading.Thread(
- target=self._k8s_pods_watch,
- name='pods_watcher',
- daemon=True)
- self._event_consumer_thread = threading.Thread(
- target=self._event_consumer,
- name='cache_consumer',
- daemon=True)
- self._pods_watch_thread.start()
- self._flapp_watch_thread.start()
- self._event_consumer_thread.start()
- logging.info('K8s watcher started')
-
- def _event_consumer(self):
- # TODO(xiangyuxuan): do more business level operations
- while True:
- try:
- event = self._queue.get()
- k8s_cache.update_cache(event)
- # job state must be updated before model service
- self._update_hook(event)
- if Features.FEATURE_MODEL_K8S_HOOK:
- with session_context() as session:
- ModelService(session).k8s_watcher_hook(event)
- session.commit()
- except Exception as e: # pylint: disable=broad-except
- logging.error(f'K8s event_consumer : {str(e)}. '
- f'traceback:{traceback.format_exc()}')
-
- def _update_hook(self, event: Event):
- if event.obj_type == ObjectType.FLAPP:
- logging.debug('[k8s_watcher][_update_hook]receive event %s',
- event.flapp_name)
- with session_context() as session:
- JobService(session).update_running_state(event.flapp_name)
- session.commit()
-
- def _k8s_flapp_watcher(self):
- resource_version = '0'
- watcher = watch.Watch()
- while True:
- logging.info(f'new stream of flapps watch rv:{resource_version}')
- if not self._running:
- watcher.stop()
- break
- # resource_version '0' means getting a recent resource without
- # consistency guarantee, this is to reduce the load of etcd.
- # Ref: https://kubernetes.io/docs/reference/using-api
- # /api-concepts/ #the-resourceversion-parameter
- stream = watcher.stream(
- k8s_client.crds.list_namespaced_custom_object,
- group=FEDLEARNER_CUSTOM_GROUP,
- version=FEDLEARNER_CUSTOM_VERSION,
- namespace=Envs.K8S_NAMESPACE,
- plural='flapps',
- resource_version=resource_version,
- _request_timeout=900, # Sometimes watch gets stuck
- )
- try:
- for event in stream:
-
- self._produce_event(event, ObjectType.FLAPP)
-
- metadata = event['object'].get('metadata')
- if metadata['resourceVersion'] is not None:
- resource_version = max(metadata['resourceVersion'],
- resource_version)
- logging.debug(
- f'resource_version now: {resource_version}')
- except client.exceptions.ApiException as e:
- logging.error(f'watcher:{str(e)}')
- if e.status == HTTPStatus.GONE:
- # It has been too old, resources should be relisted
- resource_version = '0'
- except Exception as e: # pylint: disable=broad-except
- logging.error(f'K8s watcher gets event error: {str(e)}',
- exc_info=True)
-
- def _produce_event(self, event, obj_type):
- self._queue.put(Event.from_json(event, obj_type))
-
- def _k8s_pods_watch(self):
- resource_version = '0'
- watcher = watch.Watch()
- while True:
- logging.info(f'new stream of pods watch rv: {resource_version}')
- if not self._running:
- watcher.stop()
- break
- # resource_version '0' means getting a recent resource without
- # consistency guarantee, this is to reduce the load of etcd.
- # Ref: https://kubernetes.io/docs/reference/using-api
- # /api-concepts/ #the-resourceversion-parameter
- stream = watcher.stream(
- k8s_client.core.list_namespaced_pod,
- namespace=Envs.K8S_NAMESPACE,
- label_selector='app-name',
- resource_version=resource_version,
- _request_timeout=900, # Sometimes watch gets stuck
- )
-
- try:
- for event in stream:
- self._produce_event(event, ObjectType.POD)
- metadata = event['object'].metadata
- if metadata.resource_version is not None:
- resource_version = max(metadata.resource_version,
- resource_version)
- logging.debug(
- f'resource_version now: {resource_version}')
- except client.exceptions.ApiException as e:
- logging.error(f'watcher:{str(e)}')
- if e.status == HTTPStatus.GONE:
- # It has been too old, resources should be relisted
- resource_version = '0'
- except Exception as e: # pylint: disable=broad-except
- logging.error(f'K8s watcher gets event error: {str(e)}',
- exc_info=True)
-
-
-k8s_watcher = K8sWatcher()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/kibana.py b/web_console_v2/api/fedlearner_webconsole/utils/kibana.py
index d9271adcf..1a7824cd3 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/kibana.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/kibana.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
# limitations under the License.
# coding: utf-8
-# pylint: disable=invalid-string-quote
+# pylint: disable=invalid-string-quote,missing-type-doc,missing-return-type-doc,consider-using-f-string
import hashlib
import os
import re
@@ -37,34 +37,22 @@ class Kibana(object):
"""
TSVB = ('Rate', 'Ratio', 'Numeric')
TIMELION = ('Time', 'Timer')
- RISON_REPLACEMENT = {' ': '%20',
- '"': '%22',
- '#': '%23',
- '%': '%25',
- '&': '%26',
- '+': '%2B',
- '/': '%2F',
- '=': '%3D'}
- TIMELION_QUERY_REPLACEMENT = {' and ': ' AND ',
- ' or ': ' OR '}
+ RISON_REPLACEMENT = {' ': '%20', '"': '%22', '#': '%23', '%': '%25', '&': '%26', '+': '%2B', '/': '%2F', '=': '%3D'}
+ TIMELION_QUERY_REPLACEMENT = {' and ': ' AND ', ' or ': ' OR '}
LOGICAL_PATTERN = re.compile(' and | or ', re.IGNORECASE)
- TSVB_AGG_TYPE = {'Average': 'avg',
- 'Sum': 'sum',
- 'Max': 'max',
- 'Min': 'min',
- 'Variance': 'variance',
- 'Std. Deviation': 'std_deviation',
- 'Sum of Squares': 'sum_of_squares'}
- TIMELION_AGG_TYPE = {'Average': 'avg',
- 'Sum': 'sum',
- 'Max': 'max',
- 'Min': 'min'}
- COLORS = ['#DA6E6E', '#FA8080', '#789DFF',
- '#66D4FF', '#6EB518', '#9AF02E']
+ TSVB_AGG_TYPE = {
+ 'Average': 'avg',
+ 'Sum': 'sum',
+ 'Max': 'max',
+ 'Min': 'min',
+ 'Variance': 'variance',
+ 'Std. Deviation': 'std_deviation',
+ 'Sum of Squares': 'sum_of_squares'
+ }
+ TIMELION_AGG_TYPE = {'Average': 'avg', 'Sum': 'sum', 'Max': 'max', 'Min': 'min'}
+ COLORS = ['#DA6E6E', '#FA8080', '#789DFF', '#66D4FF', '#6EB518', '#9AF02E']
# metrics* for all other job types
- JOB_INDEX = {JobType.RAW_DATA: 'raw_data',
- JobType.DATA_JOIN: 'data_join',
- JobType.PSI_DATA_JOIN: 'data_join'}
+ JOB_INDEX = {JobType.RAW_DATA: 'raw_data', JobType.DATA_JOIN: 'data_join', JobType.PSI_DATA_JOIN: 'data_join'}
BASIC_QUERY = "app/kibana#/visualize/create" \
"?type={type}&_g=(refreshInterval:(pause:!t,value:0)," \
"time:(from:'{start_time}',to:'{end_time}'))&" \
@@ -83,17 +71,10 @@ def remote_query(job, args):
if 'query' in args and args['query']:
panel['filter']['query'] += ' and ({})'.format(args['query'])
st, et = Kibana._parse_start_end_time(args, use_now=False)
- req = {
- 'timerange': {
- 'timezone': Envs.TZ.zone,
- 'min': st,
- 'max': et
- },
- 'panels': [panel]
- }
- res = requests.post(
- os.path.join(Envs.KIBANA_SERVICE_ADDRESS, 'api/metrics/vis/data'),
- json=req, headers={'kbn-xsrf': 'true'})
+ req = {'timerange': {'timezone': Envs.TZ.zone, 'min': st, 'max': et}, 'panels': [panel]}
+ res = requests.post(os.path.join(Envs.KIBANA_SERVICE_ADDRESS, 'api/metrics/vis/data'),
+ json=req,
+ headers={'kbn-xsrf': 'true'})
try:
res.raise_for_status()
@@ -102,18 +83,15 @@ def remote_query(job, args):
data = list(map(lambda x: [x[0], x[1] or 0], data))
return data
except Exception as e: # pylint: disable=broad-except
- raise InternalException(repr(e))
+ raise InternalException(repr(e)) from e
@staticmethod
def _check_remote_args(args):
- for arg in ('type', 'interval', 'x_axis_field',
- 'start_time', 'end_time'):
+ for arg in ('type', 'interval', 'x_axis_field', 'start_time', 'end_time'):
Kibana._check_present(args, arg)
Kibana._check_authorization(args.get('query'))
- Kibana._check_authorization(args['x_axis_field'],
- extra_allowed={'tags.event_time',
- 'tags.process_time'})
+ Kibana._check_authorization(args['x_axis_field'], extra_allowed={'tags.event_time', 'tags.process_time'})
if args['type'] == 'Ratio':
for arg in ('numerator', 'denominator'):
@@ -128,8 +106,7 @@ def _check_remote_args(args):
@staticmethod
def _check_present(args, arg_name):
if arg_name not in args or args[arg_name] is None:
- raise InvalidArgumentException(
- 'Missing required argument [{}].'.format(arg_name))
+ raise InvalidArgumentException('Missing required argument [{}].'.format(arg_name))
@staticmethod
def _check_authorization(arg, extra_allowed: set = None):
@@ -140,8 +117,7 @@ def _check_authorization(arg, extra_allowed: set = None):
if not query:
continue
if query.split(':')[0] not in allowed_fields:
- raise UnauthorizedException(
- 'Query [{}] is not authorized.'.format(query))
+ raise UnauthorizedException('Query [{}] is not authorized.'.format(query))
@staticmethod
def create_tsvb(job, args):
@@ -163,17 +139,14 @@ def create_tsvb(job, args):
vis_state['params']['filter']['query'] += \
' and ({})'.format(args['query'])
# rison-ify and replace
- vis_state = Kibana._regex_process(
- prison.dumps(vis_state), Kibana.RISON_REPLACEMENT
- )
+ vis_state = Kibana._regex_process(prison.dumps(vis_state), Kibana.RISON_REPLACEMENT)
start_time, end_time = Kibana._parse_start_end_time(args)
# a single-item list
return [
- os.path.join(Envs.KIBANA_ADDRESS,
- Kibana.BASIC_QUERY.format(type='metrics',
- start_time=start_time,
- end_time=end_time,
- vis_state=vis_state))
+ os.path.join(
+ Envs.KIBANA_ADDRESS,
+ Kibana.BASIC_QUERY.format(type='metrics', start_time=start_time, end_time=end_time,
+ vis_state=vis_state))
]
@staticmethod
@@ -190,16 +163,11 @@ def create_timelion(job, args):
vis_states, times = Kibana._create_timer_visualization(job, args)
# a generator, rison-ify and replace
- vis_states = (
- Kibana._regex_process(vs, Kibana.RISON_REPLACEMENT)
- for vs in map(prison.dumps, vis_states)
- )
+ vis_states = (Kibana._regex_process(vs, Kibana.RISON_REPLACEMENT) for vs in map(prison.dumps, vis_states))
return [
- os.path.join(Envs.KIBANA_ADDRESS,
- Kibana.BASIC_QUERY.format(type='timelion',
- start_time=start,
- end_time=end,
- vis_state=vis_state))
+ os.path.join(
+ Envs.KIBANA_ADDRESS,
+ Kibana.BASIC_QUERY.format(type='timelion', start_time=start, end_time=end, vis_state=vis_state))
for (start, end), vis_state in zip(times, vis_states)
]
@@ -210,15 +178,13 @@ def _parse_start_end_time(args, use_now=True):
else Kibana._normalize_datetime(
datetime.now(tz=pytz.utc) - timedelta(days=365 * 5))
else:
- st = Kibana._normalize_datetime(
- datetime.fromtimestamp(args['start_time'], tz=pytz.utc))
+ st = Kibana._normalize_datetime(datetime.fromtimestamp(args['start_time'], tz=pytz.utc))
if args['end_time'] < 0:
et = 'now' if use_now \
else Kibana._normalize_datetime(datetime.now(tz=pytz.utc))
else:
- et = Kibana._normalize_datetime(
- datetime.fromtimestamp(args['end_time'], tz=pytz.utc))
+ et = Kibana._normalize_datetime(datetime.fromtimestamp(args['end_time'], tz=pytz.utc))
return st, et
@staticmethod
@@ -238,10 +204,7 @@ def _regex_process(string, replacement):
re_mode = re.IGNORECASE
escaped_keys = map(re.escape, replacement)
pattern = re.compile("|".join(escaped_keys), re_mode)
- return pattern.sub(
- lambda match: replacement[match.group(0).lower()],
- string
- )
+ return pattern.sub(lambda match: replacement[match.group(0).lower()], string)
@staticmethod
def _create_rate_visualization(job, args):
@@ -259,51 +222,46 @@ def _create_rate_visualization(job, args):
params = vis_state['params']
# `w/`, `w/o` = `with`, `without`
# Total w/ Fake series
- twf = Kibana._tsvb_series(
- label='Total w/ Fake',
- metrics={'type': 'count'}
- )
+ twf = Kibana._tsvb_series(label='Total w/ Fake', metrics={'type': 'count'})
# Total w/o Fake series
twof = Kibana._tsvb_series(
labele='Total w/o Fake',
metrics={'type': 'count'},
# unjoined and normal joined
- series_filter={'query': 'tags.joined: "-1" or tags.joined: 1'}
- )
+ series_filter={'query': 'tags.joined: "-1" or tags.joined: 1'})
# Joined w/ Fake series
jwf = Kibana._tsvb_series(
label='Joined w/ Fake',
metrics={'type': 'count'},
# faked joined and normal joined
- series_filter={'query': 'tags.joined: 0 or tags.joined: 1'}
- )
+ series_filter={'query': 'tags.joined: 0 or tags.joined: 1'})
# Joined w/o Fake series
jwof = Kibana._tsvb_series(
label='Joined w/o Fake',
metrics={'type': 'count'},
# normal joined
- series_filter={'query': 'tags.joined: 1'}
- )
+ series_filter={'query': 'tags.joined: 1'})
# Join Rate w/ Fake series
jrwf = Kibana._tsvb_series(
series_type='ratio',
label='Join Rate w/ Fake',
- metrics={'numerator': 'tags.joined: 1 or tags.joined: 0',
- 'denominator': '*', # joined == -1 or 0 or 1
- 'type': 'filter_ratio'},
+ metrics={
+ 'numerator': 'tags.joined: 1 or tags.joined: 0',
+ 'denominator': '*', # joined == -1 or 0 or 1
+ 'type': 'filter_ratio'
+ },
line_width='2',
- fill='0'
- )
+ fill='0')
# Join Rate w/o Fake series
- jrwof = Kibana._tsvb_series(
- series_type='ratio',
- label='Join Rate w/o Fake',
- metrics={'numerator': 'tags.joined: 1',
- 'denominator': 'tags.joined: 1 or tags.joined: "-1"',
- 'type': 'filter_ratio'},
- line_width='2',
- fill='0'
- )
+ jrwof = Kibana._tsvb_series(series_type='ratio',
+ label='Join Rate w/o Fake',
+ metrics={
+ 'numerator': 'tags.joined: 1',
+ 'denominator': 'tags.joined: 1 or tags.joined: "-1"',
+ 'type': 'filter_ratio'
+ },
+ line_width='2',
+ fill='0')
series = [twf, twof, jwf, jwof, jrwf, jrwof]
for series_, color in zip(series, Kibana.COLORS):
series_['color'] = color
@@ -323,37 +281,34 @@ def _create_ratio_visualization(job, args):
Returns:
dict. A Kibana vis state dict
+ Raises:
+ ValueError: if some args not exist
+
This method will create 3 time series and stack them in vis state.
"""
for k in ('numerator', 'denominator'):
if k not in args or args[k] is None:
- raise ValueError(
- '[{}] should be provided in Ratio visualization'.format(k)
- )
+ raise ValueError('[{}] should be provided in Ratio visualization'.format(k))
vis_state = Kibana._basic_tsvb_vis_state(job, args)
params = vis_state['params']
# Denominator series
- denominator = Kibana._tsvb_series(
- label=args['denominator'],
- metrics={'type': 'count'},
- series_filter={'query': args['denominator']}
- )
+ denominator = Kibana._tsvb_series(label=args['denominator'],
+ metrics={'type': 'count'},
+ series_filter={'query': args['denominator']})
# Numerator series
- numerator = Kibana._tsvb_series(
- label=args['numerator'],
- metrics={'type': 'count'},
- series_filter={'query': args['numerator']}
- )
+ numerator = Kibana._tsvb_series(label=args['numerator'],
+ metrics={'type': 'count'},
+ series_filter={'query': args['numerator']})
# Ratio series
- ratio = Kibana._tsvb_series(
- series_type='ratio',
- label='Ratio',
- metrics={'numerator': args['numerator'],
- 'denominator': args['denominator'],
- 'type': 'filter_ratio'},
- line_width='2',
- fill='0'
- )
+ ratio = Kibana._tsvb_series(series_type='ratio',
+ label='Ratio',
+ metrics={
+ 'numerator': args['numerator'],
+ 'denominator': args['denominator'],
+ 'type': 'filter_ratio'
+ },
+ line_width='2',
+ fill='0')
series = [denominator, numerator, ratio]
for series_, color in zip(series, Kibana.COLORS[1::2]):
series_['color'] = color
@@ -371,6 +326,9 @@ def _create_numeric_visualization(job, args):
Returns:
dict. A Kibana vis state dict
+ Raises:
+ ValueError: if some args not exist
+
This method will create 1 time series. The series will filter data
further by `name: args['metric_name']`. Aggregation will be
applied on data's `args['value_field']` field. Aggregation types
@@ -378,21 +336,17 @@ def _create_numeric_visualization(job, args):
"""
for k in ('aggregator', 'value_field'):
if k not in args or args[k] is None:
- raise ValueError(
- '[{}] should be provided in Numeric visualization.'
- .format(k)
- )
+ raise ValueError('[{}] should be provided in Numeric visualization.'.format(k))
assert args['aggregator'] in Kibana.TSVB_AGG_TYPE
vis_state = Kibana._basic_tsvb_vis_state(job, args)
params = vis_state['params']
- series = Kibana._tsvb_series(
- label='{} of {}'.format(args['aggregator'],
- args['value_field']),
- metrics={'type': Kibana.TSVB_AGG_TYPE[args['aggregator']],
- 'field': args['value_field']},
- line_width=2,
- fill='0.5'
- )
+ series = Kibana._tsvb_series(label='{} of {}'.format(args['aggregator'], args['value_field']),
+ metrics={
+ 'type': Kibana.TSVB_AGG_TYPE[args['aggregator']],
+ 'field': args['value_field']
+ },
+ line_width=2,
+ fill='0.5')
series['color'] = Kibana.COLORS[-2]
params['series'] = [series]
return vis_state
@@ -422,23 +376,14 @@ def _create_time_visualization(job, args):
for t1, t2 in ((et, pt), (pt, et)):
# t1 vs t2, max/min/median of t1 as Y axis, t2 as X axis
# aggregate on t1 and histogram on t2
- max_series = Kibana._timelion_series(
- query=query, index=index,
- metric='max:' + t1, timefield=t2
- )
- min_series = Kibana._timelion_series(
- query=query, index=index,
- metric='min:' + t1, timefield=t2
- )
- median_series = Kibana._timelion_series(
- query=query, index=index,
- metric='percentiles:' + t1 + ':50', timefield=t2
- )
+ max_series = Kibana._timelion_series(query=query, index=index, metric='max:' + t1, timefield=t2)
+ min_series = Kibana._timelion_series(query=query, index=index, metric='min:' + t1, timefield=t2)
+ median_series = Kibana._timelion_series(query=query,
+ index=index,
+ metric='percentiles:' + t1 + ':50',
+ timefield=t2)
series = ','.join((max_series, min_series, median_series))
- vis_state = {"type": "timelion",
- "params": {"expression": series,
- "interval": interval},
- "aggs": []}
+ vis_state = {"type": "timelion", "params": {"expression": series, "interval": interval}, "aggs": []}
vis_states.append(vis_state)
by_pt_start = Kibana._get_start_from_job(job)
by_pt_end = 'now'
@@ -451,12 +396,10 @@ def _create_timer_visualization(job, args):
if not names:
return [], []
# split by comma, strip whitespaces of each name, filter out empty ones
- args['timer_names'] = [name for name in
- map(str.strip, names.split(',')) if name]
+ args['timer_names'] = [name for name in map(str.strip, names.split(',')) if name]
if args['aggregator'] not in Kibana.TIMELION_AGG_TYPE:
- raise TypeError('Aggregator [{}] is not supported in Timer '
- 'visualization.'.format(args['aggregator']))
+ raise TypeError('Aggregator [{}] is not supported in Timer ' 'visualization.'.format(args['aggregator']))
metric = '{}:value'.format(Kibana.TIMELION_AGG_TYPE[args['aggregator']])
query = 'tags.application_id:{}'.format(job.name)
@@ -465,25 +408,29 @@ def _create_timer_visualization(job, args):
interval = args['interval'] if args['interval'] != '' else 'auto'
series = []
for timer in args['timer_names']:
- s = Kibana._timelion_series(
- query=query + ' AND name:{}'.format(timer), index='metrics*',
- metric=metric, timefield='tags.process_time'
- )
+ s = Kibana._timelion_series(query=query + ' AND name:{}'.format(timer),
+ index='metrics*',
+ metric=metric,
+ timefield='tags.process_time')
series.append(s)
if args['split']:
# split series to different plots
vis_states = [{
"type": "timelion",
- "params": {"expression": s,
- "interval": interval},
+ "params": {
+ "expression": s,
+ "interval": interval
+ },
"aggs": []
} for s in series]
else:
# multiple series in one plot, a single-item list
vis_states = [{
"type": "timelion",
- "params": {"expression": ','.join(series),
- "interval": interval},
+ "params": {
+ "expression": ','.join(series),
+ "interval": interval
+ },
"aggs": []
}]
start = Kibana._get_start_from_job(job)
@@ -508,27 +455,29 @@ def _basic_tsvb_vis_state(job, args):
"""
assert 'x_axis_field' in args and args['x_axis_field']
- vis_state = {"aggs": [],
- "params": {"axis_formatter": "number",
- "axis_min": "",
- "axis_position": "left",
- "axis_scale": "normal",
- "default_index_pattern": "metrics*",
- "filter": {},
- "index_pattern": "",
- "interval": "",
- "isModelInvalid": False,
- "show_grid": 1,
- "show_legend": 1,
- "time_field": "",
- "type": "timeseries"}}
+ vis_state = {
+ "aggs": [],
+ "params": {
+ "axis_formatter": "number",
+ "axis_min": "",
+ "axis_position": "left",
+ "axis_scale": "normal",
+ "default_index_pattern": "metrics*",
+ "filter": {},
+ "index_pattern": "",
+ "interval": "",
+ "isModelInvalid": False,
+ "show_grid": 1,
+ "show_legend": 1,
+ "time_field": "",
+ "type": "timeseries"
+ }
+ }
params = vis_state['params']
params['interval'] = args.get('interval', '')
params['index_pattern'] = Kibana.JOB_INDEX \
.get(job.job_type, 'metrics') + '*'
- params['filter'] = Kibana._filter_query(
- 'tags.application_id:"{}"'.format(job.name)
- )
+ params['filter'] = Kibana._filter_query('tags.application_id:"{}"'.format(job.name))
params['time_field'] = args['x_axis_field']
return vis_state
@@ -547,7 +496,8 @@ def _tsvb_series(series_type='normal', **kwargs):
'series_filter': dict, additional filter on data,
only applied on this series.
- Returns: dict, a Kibana TSVB visualization time series definition
+ Returns:
+ dict, a Kibana TSVB visualization time series definition
"""
# series_id is meaningless and arbitrary to us but necessary
@@ -576,9 +526,7 @@ def _tsvb_series(series_type='normal', **kwargs):
}
if 'series_filter' in kwargs and 'query' in kwargs['series_filter']:
series['split_mode'] = 'filter'
- series['filter'] = Kibana._filter_query(
- kwargs['series_filter']['query']
- )
+ series['filter'] = Kibana._filter_query(kwargs['series_filter']['query'])
if series_type == 'ratio':
# if this is a ratio series, split axis and set axis range
series['separate_axis'] = 1
@@ -591,9 +539,7 @@ def _timelion_series(**kwargs):
assert 'metric' in kwargs
assert 'timefield' in kwargs
# convert all logical `and` and `or` to `AND` and `OR`
- query = Kibana._regex_process(
- kwargs.get('query', '*'), Kibana.TIMELION_QUERY_REPLACEMENT
- )
+ query = Kibana._regex_process(kwargs.get('query', '*'), Kibana.TIMELION_QUERY_REPLACEMENT)
return ".es(q=\"{query}\", index={index}, " \
"metric={metric}, timefield={timefield})" \
".legend(showTime=true)" \
@@ -602,8 +548,10 @@ def _timelion_series(**kwargs):
@staticmethod
def _filter_query(query):
- return {'language': 'kuery', # Kibana query
- 'query': query}
+ return {
+ 'language': 'kuery', # Kibana query
+ 'query': query
+ }
@staticmethod
def _get_start_from_job(job):
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/kibana_test.py b/web_console_v2/api/fedlearner_webconsole/utils/kibana_test.py
new file mode 100644
index 000000000..881c93331
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/kibana_test.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=protected-access
+import unittest
+
+from fedlearner_webconsole.exceptions import UnauthorizedException
+from fedlearner_webconsole.utils.kibana import Kibana
+
+
+class KibanaTest(unittest.TestCase):
+
+ def test_auth(self):
+ self.assertRaises(UnauthorizedException, Kibana._check_authorization, 'tags.1')
+ self.assertRaises(UnauthorizedException, Kibana._check_authorization, 'tags.1:2')
+ self.assertRaises(UnauthorizedException, Kibana._check_authorization, 'x:3 and y:4', {'x'})
+ self.assertRaises(UnauthorizedException, Kibana._check_authorization, 'x:3 OR y:4 AND z:5', {'x', 'z'})
+ try:
+ Kibana._check_authorization('x:1', {'x'})
+ Kibana._check_authorization('x:1 AND y:2 OR z:3', {'x', 'y', 'z'})
+ Kibana._check_authorization('x:1 oR y:2 aNd z:3', {'x', 'y', 'z'})
+ Kibana._check_authorization('*', {'x', '*'})
+ Kibana._check_authorization(None, None)
+ except UnauthorizedException:
+ self.fail()
+
+ def test_parse_time(self):
+ dt1 = 0
+ dt2 = 60 * 60 * 24
+ args = {'start_time': dt1, 'end_time': dt2}
+ st, et = Kibana._parse_start_end_time(args)
+ self.assertEqual(st, '1970-01-01T00:00:00Z')
+ self.assertEqual(et, '1970-01-02T00:00:00Z')
+ st, et = Kibana._parse_start_end_time({'start_time': -1, 'end_time': -1})
+ self.assertEqual(st, 'now-5y')
+ self.assertEqual(et, 'now')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/metrics.py b/web_console_v2/api/fedlearner_webconsole/utils/metrics.py
index c7c1971e2..d468e8703 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/metrics.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/metrics.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
@@ -15,11 +15,56 @@
# coding: utf-8
import logging
from abc import ABCMeta, abstractmethod
+import sys
+from typing import Dict, Union
+from threading import Lock
+
+from opentelemetry import trace, _metrics as metrics
+from opentelemetry._metrics.instrument import UpDownCounter
+from opentelemetry._metrics.measurement import Measurement
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk._metrics import MeterProvider
+from opentelemetry.sdk._metrics.export import (PeriodicExportingMetricReader, ConsoleMetricExporter, MetricExporter,
+ MetricExportResult, Metric, Sequence)
+from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
+from opentelemetry.exporter.otlp.proto.grpc._metric_exporter import OTLPMetricExporter
+from opentelemetry.sdk.trace.export import (BatchSpanProcessor, ConsoleSpanExporter, SpanExportResult, SpanExporter,
+ ReadableSpan)
+
+from envs import Envs
+
+
+def _validate_tags(tags: Dict[str, str]):
+ if tags is None:
+ return
+ for k, v in tags.items():
+ if not isinstance(k, str) or not isinstance(v, str):
+ raise TypeError(f'Expected str, actually {type(k)}: {type(v)}')
+
+
+class DevNullSpanExporter(SpanExporter):
+
+ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
+ return SpanExportResult.SUCCESS
+
+ def shutdown(self):
+ pass
+
+
+class DevNullMetricExporter(MetricExporter):
+
+ def export(self, metrics: Sequence[Metric]) -> MetricExportResult: # pylint: disable=redefined-outer-name
+ return MetricExportResult.SUCCESS
+
+ def shutdown(self):
+ pass
class MetricsHandler(metaclass=ABCMeta):
+
@abstractmethod
- def emit_counter(self, name, value: int, tags: dict = None):
+ def emit_counter(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
"""Emits counter metrics which will be accumulated.
Args:
@@ -29,7 +74,7 @@ def emit_counter(self, name, value: int, tags: dict = None):
"""
@abstractmethod
- def emit_store(self, name, value: int, tags: dict = None):
+ def emit_store(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
"""Emits store metrics.
Args:
@@ -41,11 +86,99 @@ def emit_store(self, name, value: int, tags: dict = None):
class _DefaultMetricsHandler(MetricsHandler):
- def emit_counter(self, name, value: int, tags: dict = None):
- logging.info(f'[Metric][Counter] {name}: {value}', extra=tags or {})
-
- def emit_store(self, name, value: int, tags: dict = None):
- logging.info(f'[Metric][Store] {name}: {value}', extra=tags or {})
+ def emit_counter(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ tags = tags or {}
+ logging.info(f'[Metric][Counter] {name}: {value}, tags={tags}')
+
+ def emit_store(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ tags = tags or {}
+ logging.info(f'[Metric][Store] {name}: {value}, tags={tags}')
+
+
+class OpenTelemetryMetricsHandler(MetricsHandler):
+
+ class Callback:
+
+ def __init__(self) -> None:
+ self._measurement_list = []
+
+ def add(self, value: Union[int, float], tags: Dict[str, str]):
+ self._measurement_list.append(Measurement(value=value, attributes=tags))
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if len(self._measurement_list) == 0:
+ raise StopIteration
+ return self._measurement_list.pop(0)
+
+ def __call__(self):
+ return iter(self)
+
+ @classmethod
+ def new_handler(cls) -> 'OpenTelemetryMetricsHandler':
+ instrument_module_name = 'fedlearner_webconsole'
+ resource = Resource.create(attributes={
+ 'service.name': instrument_module_name,
+ 'deployment.environment': Envs.CLUSTER
+ })
+ # initiailized trace stuff
+ if Envs.APM_SERVER_ENDPOINT == 'stdout':
+ span_exporter = ConsoleSpanExporter(out=sys.stdout)
+ elif Envs.APM_SERVER_ENDPOINT == '/dev/null':
+ span_exporter = DevNullSpanExporter()
+ else:
+ span_exporter = OTLPSpanExporter(endpoint=Envs.APM_SERVER_ENDPOINT)
+ tracer_provider = TracerProvider(resource=resource)
+ tracer_provider.add_span_processor(BatchSpanProcessor(span_exporter))
+ trace.set_tracer_provider(tracer_provider)
+
+ # initiailized meter stuff
+ if Envs.APM_SERVER_ENDPOINT == 'stdout':
+ metric_exporter = ConsoleMetricExporter(out=sys.stdout)
+ elif Envs.APM_SERVER_ENDPOINT == '/dev/null':
+ metric_exporter = DevNullMetricExporter()
+ else:
+ metric_exporter = OTLPMetricExporter(endpoint=Envs.APM_SERVER_ENDPOINT)
+ reader = PeriodicExportingMetricReader(metric_exporter, export_interval_millis=60000)
+ meter_provider = MeterProvider(metric_readers=[reader], resource=resource)
+ metrics.set_meter_provider(meter_provider=meter_provider)
+
+ return cls(tracer=tracer_provider.get_tracer(instrument_module_name),
+ meter=meter_provider.get_meter(instrument_module_name))
+
+ def __init__(self, tracer: trace.Tracer, meter: metrics.Meter):
+ self._tracer = tracer
+ self._meter = meter
+
+ self._lock = Lock()
+ self._cache: Dict[str, Union[UpDownCounter, OpenTelemetryMetricsHandler.Callback]] = {}
+
+ def emit_counter(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
+ # Note that the `values.` prefix is used for Elastic Index Dynamic Inference.
+ # Optimize by decreasing lock.
+ if name not in self._cache:
+ with self._lock:
+ # Double check `self._cache` content.
+ if name not in self._cache:
+ counter = self._meter.create_up_down_counter(name=f'values.{name}')
+ self._cache[name] = counter
+ assert isinstance(self._cache[name], UpDownCounter)
+ self._cache[name].add(value, attributes=tags)
+
+ def emit_store(self, name: str, value: Union[int, float], tags: Dict[str, str] = None):
+ # Note that the `values.` prefix is used for Elastic Index Dynamic Inference.
+ # Optimize by decreasing lock.
+ if name not in self._cache:
+ with self._lock:
+ # Double check `self._cache` content.
+ if name not in self._cache:
+ cb = OpenTelemetryMetricsHandler.Callback()
+ self._meter.create_observable_gauge(name=f'values.{name}', callback=cb)
+ self._cache[name] = cb
+ assert isinstance(self._cache[name], OpenTelemetryMetricsHandler.Callback)
+ self._cache[name].add(value=value, tags=tags)
class _Client(MetricsHandler):
@@ -57,12 +190,16 @@ class _Client(MetricsHandler):
def __init__(self):
self._handlers.append(_DefaultMetricsHandler())
+ # TODO(wangsen.0914): unify this behaviour to py_libs
+ self._handlers.append(OpenTelemetryMetricsHandler.new_handler())
- def emit_counter(self, name, value: int, tags: dict = None):
+ def emit_counter(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ _validate_tags(tags)
for handler in self._handlers:
handler.emit_counter(name, value, tags)
- def emit_store(self, name, value: int, tags: dict = None):
+ def emit_store(self, name, value: Union[int, float], tags: Dict[str, str] = None):
+ _validate_tags(tags)
for handler in self._handlers:
handler.emit_store(name, value, tags)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/metrics_test.py b/web_console_v2/api/fedlearner_webconsole/utils/metrics_test.py
new file mode 100644
index 000000000..dc7179586
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/metrics_test.py
@@ -0,0 +1,204 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import json
+import logging
+import unittest
+from io import StringIO
+from unittest.mock import patch
+from typing import Dict
+
+from opentelemetry import trace as otel_trace, _metrics as otel_metrics
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import BatchSpanProcessor
+from opentelemetry.sdk._metrics import MeterProvider
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace.export import ConsoleSpanExporter
+from opentelemetry.sdk._metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader
+
+from fedlearner_webconsole.utils import metrics
+from fedlearner_webconsole.utils.metrics import _DefaultMetricsHandler, MetricsHandler, OpenTelemetryMetricsHandler
+
+
+class _FakeMetricsHandler(MetricsHandler):
+
+ def emit_counter(self, name, value: int, tags: Dict[str, str] = None):
+ logging.info(f'[Test][Counter] {name} - {value}')
+
+ def emit_store(self, name, value: int, tags: Dict[str, str] = None):
+ logging.info(f'[Test][Store] {name} - {value}')
+
+
+class DefaultMetricsHandler(unittest.TestCase):
+
+ def setUp(self):
+ self._handler = _DefaultMetricsHandler()
+
+ def test_emit_counter(self):
+ with self.assertLogs() as cm:
+ self._handler.emit_counter('test', 1)
+ self._handler.emit_counter('test2', 2)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Counter] test: 1, tags={}', '[Metric][Counter] test2: 2, tags={}'])
+
+ def test_emit_store(self):
+ with self.assertLogs() as cm:
+ self._handler.emit_store('test', 199)
+ self._handler.emit_store('test2', 299)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Store] test: 199, tags={}', '[Metric][Store] test2: 299, tags={}'])
+
+
+class ClientTest(unittest.TestCase):
+
+ def setUp(self):
+ metrics.add_handler(_FakeMetricsHandler())
+
+ def tearDown(self):
+ metrics.reset_handlers()
+
+ def test_emit_counter(self):
+ with self.assertRaises(TypeError):
+ metrics.emit_counter('test', 1, tags={'name': 1})
+
+ with self.assertLogs() as cm:
+ metrics.emit_counter('test', 1)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Counter] test: 1, tags={}', '[Test][Counter] test - 1'])
+
+ def test_emit_store(self):
+ with self.assertRaises(TypeError):
+ metrics.emit_store('test', 1, tags={'name': 1})
+
+ with self.assertLogs() as cm:
+ metrics.emit_store('test', 199)
+ logs = [r.msg for r in cm.records]
+ self.assertEqual(logs, ['[Metric][Store] test: 199, tags={}', '[Test][Store] test - 199'])
+
+
+class OpenTelemetryMetricsHandlerClassMethodTest(unittest.TestCase):
+
+ def setUp(self):
+ self._span_out = StringIO()
+ self._span_exporter_patcher = patch('fedlearner_webconsole.utils.metrics.OTLPSpanExporter',
+ lambda **kwargs: ConsoleSpanExporter(out=self._span_out))
+ self._metric_out = StringIO()
+ self._metric_exporter_patcher = patch('fedlearner_webconsole.utils.metrics.OTLPMetricExporter',
+ lambda **kwargs: ConsoleMetricExporter(out=self._metric_out))
+ self._span_exporter_patcher.start()
+ self._metric_exporter_patcher.start()
+
+ def tearDown(self):
+ self._metric_exporter_patcher.stop()
+ self._span_exporter_patcher.stop()
+
+ def test_new_handler(self):
+ OpenTelemetryMetricsHandler.new_handler()
+ self.assertEqual(
+ otel_trace.get_tracer_provider().resource,
+ Resource(
+ attributes={
+ 'telemetry.sdk.language': 'python',
+ 'telemetry.sdk.name': 'opentelemetry',
+ 'telemetry.sdk.version': '1.10.0',
+ 'service.name': 'fedlearner_webconsole',
+ 'deployment.environment': 'default',
+ }))
+ self.assertEqual(
+ otel_metrics.get_meter_provider()._sdk_config.resource, # pylint: disable=protected-access
+ Resource(
+ attributes={
+ 'telemetry.sdk.language': 'python',
+ 'telemetry.sdk.name': 'opentelemetry',
+ 'telemetry.sdk.version': '1.10.0',
+ 'service.name': 'fedlearner_webconsole',
+ 'deployment.environment': 'default',
+ }))
+
+
+class OpenTelemetryMetricsHandlerTest(unittest.TestCase):
+
+ def setUp(self):
+ self._span_out = StringIO()
+ self._metric_out = StringIO()
+ tracer_provider = TracerProvider()
+ tracer_provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter(out=self._span_out)))
+ reader = PeriodicExportingMetricReader(ConsoleMetricExporter(out=self._metric_out),
+ export_interval_millis=60000)
+ meter_provider = MeterProvider(metric_readers=[reader])
+ self._tracer_provider = tracer_provider
+ self._meter_provider = meter_provider
+ self._handler = OpenTelemetryMetricsHandler(tracer=tracer_provider.get_tracer(__file__),
+ meter=meter_provider.get_meter(__file__))
+
+ def _force_flush(self):
+ self._meter_provider.force_flush()
+ self._metric_out.flush()
+ self._tracer_provider.force_flush()
+ self._span_out.flush()
+
+ def test_emit_store(self):
+ # Note that same instrument with different tags won't be aggregated.
+ # Aggregation rule for `emit_store` is delivering the last value of this interval.
+ # If no value at this interval, no `Metric` will be sent.
+ self._handler.emit_store(name='test_store', value=1, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._handler.emit_store(name='test_store', value=5, tags={'module': 'dataset', 'uuid': 'tag2'})
+ self._handler.emit_store(name='test_store', value=2, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self._force_flush()
+ self._force_flush()
+ self._handler.emit_store(name='test_store', value=0, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self.assertEqual(self._span_out.getvalue(), '')
+ self._metric_out.seek(0)
+ lines = self._metric_out.readlines()
+ measurements = []
+ for l in lines:
+ measurement = json.loads(l)
+ measurements.append(measurement)
+ self.assertEqual(len(measurements), 3)
+ self.assertEqual(measurements[0]['attributes'], {'uuid': 'tag1', 'module': 'dataset'})
+ self.assertEqual(measurements[1]['attributes'], {'uuid': 'tag2', 'module': 'dataset'})
+ self.assertEqual(measurements[0]['name'], 'values.test_store')
+ self.assertEqual([m['point']['value'] for m in measurements], [2, 5, 0])
+
+ def test_emit_counter(self):
+ # Note that same instrument with different tags won't be aggregated.
+ # Aggregation rule for `emit_counter` is delivering the accumulated \
+ # value with the same tags during this interval.
+ # If no value at this interval, a `Metric` with value of last interval will be sent.
+ self._handler.emit_counter(name='test_counter', value=1, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._handler.emit_counter(name='test_counter', value=5, tags={'module': 'dataset', 'uuid': 'tag2'})
+ self._handler.emit_counter(name='test_counter', value=2, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self._force_flush()
+ self._handler.emit_counter(name='test_counter', value=-1, tags={'module': 'dataset', 'uuid': 'tag1'})
+ self._force_flush()
+ self.assertEqual(self._span_out.getvalue(), '')
+ self._metric_out.seek(0)
+ lines = self._metric_out.readlines()
+ measurements = []
+ for l in lines:
+ measurement = json.loads(l)
+ measurements.append(measurement)
+ self.assertEqual(len(measurements), 6)
+ self.assertEqual(measurements[0]['attributes'], {'uuid': 'tag1', 'module': 'dataset'})
+ self.assertEqual(measurements[1]['attributes'], {'uuid': 'tag2', 'module': 'dataset'})
+ self.assertEqual(measurement['name'], 'values.test_counter')
+ self.assertEqual([m['point']['value'] for m in measurements], [3, 5, 3, 5, 2, 5])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/middlewares.py b/web_console_v2/api/fedlearner_webconsole/utils/middlewares.py
deleted file mode 100644
index cda20152d..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/middlewares.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import logging
-
-
-class _MiddlewareRegistry(object):
- def __init__(self):
- self.middlewares = []
-
- def register(self, middleware):
- self.middlewares.append(middleware)
-
-
-_middleware_registry = _MiddlewareRegistry()
-register = _middleware_registry.register
-
-
-def init_app(app):
- logging.info('Initializing app with middlewares')
- # Wraps app with middlewares
- for middleware in _middleware_registry.middlewares:
- app = middleware(app)
- return app
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/mixins.py b/web_console_v2/api/fedlearner_webconsole/utils/mixins.py
index 6ec201857..6206fbf87 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/mixins.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/mixins.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,13 +14,30 @@
# coding: utf-8
from typing import List, Dict, Callable
-from datetime import datetime, timezone
+from datetime import datetime
from enum import Enum
+
from sqlalchemy.ext.declarative import DeclarativeMeta
from google.protobuf.message import Message
from google.protobuf.json_format import MessageToDict
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+
+
+def _to_dict_value(value):
+ if isinstance(value, datetime):
+ return to_timestamp(value)
+ if isinstance(value, Message):
+ return MessageToDict(value, preserving_proto_field_name=True, including_default_value_fields=True)
+ if isinstance(value, Enum):
+ return value.name
+ if isinstance(value, list):
+ return [_to_dict_value(v) for v in value]
+ if hasattr(value, 'to_dict'):
+ return value.to_dict()
+ return value
+
def to_dict_mixin(ignores: List[str] = None,
extras: Dict[str, Callable] = None,
@@ -40,6 +57,7 @@ def _get_fields(self: object) -> List[str]:
def decorator(cls):
"""A decorator to add a to_dict method to a class."""
+
def to_dict(self: object):
"""A helper function to convert a class to dict."""
dic = {}
@@ -54,27 +72,7 @@ def to_dict(self: object):
dic[extra_key] = func(self)
# Converts type
for key in dic:
- value = dic[key]
- if isinstance(value, datetime):
- # If there is no timezone, we should treat it as
- # UTC datetime,otherwise it will be calculated
- # as local time when converting to timestamp.
- # Context: all datetime in db is UTC datetime,
- # see details in config.py#turn_db_timezone_to_utc
- if value.tzinfo is None:
- dic[key] = int(
- value.replace(tzinfo=timezone.utc).timestamp())
- else:
- dic[key] = int(value.timestamp())
- elif isinstance(value, Message):
- dic[key] = MessageToDict(
- value,
- preserving_proto_field_name=True,
- including_default_value_fields=True)
- elif isinstance(value, Enum):
- dic[key] = value.name
- elif hasattr(value, 'to_dict'):
- dic[key] = value.to_dict()
+ dic[key] = _to_dict_value(dic[key])
# remove None and emtry list and dict
if ignore_none:
@@ -86,33 +84,3 @@ def to_dict(self: object):
return cls
return decorator
-
-
-def from_dict_mixin(from_dict_fields: List[str] = None,
- required_fields: List[str] = None):
- if from_dict_fields is None:
- from_dict_fields = []
- if required_fields is None:
- required_fields = []
-
- def decorator(cls: object):
- @classmethod
- def from_dict(cls: object, content: dict):
- obj = cls() # pylint: disable=no-value-for-parameter
- for k in from_dict_fields:
- if k in content:
- current_type = type(getattr(obj, k))
- if hasattr(current_type, 'from_dict'):
- setattr(obj, k, current_type.from_dict(content[k]))
- else:
- setattr(obj, k, content[k])
- for k in required_fields:
- if getattr(obj, k) is None:
- raise ValueError(f'{type(obj)} should have attribute {k}')
-
- return obj
-
- setattr(cls, 'from_dict', from_dict)
- return cls
-
- return decorator
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/mixins_test.py b/web_console_v2/api/fedlearner_webconsole/utils/mixins_test.py
new file mode 100644
index 000000000..6561231be
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/mixins_test.py
@@ -0,0 +1,89 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from datetime import datetime, timezone
+
+from sqlalchemy.ext.declarative import declarative_base
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.utils.mixins import to_dict_mixin
+
+Base = declarative_base()
+
+
+@to_dict_mixin(ignores=['token', 'grpc_spec'],
+ extras={
+ 'grpc_spec': lambda model: model.get_grpc_spec(),
+ 'list': lambda _: ['hello', 'world']
+ })
+class DeclarativeClass(Base):
+ __tablename__ = 'just_a_test'
+
+ id = db.Column(db.Integer, primary_key=True)
+ name = db.Column(db.String(255))
+ token = db.Column('token_string', db.String(64), index=True, key='token')
+ updated_at = db.Column(db.DateTime(timezone=True))
+ grpc_spec = db.Column(db.Text())
+
+ def set_grpc_spec(self, proto):
+ self.grpc_spec = proto.SerializeToString()
+
+ def get_grpc_spec(self):
+ proto = common_pb2.GrpcSpec()
+ proto.ParseFromString(self.grpc_spec)
+ return proto
+
+
+@to_dict_mixin(to_dict_fields=['hhh'])
+class SpecifyColumnsClass(object):
+
+ def __init__(self) -> None:
+ self.hhh = None
+ self.not_include = None
+
+
+class MixinsTest(unittest.TestCase):
+
+ def test_to_dict_declarative_api(self):
+ # 2021/04/23 10:42:01 UTC
+ updated_at = datetime(2021, 4, 23, 10, 42, 1, tzinfo=timezone.utc)
+ updated_at_ts = int(updated_at.timestamp())
+ test_model = DeclarativeClass(id=123, name='test-model', token='test-token', updated_at=updated_at)
+ test_grpc_spec = common_pb2.GrpcSpec(authority='test-authority')
+ test_model.set_grpc_spec(test_grpc_spec)
+
+ self.assertDictEqual(
+ test_model.to_dict(), {
+ 'id': 123,
+ 'name': 'test-model',
+ 'updated_at': updated_at_ts,
+ 'grpc_spec': {
+ 'authority': 'test-authority',
+ },
+ 'list': ['hello', 'world']
+ })
+
+ def test_to_dict_specify_columns(self):
+ obj = SpecifyColumnsClass()
+ obj.hhh = 'hhh'
+ res = obj.to_dict()
+ self.assertEqual(len(res), 1)
+ self.assertTrue('hhh' in res)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/paginate.py b/web_console_v2/api/fedlearner_webconsole/utils/paginate.py
new file mode 100644
index 000000000..2bb6b60c8
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/paginate.py
@@ -0,0 +1,121 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+# Ref: https://github.com/pallets/flask-sqlalchemy/blob/main/src/flask_sqlalchemy/__init__.py
+from typing import Optional
+
+from math import ceil
+
+from sqlalchemy import func
+from sqlalchemy.orm import Query
+
+
+class Pagination:
+
+ def __init__(self, query: Query, page: int, page_size: int):
+ """Constructor for pagination
+
+ Args:
+ query (Query): A SQLAlchemy query
+ page (int): The selected page
+ page_size (int): The number of items on each page
+ """
+ self.query = query
+ self.page = page
+ self.page_size = page_size
+ self._total_of_items = None
+
+ def get_number_of_items(self) -> int:
+ """Get the total number of items in the query.
+
+ Returns:
+ The total of items from the original query.
+ """
+ if self._total_of_items is None:
+ # A raw query without any WHERE clause will result in a SQL statement without FROM clause
+ # Therefore, if there is not FROM clause detected, we use a subquery to count the items
+ # Ref: https://stackoverflow.com/questions/12941416/how-to-count-rows-with-select-count-with-sqlalchemy#comment118672248_57934541 # pylint:disable=line-too-long
+ # FYI: Even 1.4.35 did not resolve this issue
+ if ' FROM ' not in str(self.query).upper():
+ self._total_of_items = self.query.count()
+ else:
+ self._total_of_items = self.query.with_entities(func.count()).scalar()
+
+ return self._total_of_items
+
+ def get_items(self) -> iter:
+ """Get a "page" of items.
+ CAUTION: Returns all records if {self.page_size} is 0.
+
+ Returns:
+ An iterable contains {self.page_size} items on {self.page} page.
+ """
+ if self.page_size == 0:
+ return self.query.all()
+ return self.query.limit(self.page_size).offset((self.page - 1) * self.page_size).all()
+
+ def get_number_of_pages(self) -> int:
+ """Get the number of pages of the query according to the specified
+ per_page value.
+ CAUTION: Returns 1 if {self.page_size} is 0 and the query has records.
+
+ Returns:
+ The number of pages of all items from the original query.
+ """
+ if self.get_number_of_items() == 0:
+ return 0
+ if self.page_size == 0:
+ return 1
+ return int(ceil(self.get_number_of_items() / float(self.page_size)))
+
+ def get_metadata(self) -> dict:
+ """Get pagination metadata in a dictionary.
+
+ Returns:
+ A dictionary contains information needed for current page.
+ """
+ return {
+ 'current_page': self.page,
+ 'page_size': self.page_size,
+ 'total_pages': self.get_number_of_pages(),
+ 'total_items': self.get_number_of_items()
+ }
+
+
+def paginate(query: Query, page: Optional[int] = None, page_size: Optional[int] = None) -> Pagination:
+ """Paginate a query.
+
+ Check if page and page_size are valid and construct a new Pagination
+ object by a SQLAlchemy Query.
+ CAUTION: page starts at one
+
+ Args:
+ query (Query): Query to be paginated
+ page (int): Page selected in pagination (page >= 0)
+ page_size (int): Number of items on each page (page_size <= 100)
+
+ Returns:
+ A Pagination object contains the selected items and metadata.
+
+ Raises:
+ ValueError: page >= 1 and 0 <= page_size <= 100.
+ """
+ page = page or 1
+ page_size = page_size or 0
+ if not (page >= 1 and 0 <= page_size <= 100):
+ raise ValueError('page should be positive and page_size ranges between 0 and 100')
+
+ return Pagination(query, page, page_size)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/paginate_test.py b/web_console_v2/api/fedlearner_webconsole/utils/paginate_test.py
new file mode 100644
index 000000000..89f91243b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/paginate_test.py
@@ -0,0 +1,84 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+
+import unittest
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+def generate_workflow(state: WorkflowState = WorkflowState.RUNNING) -> Workflow:
+ return Workflow(state=state)
+
+
+class PaginateTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ workflows = [generate_workflow() for _ in range(7)]
+ workflows += [generate_workflow(state=WorkflowState.INVALID) for _ in range(7)]
+
+ with db.session_scope() as session:
+ session.bulk_save_objects(workflows)
+ session.commit()
+
+ def test_paginate(self):
+ with db.session_scope() as session:
+ query = session.query(Workflow).filter(Workflow.state == WorkflowState.RUNNING)
+ pagination = paginate(query, page=1, page_size=3)
+
+ self.assertEqual(3, pagination.get_number_of_pages())
+ self.assertEqual(3, len(pagination.get_items()))
+
+ pagination = paginate(query, page=3, page_size=3)
+
+ self.assertEqual(1, len(pagination.get_items()))
+
+ pagination = paginate(query, page=4, page_size=3)
+
+ self.assertEqual(0, len(pagination.get_items()))
+
+ def test_page_meta(self):
+ with db.session_scope() as session:
+ query = session.query(Workflow)
+ page_meta = paginate(query, page=1, page_size=3).get_metadata()
+ self.assertDictEqual(page_meta, {'current_page': 1, 'page_size': 3, 'total_pages': 5, 'total_items': 14})
+
+ query = session.query(Workflow).filter(Workflow.state == WorkflowState.RUNNING)
+ page_meta = paginate(query, page=1, page_size=3).get_metadata()
+ self.assertDictEqual(page_meta, {'current_page': 1, 'page_size': 3, 'total_pages': 3, 'total_items': 7})
+
+ page_meta = paginate(query, page=4, page_size=10).get_metadata()
+ self.assertDictEqual(page_meta, {'current_page': 4, 'page_size': 10, 'total_pages': 1, 'total_items': 7})
+
+ def test_fallback_page_size(self):
+ with db.session_scope() as session:
+ query = session.query(Workflow).filter(Workflow.state == WorkflowState.RUNNING)
+ pagination = paginate(query)
+
+ self.assertEqual(7, len(pagination.get_items()))
+ self.assertDictEqual(pagination.get_metadata(), {
+ 'current_page': 1,
+ 'page_size': 0,
+ 'total_pages': 1,
+ 'total_items': 7
+ })
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_base64.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_base64.py
new file mode 100644
index 000000000..1a588329c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_base64.py
@@ -0,0 +1,24 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from base64 import b64encode, b64decode
+
+
+def base64encode(s: str) -> str:
+ return b64encode(s.encode('UTF-8')).decode('UTF-8')
+
+
+def base64decode(s: str) -> str:
+ return b64decode(s).decode('UTF-8')
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_base64_test.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_base64_test.py
new file mode 100644
index 000000000..72db50be5
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_base64_test.py
@@ -0,0 +1,37 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+
+from fedlearner_webconsole.utils.pp_base64 import base64encode, base64decode
+
+
+class Base64Test(unittest.TestCase):
+
+ def test_base64encode(self):
+ self.assertEqual(base64encode('hello 1@2'), 'aGVsbG8gMUAy')
+ self.assertEqual(base64encode('😈'), '8J+YiA==')
+
+ def test_base64decode(self):
+ self.assertEqual(base64decode('aGVsbG8gMUAy'), 'hello 1@2')
+ self.assertEqual(base64decode('JjEzOVlUKiYm'), '&139YT*&&')
+
+ def test_base64_encode_and_decode(self):
+ self.assertEqual(base64decode(base64encode('test')), 'test')
+ self.assertEqual(base64encode(base64decode('aGVsbG8gMUAy')), 'aGVsbG8gMUAy')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_datetime.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_datetime.py
new file mode 100644
index 000000000..14c4a382f
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_datetime.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import Optional, Union
+from datetime import datetime, timezone
+from dateutil.parser import isoparse
+
+
+def to_timestamp(dt: Union[datetime, str]) -> int:
+ """Converts DB datetime to timestamp in second."""
+ # If there is no timezone, we should treat it as UTC datetime,
+ # otherwise it will be calculated as local time when converting
+ # to timestamp.
+ # Context: all datetime in db is UTC datetime,
+ # see details in config.py#turn_db_timezone_to_utc
+ if isinstance(dt, str):
+ dt = isoparse(dt)
+ if dt.tzinfo is None:
+ return int(dt.replace(tzinfo=timezone.utc).timestamp())
+ return int(dt.timestamp())
+
+
+def from_timestamp(timestamp: int) -> datetime:
+ """Converts timestamp to datetime with utc timezone."""
+ return datetime.fromtimestamp(timestamp, timezone.utc)
+
+
+def now(tz: Optional[timezone] = timezone.utc) -> datetime:
+ """A wrapper of datetime.now.
+
+ This is for easy testing, as datetime.now is referred by a lot
+ of components, mock that will break tests easily. Using this wrapper
+ so that developers can mock this function to get a fake datetime."""
+ return datetime.now(tz)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_datetime_test.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_datetime_test.py
new file mode 100644
index 000000000..d7c746829
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_datetime_test.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from datetime import datetime, timezone, timedelta
+
+from fedlearner_webconsole.utils.pp_datetime import from_timestamp, to_timestamp
+
+
+class DatetimeTest(unittest.TestCase):
+
+ def test_to_timestamp(self):
+ # 2020/12/17 13:58:59 UTC+8
+ dt_utc8 = datetime(2020, 12, 17, 13, 58, 59, tzinfo=timezone(timedelta(hours=8)))
+ # datetime will be stored without timezone info
+ dt_utc8_ts = int(dt_utc8.timestamp()) + 8 * 60 * 60
+ self.assertEqual(to_timestamp(dt_utc8.replace(tzinfo=None)), dt_utc8_ts)
+ # 2021/04/23 10:42:01 UTC
+ dt_utc = datetime(2021, 4, 23, 10, 42, 1, tzinfo=timezone.utc)
+ dt_utc_ts = int(dt_utc.timestamp())
+ self.assertEqual(to_timestamp(dt_utc), dt_utc_ts)
+
+ def test_from_timestamp(self):
+ # 2020/12/17 13:58:59 UTC+8
+ dt_utc8 = datetime(2020, 12, 17, 13, 58, 59, tzinfo=timezone(timedelta(hours=8)))
+ self.assertEqual(from_timestamp(to_timestamp(dt_utc8)), datetime(2020, 12, 17, 5, 58, 59, tzinfo=timezone.utc))
+ dt_utc = datetime(2021, 4, 23, 10, 42, 1, tzinfo=timezone.utc)
+ self.assertEqual(from_timestamp(to_timestamp(dt_utc)), datetime(2021, 4, 23, 10, 42, 1, tzinfo=timezone.utc))
+
+ def test_to_timestamp_with_str_input(self):
+ dt_str = '2021-04-15T10:43:15Z'
+ real_dt = datetime(2021, 4, 15, 10, 43, 15, tzinfo=timezone.utc)
+ ts = to_timestamp(dt_str)
+ self.assertEqual(real_dt.timestamp(), ts)
+
+ dt_str = '2021-09-24T17:58:27+08:00'
+ real_dt = datetime(2021, 9, 24, 17, 58, 27, tzinfo=timezone(timedelta(hours=8)))
+ ts = to_timestamp(dt_str)
+ self.assertEqual(real_dt.timestamp(), ts)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_flatten_dict.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_flatten_dict.py
new file mode 100644
index 000000000..78e6c7802
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_flatten_dict.py
@@ -0,0 +1,49 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+# pylint: disable=deprecated-class
+from typing import Mapping
+import six
+from flatten_dict.flatten_dict import dot_reducer
+
+
+def flatten(d):
+ """
+ Copied and modified from flatten_dict.flatten_dict, because the origin method
+ convert {'a': {'b': 1}} to {'a.b': 1}, but we want {'a': {'b': 1}, 'a.b': 1}
+
+ Flatten `Mapping` object.
+ """
+ flattenable_types = (Mapping,)
+ if not isinstance(d, flattenable_types):
+ raise ValueError(f'argument type {type(d)} is not in the flattenalbe types {flattenable_types}')
+
+ reducer = dot_reducer
+ flat_dict = {}
+
+ def _flatten(d, parent=None):
+ key_value_iterable = six.viewitems(d)
+ for key, value in key_value_iterable:
+ flat_key = reducer(parent, key)
+ if isinstance(value, flattenable_types):
+ flat_dict[flat_key] = value
+ if value:
+ # recursively build the result
+ _flatten(value, flat_key)
+ continue
+ flat_dict[flat_key] = value
+
+ _flatten(d)
+ return flat_dict
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_flatten_dict_test.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_flatten_dict_test.py
new file mode 100644
index 000000000..82e82fb28
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_flatten_dict_test.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from fedlearner_webconsole.utils.pp_flatten_dict import flatten
+
+
+class FlattenDictTestCase(unittest.TestCase):
+
+ def test_flatten(self):
+ self.assertEqual(flatten({'a': [1], 'b': {'c': 2}, 'd': 3}), {'a': [1], 'b': {'c': 2}, 'b.c': 2, 'd': 3})
+ self.assertEqual(flatten({'a': {
+ 'b': {
+ 'c': {
+ 'd': 1
+ }
+ }
+ }}), {
+ 'a': {
+ 'b': {
+ 'c': {
+ 'd': 1
+ }
+ }
+ },
+ 'a.b': {
+ 'c': {
+ 'd': 1
+ }
+ },
+ 'a.b.c': {
+ 'd': 1
+ },
+ 'a.b.c.d': 1
+ })
+
+ self.assertEqual(flatten({'a': {}}), {'a': {}})
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_time.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_time.py
new file mode 100644
index 000000000..221eee830
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_time.py
@@ -0,0 +1,24 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import time
+
+
+def sleep(seconds: int):
+ """A wrapper of time.sleep.
+ This is for easy testing, as time.sleep is referred by a lot
+ of components, mock that will break tests easily. Using this wrapper
+ so that developers can mock this function to fake the time tick."""
+ time.sleep(seconds)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_yaml.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_yaml.py
new file mode 100644
index 000000000..05b805686
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_yaml.py
@@ -0,0 +1,243 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+import json
+from ast import Attribute, Name, Subscript, Add, Call
+from string import Template
+from typing import Callable, List, Optional
+from simpleeval import EvalWithCompoundTypes
+from fedlearner_webconsole.utils.pp_flatten_dict import flatten
+from fedlearner_webconsole.setting.service import SettingService
+from fedlearner_webconsole.utils.const import DEFAULT_OWNER
+from fedlearner_webconsole.utils.system_envs import get_system_envs
+
+
+class _YamlTemplate(Template):
+ delimiter = '$'
+ # Which placeholders in the template should be interpreted
+ idpattern = r'[a-zA-Z_\-\[0-9\]]+(\.[a-zA-Z_\-\[0-9\]]+)*'
+
+
+def _format_yaml(yaml, **kwargs):
+ """Formats a yaml template.
+
+ Example usage:
+ format_yaml('{"abc": ${x.y}}', x={'y': 123})
+ output should be '{"abc": 123}'
+ """
+ template = _YamlTemplate(yaml)
+ try:
+ return template.substitute(flatten(kwargs or {}))
+ except KeyError as e:
+ raise RuntimeError(f'Unknown placeholder: {e.args[0]}') from e
+
+
+def _to_str(x=None) -> str:
+ if x is None:
+ return ''
+ if isinstance(x, dict):
+ return json.dumps(x)
+ return str(x)
+
+
+def _to_int(x=None) -> Optional[int]:
+ if x is None or x == '':
+ return None
+ try:
+ return int(float(x))
+ except Exception as e:
+ raise ValueError(f'{str(e)}. The input is: {x}') from e
+
+
+def _to_float(x=None) -> Optional[float]:
+ if x is None or x == '':
+ return None
+ try:
+ return float(x)
+ except Exception as e:
+ raise ValueError(f'{str(e)}. The input is: {x}') from e
+
+
+def _to_bool(x=None) -> bool:
+ if x is None or x == '':
+ return False
+ if isinstance(x, bool):
+ return x
+ if not isinstance(x, str):
+ raise ValueError(f'{x} can not convert boolean')
+ if x.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ if x.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ raise ValueError(f'{x} can not convert boolean')
+
+
+def _to_dict(x) -> dict:
+ if isinstance(x, dict):
+ return x
+ raise ValueError(f'{x} is not dict')
+
+
+def _to_list(x) -> list:
+ if isinstance(x, list):
+ return x
+ raise ValueError(f'{x} is not list')
+
+
+def _eval_attribute(self, node):
+ """
+ Copy from simpleeval, and modified the last exception raising to has more information about the attribute.
+ Before changed, the exception message would be "can't find 'c'", when can't find the attribute c in a.b.
+ Such as eval('a.b.c', names:{'a':{'b': {'d':1}}}).
+ After changed, the exception message will be "can't find 'a.b.c'".
+ """
+ max_depth = 10
+ for prefix in ['_', 'func_']:
+ if node.attr.startswith(prefix):
+ raise ValueError('Sorry, access to __attributes '
+ ' or func_ attributes is not available. '
+ f'({node.attr})')
+ if node.attr in ['format', 'format_map', 'mro']:
+ raise ValueError(f'Sorry, this method is not available. ({node.attr})')
+ # eval node
+ node_evaluated = self._eval(node.value) # pylint: disable=protected-access
+
+ # Maybe the base object is an actual object, not just a dict
+ try:
+ return getattr(node_evaluated, node.attr)
+ except (AttributeError, TypeError):
+ pass
+
+ if self.ATTR_INDEX_FALLBACK:
+ try:
+ return node_evaluated[node.attr]
+ except (KeyError, TypeError):
+ pass
+
+ # If it is neither, raise an exception
+ # Modified(xiangyuxuan.prs) from simpleeval to make the error message has more information.
+ pre_node = node.value
+ attr_chains = [node.attr]
+ for i in range(max_depth):
+ if not isinstance(pre_node, Attribute):
+ break
+ attr_chains.append(pre_node.attr)
+ pre_node = pre_node.value
+ if isinstance(pre_node, Name):
+ attr_chains.append(pre_node.id)
+ raise ValueError('.'.join(attr_chains[::-1]), self.expr)
+
+
+def compile_yaml_template(yaml_template: str,
+ post_processors: List[Callable],
+ ignore_variables: bool = False,
+ use_old_formater: bool = False,
+ **kwargs) -> dict:
+ """
+ Args:
+ yaml_template (str): The original string to format.
+ post_processors (List): List of methods to process the dict which yaml_template generated.
+ ignore_variables (bool): If True then Compile the yaml_template without any variables.
+ All variables will be treated as None. Such as: "{var_a.attr_a: 1, 'b': 2}" -> {None:1, 'b':2}
+ **kwargs: variables to format the yaml_template.
+ use_old_formater (bool): If True then use old ${} placeholder formatter.
+ Raises:
+ ValueError: foramte failed
+ Returns:
+ a dict which can submit to k8s.
+ """
+ # TODO(xiangyuxuan.prs): this is old version formatter, should be deleted after no flapp in used
+ if use_old_formater:
+ yaml_template = _format_yaml(yaml_template, **kwargs)
+ try:
+ # names={'true': True, 'false': False, 'null': None} support json symbol in python
+ eval_with_types = EvalWithCompoundTypes(names={'true': True, 'false': False, 'null': None, **kwargs})
+
+ # replace the built-in functions in eval stage,
+ # Ref: https://github.com/danthedeckie/simpleeval
+ if ignore_variables:
+ eval_with_types.nodes[Attribute] = lambda x: None
+ eval_with_types.nodes[Name] = lambda x: None
+ eval_with_types.nodes[Subscript] = lambda x: None
+ eval_with_types.nodes[Call] = lambda x: None
+ eval_with_types.operators[Add] = lambda x, y: None
+ return eval_with_types.eval(yaml_template)
+ eval_with_types.functions.update(str=_to_str, int=_to_int, bool=_to_bool, dict=_to_dict, list=_to_list)
+
+ # Overwrite to let the exceptions message have more information.
+ eval_with_types.nodes[Attribute] = lambda x: _eval_attribute(eval_with_types, x)
+ loaded_json = eval_with_types.eval(yaml_template)
+ except SyntaxError as e:
+ raise ValueError(f'Invalid python dict syntax error msg: {e.args}') from e
+ except Exception as e: # pylint: disable=broad-except
+ # use args[0] to simplify the error message
+ raise ValueError(f'Invalid python dict placeholder error msg: {e.args[0]}') from e
+ # post processor for flapp yaml
+ for post_processor in post_processors:
+ loaded_json = post_processor(loaded_json)
+ return loaded_json
+
+
+def add_username_in_label(loaded_json: dict, username: Optional[str] = None) -> dict:
+ if 'labels' not in loaded_json['metadata']:
+ loaded_json['metadata']['labels'] = {}
+ loaded_json['metadata']['labels']['owner'] = username or DEFAULT_OWNER
+ return loaded_json
+
+
+class GenerateDictService:
+
+ def __init__(self, session):
+ self._session = session
+
+ def generate_system_dict(self):
+ sys_vars_dict = SettingService(self._session).get_system_variables_dict()
+ # TODO(xiangyuxuan.prs): basic_envs is old method to inject the envs, delete in the future.
+ basic_envs_list = get_system_envs()
+ basic_envs = ','.join([json.dumps(env) for env in basic_envs_list])
+ version = SettingService(self._session).get_application_version().version.version
+ return {
+ 'basic_envs': basic_envs,
+ 'variables': sys_vars_dict,
+ 'basic_envs_list': basic_envs_list,
+ 'version': version
+ }
+
+
+def _envs_to_dict(flapp_envs: List[dict]) -> dict:
+ return {env['name']: env['value'] for env in flapp_envs}
+
+
+def extract_flapp_envs(flapp_json: dict) -> Optional[dict]:
+ """Extract flapp envs
+
+ Returns:
+ dict of environment variables under different type of pods is returned, e.g.
+ {'master': {'INPUT_BASE_DIR': '/data'}
+ 'worker': {'INPUT_DATA_FORMAT': 'TF_RECORD'}}
+ """
+ try:
+ if flapp_json['kind'] != 'FLApp':
+ return None
+ flapp_specs = flapp_json['spec']['flReplicaSpecs']
+ flapp_envs = {}
+ for role in flapp_specs:
+ assert len(flapp_specs[role]['template']['spec']['containers']) == 1
+ flapp_envs[role] = _envs_to_dict(flapp_specs[role]['template']['spec']['containers'][0]['env'])
+ return flapp_envs
+ except Exception as e: # pylint: disable=broad-except
+ logging.error(f'extracting environment variables with error {str(e)}')
+ return None
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/pp_yaml_test.py b/web_console_v2/api/fedlearner_webconsole/utils/pp_yaml_test.py
new file mode 100644
index 000000000..117d2b3ab
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/pp_yaml_test.py
@@ -0,0 +1,166 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from fedlearner_webconsole.utils.pp_yaml import compile_yaml_template, extract_flapp_envs, _to_bool, \
+ _to_dict, _to_int, \
+ _to_float, _to_list
+
+
+def test_postprocessor(loaded_json: dict):
+ loaded_json['test'] = 1
+ return loaded_json
+
+
+class YamlTestCase(unittest.TestCase):
+
+ def test_compile_yaml_template(self):
+ result = compile_yaml_template('{"test${a.b}": 1}', [], use_old_formater=True, a={'b': 'placeholder'})
+ self.assertEqual(result, {'testplaceholder': 1})
+
+ def test_compile_yaml_template_with_postprocessor(self):
+ result = compile_yaml_template('{"test${a.b}": 1}', [test_postprocessor],
+ use_old_formater=True,
+ a={'b': 'placeholder'})
+ self.assertEqual(result, {'testplaceholder': 1, 'test': 1})
+
+ def test_compile_yaml_template_with_list_merge(self):
+ result = compile_yaml_template('[{"test": false}]+${a.b}', [], use_old_formater=True, a={'b': '[{True: true}]'})
+ self.assertEqual(result, [{'test': False}, {True: True}])
+ result = compile_yaml_template('[{"test": false}]+${a.b}', [],
+ use_old_formater=True,
+ a={'b': [{
+ 'test': False
+ }]})
+ self.assertEqual(result, [{'test': False}, {'test': False}])
+ result = compile_yaml_template('${a.b}', [], use_old_formater=True, a={'b': {'v': 123}})
+ self.assertEqual(result, {'v': 123})
+
+ def test_extract_flapp_evs(self):
+ flapp_json = {
+ 'kind': 'FLApp',
+ 'spec': {
+ 'flReplicaSpecs': {
+ 'master': {
+ 'template': {
+ 'spec': {
+ 'containers': [{
+ 'env': [{
+ 'name': 'CODE_KEY',
+ 'value': 'test-code-key'
+ }]
+ }]
+ }
+ }
+ },
+ 'worker': {
+ 'template': {
+ 'spec': {
+ 'containers': [{
+ 'env': [{
+ 'name': 'CODE_TAR',
+ 'value': 'test-code-tar'
+ }, {
+ 'name': 'EPOCH_NUM',
+ 'value': '3'
+ }]
+ }]
+ }
+ }
+ }
+ }
+ }
+ }
+ flapp_envs = extract_flapp_envs(flapp_json)
+ expected_flapp_envs = {
+ 'master': {
+ 'CODE_KEY': 'test-code-key'
+ },
+ 'worker': {
+ 'CODE_TAR': 'test-code-tar',
+ 'EPOCH_NUM': '3'
+ }
+ }
+ self.assertEqual(flapp_envs, expected_flapp_envs)
+
+ def test_convert_built_in_functions(self):
+ self.assertEqual(_to_int(''), None)
+ self.assertEqual(_to_int('1.6'), 1)
+ self.assertEqual(_to_float(''), None)
+ self.assertEqual(_to_float('1.9'), 1.9)
+ self.assertEqual(_to_bool('0'), False)
+ self.assertEqual(_to_bool('false'), False)
+ with self.assertRaises(ValueError):
+ _to_dict('{}')
+ with self.assertRaises(ValueError):
+ _to_list('[]')
+ self.assertEqual(_to_list([]), [])
+ self.assertEqual(_to_dict({}), {})
+
+ def test_eval_attribute_exception(self):
+ with self.assertRaises(ValueError) as e:
+ compile_yaml_template(yaml_template='a.b.c', post_processors=[], a={'b': 'd'})
+ self.assertEqual(str(e.exception), 'Invalid python dict placeholder error msg: a.b.c')
+ self.assertEqual(compile_yaml_template(yaml_template='a.b.c', post_processors=[], a={'b': {'c': 1}}), 1)
+
+ def test_eval_syntax_exception(self):
+ with self.assertRaises(ValueError) as e:
+ compile_yaml_template(yaml_template='{,,}', post_processors=[], a={'b': 'd'})
+ self.assertEqual(
+ str(e.exception),
+ """Invalid python dict syntax error msg: ('invalid syntax', ('', 1, 2, '{,,}\\n'))""")
+
+ def test_compile_yaml_template_ignore_variables(self):
+ self.assertEqual(compile_yaml_template('jaweof', [], True), None)
+ self.assertEqual(compile_yaml_template('{asdf: 12312, aaa:333}', [], True), {None: 333})
+ self.assertEqual(compile_yaml_template('{asdf.a[1].b: 12312, "a": 3}', [], True), {None: 12312, 'a': 3})
+ test_yaml_tpl = """
+ {
+ "apiVersion": "fedlearner.k8s.io/v1alpha1",
+ "kind": "FLApp",
+ "metadata": {
+ "name": self.name,
+ "namespace": system.variables.namespace,
+ "labels": dict(system.variables.labels)
+ },
+ "containers": [
+ {
+ "env": system.basic_envs_list + [
+ {
+ "name": "EGRESS_HOST",
+ "value": project.participants[0].egress_host.lower()
+ },
+ {
+ "name": "OUTPUT_PARTITION_NUM",
+ "value": str(int(workflow.variables.partition_num))
+ },
+ {
+ "name": "OUTPUT_BASE_DIR",
+ "value": project.variables.storage_root_path + "/raw_data/" + self.name
+ },
+ {
+ "name": "RAW_DATA_METRICS_SAMPLE_RATE",
+ "value": str(asdfasdf)
+ }
+ ] + list(system.variables.volumes_list),
+ }
+ ]
+ }
+ """
+ self.assertEqual(compile_yaml_template(test_yaml_tpl, [], True).get('kind'), 'FLApp')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/process_utils.py b/web_console_v2/api/fedlearner_webconsole/utils/process_utils.py
new file mode 100644
index 000000000..446b85873
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/process_utils.py
@@ -0,0 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import queue
+from multiprocessing import get_context
+from typing import Optional, Callable, Any, Dict
+
+from fedlearner_webconsole.utils.hooks import pre_start_hook
+
+
+def _sub_process_wrapper(target: Optional[Callable[..., Any]], kwargs: Dict[str, Any]):
+ pre_start_hook()
+ target(**kwargs)
+
+
+def get_result_by_sub_process(name: str, target: Optional[Callable[..., Any]], kwargs: Dict[str, Any]):
+ context = get_context('spawn')
+ internal_queue = context.Queue()
+ kwargs['q'] = internal_queue
+ wrapper_args = {'target': target, 'kwargs': kwargs}
+ sub_process = context.Process(target=_sub_process_wrapper, kwargs=wrapper_args, daemon=True)
+ sub_process.start()
+ try:
+ result = internal_queue.get(timeout=60)
+ except queue.Empty as e:
+ sub_process.terminate()
+ raise RuntimeError(f'[subprocess] {name} task failed') from e
+ finally:
+ sub_process.join()
+ sub_process.close()
+ internal_queue.close()
+ logging.info(f'[subprocess]: {name} task finished')
+ return result
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/process_utils_test.py b/web_console_v2/api/fedlearner_webconsole/utils/process_utils_test.py
new file mode 100644
index 000000000..071618054
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/process_utils_test.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from multiprocessing import Queue
+
+from fedlearner_webconsole.utils.process_utils import get_result_by_sub_process
+
+
+def _fake_sub_process(num: int, q: Queue):
+ q.put([num, num + 1])
+
+
+class SubProcessTestCase(unittest.TestCase):
+
+ def test_sub_process(self):
+ result = get_result_by_sub_process(name='fake sub process', target=_fake_sub_process, kwargs={
+ 'num': 2,
+ })
+ self.assertEqual(result, [2, 3])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/proto.py b/web_console_v2/api/fedlearner_webconsole/utils/proto.py
new file mode 100644
index 000000000..faf814c88
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/proto.py
@@ -0,0 +1,141 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import copy
+from typing import Dict, Any
+
+from google.protobuf import json_format
+from google.protobuf.descriptor import FieldDescriptor
+from google.protobuf.message import Message
+from google.protobuf.struct_pb2 import Struct, Value
+
+from fedlearner_webconsole.proto.common.extension_pb2 import secret
+
+
+def _is_map(descriptor: FieldDescriptor) -> bool:
+ """Checks if a field is map or normal repeated field.
+
+ Inspired by https://github.com/protocolbuffers/protobuf/blob/3.6.x/python/google/protobuf/json_format.py#L159
+ """
+ return (descriptor.type == FieldDescriptor.TYPE_MESSAGE and descriptor.message_type.has_options and
+ descriptor.message_type.GetOptions().map_entry)
+
+
+def remove_secrets(proto: Message) -> Message:
+ """Removes secrete fields in proto."""
+ proto = copy.copy(proto)
+ field: FieldDescriptor
+ for field, value in proto.ListFields():
+ if field.type != FieldDescriptor.TYPE_MESSAGE:
+ # Clears field if it has secret annotation and its message is not message (no matter repeated or not)
+ if field.GetOptions().Extensions[secret]:
+ proto.ClearField(field.name)
+ continue
+ if field.label != FieldDescriptor.LABEL_REPEATED:
+ # Nested message
+ value.CopyFrom(remove_secrets(value))
+ continue
+ if _is_map(field):
+ # Checks value type
+ map_value_field: FieldDescriptor = field.message_type.fields_by_name['value']
+ for k in list(value.keys()):
+ if map_value_field.type == FieldDescriptor.TYPE_MESSAGE:
+ value[k].CopyFrom(remove_secrets(value[k]))
+ else:
+ value[k] = map_value_field.default_value
+ else:
+ # Replace the repeated field (list of message)
+ new_protos = [remove_secrets(m) for m in value]
+ del value[:]
+ value.extend(new_protos)
+
+ return proto
+
+
+_INT_TYPES = frozenset(
+ [FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_UINT32])
+
+
+def _normalize_singular(value: Any, proto_type: int) -> Any:
+ if proto_type in _INT_TYPES:
+ return int(value)
+ return value
+
+
+def _normalize_dict(dct: Dict, message: Message):
+ """Normalizes the dict in place.
+
+ Converts int64 type in dict to python int instead of string. Currently python
+ proto lib will convert int64 to str, ref: https://github.com/protocolbuffers/protobuf/issues/2954
+
+ So this is a hack to make the dict converting work as our expectation. If you do not want this
+ behavior someday, you can use extension in the field case by case."""
+ if isinstance(message, (Struct, Value)):
+ # For those well-known protobuf types, we do not normalize it as
+ # there are some magics.
+ return
+ descriptors = message.DESCRIPTOR.fields_by_name
+ for key in dct:
+ descriptor = descriptors.get(key)
+ # Defensively
+ if descriptor is None:
+ continue
+ # Repeated field
+ if descriptor.label == FieldDescriptor.LABEL_REPEATED:
+ nested = getattr(message, key)
+ if _is_map(descriptor):
+ # 1. Map
+ map_key_type: FieldDescriptor = descriptor.message_type.fields_by_name['key'].type
+ map_value_type: FieldDescriptor = descriptor.message_type.fields_by_name['value'].type
+ for k, v in dct[key].items():
+ if map_value_type == FieldDescriptor.TYPE_MESSAGE:
+ # If type of key of mapper is int,
+ # we should convert it from string back to int for fetching information from Message
+ k = _normalize_singular(k, map_key_type)
+ _normalize_dict(v, nested[k])
+ else:
+ dct[key][k] = _normalize_singular(v, map_value_type)
+ else:
+ # 2. List
+ for i, v in enumerate(dct[key]):
+ if descriptor.type == FieldDescriptor.TYPE_MESSAGE:
+ _normalize_dict(v, nested[i])
+ else:
+ dct[key][i] = _normalize_singular(v, descriptor.type)
+ continue
+ # Nested message
+ if descriptor.type == FieldDescriptor.TYPE_MESSAGE:
+ _normalize_dict(dct[key], getattr(message, key))
+ continue
+ # Singular field
+ dct[key] = _normalize_singular(dct[key], descriptor.type)
+
+
+def to_dict(proto: Message, with_secret: bool = True):
+ if not with_secret:
+ proto = remove_secrets(proto)
+ dct = json_format.MessageToDict(proto, preserving_proto_field_name=True, including_default_value_fields=True)
+ _normalize_dict(dct, proto)
+ return dct
+
+
+def to_json(proto: Message) -> str:
+ """Converts proto to json string."""
+ return json_format.MessageToJson(proto, preserving_proto_field_name=True)
+
+
+def parse_from_json(json_str: str, proto: Message) -> Message:
+ """Parses json string to a proto."""
+ return json_format.Parse(json_str or '{}', proto, ignore_unknown_fields=True)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/proto_test.py b/web_console_v2/api/fedlearner_webconsole/utils/proto_test.py
new file mode 100644
index 000000000..24e6709b0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/proto_test.py
@@ -0,0 +1,211 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=unsupported-assignment-operation
+import json
+import unittest
+
+from google.protobuf.struct_pb2 import Struct, Value, ListValue
+
+from fedlearner_webconsole.proto.testing.testing_pb2 import PrivateInfo, RichMessage, Tdata, Int64Message, StructWrapper
+from fedlearner_webconsole.utils.proto import remove_secrets, to_dict, to_json, parse_from_json
+
+
+class ProtoTest(unittest.TestCase):
+
+ def test_remove_secrets(self):
+ proto = RichMessage(
+ field1='f1',
+ field2=123,
+ pinfo=PrivateInfo(pii='pii', non_pii='non pii'),
+ infos=[PrivateInfo(pii='only pii'), PrivateInfo(non_pii='only non pii')],
+ pinfo_map={
+ 'k1': PrivateInfo(non_pii='hello non pii'),
+ 'k2': PrivateInfo(pii='hello pii')
+ },
+ pstring_map={'s1': 'v1'},
+ pstring_list=['p1'],
+ )
+ proto_without_secret = RichMessage(
+ field1='f1',
+ pinfo=PrivateInfo(non_pii='non pii'),
+ infos=[PrivateInfo(), PrivateInfo(non_pii='only non pii')],
+ pinfo_map={
+ 'k1': PrivateInfo(non_pii='hello non pii'),
+ 'k2': PrivateInfo()
+ },
+ pstring_map={'s1': ''},
+ )
+ self.assertEqual(remove_secrets(proto), proto_without_secret)
+
+ def test_to_dict_with_secret(self):
+ proto = RichMessage(field1='f1',
+ field2=123,
+ pinfo=PrivateInfo(pii='pii', non_pii='non pii'),
+ infos=[PrivateInfo(pii='only pii'),
+ PrivateInfo(non_pii='only non pii')],
+ pinfo_map={
+ 'k1': PrivateInfo(non_pii='hello non pii'),
+ 'k2': PrivateInfo(pii='hello pii')
+ },
+ pstring_map={'s1': 'v1'},
+ pstring_list=['p1'])
+ self.assertEqual(
+ to_dict(proto), {
+ 'field1': 'f1',
+ 'field2': 123,
+ 'infos': [{
+ 'pii': 'only pii',
+ 'non_pii': ''
+ }, {
+ 'pii': '',
+ 'non_pii': 'only non pii'
+ }],
+ 'pinfo': {
+ 'pii': 'pii',
+ 'non_pii': 'non pii'
+ },
+ 'pinfo_map': {
+ 'k1': {
+ 'non_pii': 'hello non pii',
+ 'pii': ''
+ },
+ 'k2': {
+ 'non_pii': '',
+ 'pii': 'hello pii'
+ }
+ },
+ 'pstring_map': {
+ 's1': 'v1'
+ },
+ 'pstring_list': ['p1']
+ })
+ self.assertEqual(
+ to_dict(proto, with_secret=False), {
+ 'field1': 'f1',
+ 'field2': 0,
+ 'infos': [{
+ 'pii': '',
+ 'non_pii': ''
+ }, {
+ 'pii': '',
+ 'non_pii': 'only non pii'
+ }],
+ 'pinfo': {
+ 'pii': '',
+ 'non_pii': 'non pii'
+ },
+ 'pinfo_map': {
+ 'k1': {
+ 'non_pii': 'hello non pii',
+ 'pii': ''
+ },
+ 'k2': {
+ 'non_pii': '',
+ 'pii': ''
+ }
+ },
+ 'pstring_map': {
+ 's1': ''
+ },
+ 'pstring_list': []
+ })
+
+ def test_to_dict_int64(self):
+ proto = Int64Message(id=123456789,
+ uuid='123123',
+ project_id=666,
+ data=[Tdata(id=987), Tdata(projects=[1, 2, 3])])
+ self.assertEqual(
+ to_dict(proto), {
+ 'uuid':
+ '123123',
+ 'project_id':
+ 666,
+ 'id':
+ 123456789,
+ 'data': [
+ {
+ 'id': 987,
+ 'mappers': {},
+ 'projects': [],
+ 'tt': 'UNSPECIFIED',
+ },
+ {
+ 'id': 0,
+ 'mappers': {},
+ 'projects': [1, 2, 3],
+ 'tt': 'UNSPECIFIED',
+ },
+ ]
+ })
+
+ def test_to_dict_struct(self):
+ list_value = ListValue(values=[Value(string_value='string in list')])
+ nested_struct = Struct()
+ nested_struct['haha'] = 2.33
+ struct = Struct()
+ struct['nested_list'] = list_value
+ struct['nested_struct'] = nested_struct
+
+ struct_wrapper = StructWrapper(typed_value=Value(string_value='str'), struct=struct)
+ self.assertEqual(to_dict(struct_wrapper), {
+ 'typed_value': 'str',
+ 'struct': {
+ 'nested_list': ['string in list'],
+ 'nested_struct': {
+ 'haha': 2.33
+ }
+ }
+ })
+
+ def test_to_json(self):
+ proto = RichMessage(field1='field1',
+ field2=123123,
+ pinfo=PrivateInfo(pii='pii', non_pii='non pii'),
+ pstring_map={'s1': 'v1'},
+ pstring_list=['p1'])
+ self.assertEqual(
+ json.loads(to_json(proto)), {
+ 'field1': 'field1',
+ 'field2': 123123,
+ 'pinfo': {
+ 'pii': 'pii',
+ 'non_pii': 'non pii'
+ },
+ 'pstring_map': {
+ 's1': 'v1'
+ },
+ 'pstring_list': ['p1']
+ })
+
+ def test_parse_from_json(self):
+ proto = RichMessage(field1='field1', field2=123123, pstring_map={'s1': 'v1'}, pstring_list=['p1'])
+ self.assertEqual(
+ to_dict(
+ parse_from_json(
+ json.dumps({
+ 'field1': 'field1',
+ 'field2': 123123,
+ 'pstring_map': {
+ 's1': 'v1'
+ },
+ 'pstring_list': ['p1'],
+ 'unknown_f': '123'
+ }), RichMessage())), to_dict(proto))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/resource_name.py b/web_console_v2/api/fedlearner_webconsole/utils/resource_name.py
new file mode 100644
index 000000000..d2668332b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/resource_name.py
@@ -0,0 +1,28 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from uuid import uuid4
+
+
+def resource_uuid() -> str:
+ """Build resource uuid
+ Returns:
+ A DNS-1035 label. A DNS-1035 label must start with an
+ alphabetic character. Since k8s resource name is limited to 64 chars,
+ job_def name is limited to 24 chars and pod name suffix is limit to
+ 19 chars, 20 chars are left for uuid.
+ substring uuid[:19] has no collision in 10 million draws.
+ """
+ return f'u{uuid4().hex[:19]}'
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/resource_name_test.py b/web_console_v2/api/fedlearner_webconsole/utils/resource_name_test.py
new file mode 100644
index 000000000..c1818a5f2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/resource_name_test.py
@@ -0,0 +1,29 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+
+
+class UtilsTest(unittest.TestCase):
+
+ def test_resource_uuid(self):
+ uuid = resource_uuid()
+ self.assertEqual(len(uuid), 20)
+ self.assertEqual(uuid[0], 'u')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/schema.py b/web_console_v2/api/fedlearner_webconsole/utils/schema.py
new file mode 100644
index 000000000..c8161c53b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/schema.py
@@ -0,0 +1,104 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+
+from fedlearner_webconsole.exceptions import InternalException
+
+_SPARK_TO_JSON = {
+ 'integer': 'integer',
+ 'long': 'integer',
+ 'short': 'integer',
+ 'float': 'number',
+ 'double': 'number',
+ 'string': 'string',
+ 'binary': 'string',
+ 'boolean': 'boolean',
+ 'null': 'null',
+}
+
+
+def spark_schema_to_json_schema(spark_schema: Optional[dict]):
+ """
+ all fields in spark schema are deemed required in json schema
+ any fields not in spark schema is deemed forbidden in json schema
+ type convert from spark schema to json schema by _SPARK_TO_JSON
+ Ref: https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.types.StructField.html
+ Ref: https://json-schema.org/learn/getting-started-step-by-step.html
+
+ e.g.
+ [from] spark schema:
+ {
+ 'type': 'struct',
+ 'fields': [
+ {
+ 'name': 'raw_id',
+ 'type': 'integer',
+ 'nullable': True,
+ 'metadata': {}
+ },
+ {
+ 'name': 'f01',
+ 'type': 'float',
+ 'nullable': True,
+ 'metadata': {}
+ },
+ {
+ 'name': 'image',
+ 'type': 'binary',
+ 'nullable': True,
+ 'metadata': {}
+ }
+ ]
+ }
+
+ [to] json schema:
+ {
+ 'type': 'object',
+ 'properties':{
+ 'raw_id': {
+ 'type': 'integer'
+ },
+ 'f01': {
+ 'type': 'number'
+ },
+ 'image': {
+ 'type': 'string'
+ }
+ },
+ 'additionalProperties': False,
+ 'required': [
+ 'raw_id',
+ 'f01',
+ 'image'
+ ]
+ }
+ """
+ if spark_schema is None:
+ return {}
+ properties = {}
+ required = []
+ fields = spark_schema.get('fields')
+ for field in fields:
+ name = field.get('name')
+ field_type = field.get('type')
+ json_type = _SPARK_TO_JSON.get(field_type)
+ if json_type is None:
+ raise InternalException(
+ f'spark schema to json schema convert failed! reason: unrecognized type [{field_type}]')
+ properties[name] = {'type': json_type}
+ required.append(name)
+ json_schema = {'type': 'object', 'properties': properties, 'additionalProperties': False, 'required': required}
+ return json_schema
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/schema_test.py b/web_console_v2/api/fedlearner_webconsole/utils/schema_test.py
new file mode 100644
index 000000000..a32f18fd2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/schema_test.py
@@ -0,0 +1,74 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.utils.schema import spark_schema_to_json_schema
+
+
+class SchemaConvertTest(unittest.TestCase):
+
+ def test_spark_schema_to_json_schema(self):
+ spark_schema = {
+ 'type':
+ 'struct',
+ 'fields': [{
+ 'name': 'raw_id',
+ 'type': 'integer',
+ 'nullable': True,
+ 'metadata': {}
+ }, {
+ 'name': 'f01',
+ 'type': 'float',
+ 'nullable': True,
+ 'metadata': {}
+ }, {
+ 'name': 'image',
+ 'type': 'binary',
+ 'nullable': True,
+ 'metadata': {}
+ }, {
+ 'name': 'hight',
+ 'type': 'long',
+ 'nullable': True,
+ 'metadata': {}
+ }]
+ }
+
+ json_schema = {
+ 'type': 'object',
+ 'properties': {
+ 'raw_id': {
+ 'type': 'integer'
+ },
+ 'f01': {
+ 'type': 'number'
+ },
+ 'image': {
+ 'type': 'string'
+ },
+ 'hight': {
+ 'type': 'integer'
+ }
+ },
+ 'additionalProperties': False,
+ 'required': ['raw_id', 'f01', 'image', 'hight']
+ }
+ res = spark_schema_to_json_schema(spark_schema)
+ self.assertEqual(res, json_schema)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/sorting.py b/web_console_v2/api/fedlearner_webconsole/utils/sorting.py
new file mode 100644
index 000000000..31867083b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/sorting.py
@@ -0,0 +1,60 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import re
+from typing import NamedTuple, Type, List
+
+from sqlalchemy import asc, desc
+from sqlalchemy.orm import Query
+
+from fedlearner_webconsole.db import db
+
+_REGEX = re.compile(r'^([a-zA-Z0-9._\-]+)\s(asc|desc)$')
+
+
+class SortExpression(NamedTuple):
+ is_asc: bool
+ field: str
+
+
+def parse_expression(exp: str) -> SortExpression:
+ matches = _REGEX.match(exp)
+ if not matches:
+ error_message = f'[SortExpression] unsupported expression {exp}'
+ logging.error(error_message)
+ raise ValueError(error_message)
+ is_asc = True
+ if matches.group(2) == 'desc':
+ is_asc = False
+ return SortExpression(field=matches.group(1), is_asc=is_asc)
+
+
+class SorterBuilder(object):
+
+ def __init__(self, model_class: Type[db.Model], supported_fields: List[str]):
+ self.model_class = model_class
+ for field in supported_fields:
+ assert getattr(self.model_class, field, None) is not None, f'{field} is not a column key'
+ self.supported_fields = set(supported_fields)
+
+ def build_query(self, query: Query, exp: SortExpression) -> Query:
+ if exp.field not in self.supported_fields:
+ raise ValueError(f'[SortExpression] unsupported field: {exp.field}')
+ column = getattr(self.model_class, exp.field)
+ order_fn = asc
+ if not exp.is_asc:
+ order_fn = desc
+ return query.order_by(order_fn(column))
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/sorting_test.py b/web_console_v2/api/fedlearner_webconsole/utils/sorting_test.py
new file mode 100644
index 000000000..0b1984afe
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/sorting_test.py
@@ -0,0 +1,76 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.db import db, default_table_args
+from fedlearner_webconsole.utils.sorting import parse_expression, SortExpression, SorterBuilder
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class ParseExpressionTest(unittest.TestCase):
+
+ def test_invalid_exp(self):
+ # No space
+ with self.assertRaises(ValueError):
+ parse_expression('fieldasc')
+ # Invalid field name
+ with self.assertRaises(ValueError):
+ parse_expression('field你 asc')
+ # Invalid asc sign
+ with self.assertRaises(ValueError):
+ parse_expression('fiele dasc')
+
+ def test_valid_exp(self):
+ self.assertEqual(parse_expression('ff_GG-1 asc'), SortExpression(field='ff_GG-1', is_asc=True))
+ self.assertEqual(parse_expression('f.a.b desc'), SortExpression(field='f.a.b', is_asc=False))
+
+
+class TestModel(db.Model):
+ __tablename__ = 'test_table'
+ __table_args__ = (default_table_args('Test table'))
+ id = db.Column(db.Integer, primary_key=True, autoincrement=True)
+ amount = db.Column(db.Float, default=0)
+
+
+class SorterBuilderTest(NoWebServerTestCase):
+
+ def test_supported_field(self):
+ with self.assertRaises(AssertionError):
+ SorterBuilder(model_class=TestModel, supported_fields=['id', 'amount', 'non-existing'])
+
+ def test_build_query(self):
+ self.maxDiff = None
+ builder = SorterBuilder(model_class=TestModel, supported_fields=['id', 'amount'])
+ with db.session_scope() as session:
+ query = session.query(TestModel)
+ # Invalid one
+ sort_exp = SortExpression(field='f1', is_asc=True)
+ with self.assertRaisesRegex(ValueError, 'unsupported field: f1'):
+ builder.build_query(query, sort_exp)
+ # Valid ones
+ sort_exp = SortExpression(field='id', is_asc=True)
+ statement = self.generate_mysql_statement(builder.build_query(query, sort_exp))
+ self.assertEqual(statement, 'SELECT test_table.id, test_table.amount \n'
+ 'FROM test_table ORDER BY test_table.id ASC')
+ sort_exp = SortExpression(field='amount', is_asc=False)
+ statement = self.generate_mysql_statement(builder.build_query(query, sort_exp))
+ self.assertEqual(
+ statement, 'SELECT test_table.id, test_table.amount \n'
+ 'FROM test_table ORDER BY test_table.amount DESC')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/stream_tars.py b/web_console_v2/api/fedlearner_webconsole/utils/stream_tars.py
new file mode 100644
index 000000000..7087c99d1
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/stream_tars.py
@@ -0,0 +1,166 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# coding: utf-8
+# pylint: disable=redefined-builtin, no-else-continue, broad-except, consider-using-with
+import os
+from io import BytesIO
+from tarfile import TarFile, NUL, BLOCKSIZE, TarInfo
+import tempfile
+import gzip
+import logging
+from typing import BinaryIO, AnyStr, Union
+
+from fedlearner_webconsole.utils.file_manager import FileManager, FILE_PREFIX
+
+CHUNK_SIZE = 1 << 22
+
+
+class FileStream:
+
+ def __init__(self):
+ self.buffer = BytesIO()
+ self.offset = 0
+
+ def write(self, s: AnyStr):
+ self.buffer.write(s)
+ self.offset += len(s)
+
+ def tell(self):
+ return self.offset
+
+ def close(self):
+ self.buffer.close()
+
+ def read_all(self):
+ try:
+ return self.buffer.getvalue()
+ finally:
+ self.buffer.close()
+ self.buffer = BytesIO()
+
+
+class _TarFileWithoutCache(TarFile):
+ """ Building a tar file chunk-by-chunk.
+ """
+
+ def __init__(self, directories: Union[str, list], file_chunk_size: int = CHUNK_SIZE): # pylint: disable=super-init-not-called
+ self._contents = [directories]
+ self._file_chunk_size = file_chunk_size
+ self._is_multiple = False
+ if isinstance(directories, list):
+ self._is_multiple = True
+ self._contents = directories
+
+ @staticmethod
+ def _stream_file_into_tar(tarinfo: TarInfo, tar: TarFile, fh: BinaryIO, buf_size: int):
+ out = tar.fileobj
+
+ for b in iter(lambda: fh.read(buf_size), b''):
+ out.write(b)
+ yield
+
+ blocks, remainder = divmod(tarinfo.size, BLOCKSIZE)
+ if remainder > 0:
+ out.write(NUL * (BLOCKSIZE - remainder))
+ blocks += 1
+ tar.offset += blocks * BLOCKSIZE
+ yield
+
+ def __iter__(self):
+ out = FileStream()
+ tar = TarFile(fileobj=out, mode='w')
+ for content in self._contents:
+ if os.path.isdir(content):
+ prefix, name = os.path.split(content)
+ prefix_len = len(prefix) + len(os.path.sep)
+ tar.add(name=content, arcname=name, recursive=False)
+ for path, dirs, files in os.walk(content):
+ arcpath = path[prefix_len:]
+ # Add files
+ # Use this script instead of tar.add() to avoid the non-fixed memory usage caused by the invoke of
+ # tar.addfile(), which will cache tarinfo in TarFile.members
+ for f in files:
+ filepath = os.path.join(path, f)
+ with open(filepath, 'rb') as fh:
+ tarinfo = tar.gettarinfo(name=filepath, arcname=os.path.join(arcpath, f), fileobj=fh)
+ tar.addfile(tarinfo)
+ for _ in self._stream_file_into_tar(tarinfo, tar, fh, self._file_chunk_size):
+ yield out.read_all()
+
+ # Add directories
+ for d in dirs:
+ tar.add(name=os.path.join(path, d), arcname=os.path.join(arcpath, d), recursive=False)
+ yield out.read_all()
+ else:
+ filepath = content
+ filename = os.path.basename(filepath)
+ with open(filepath, 'rb') as fh:
+ tarinfo = tar.gettarinfo(name=filepath, arcname=filename, fileobj=fh)
+ tar.addfile(tarinfo)
+ for _ in self._stream_file_into_tar(tarinfo, tar, fh, self._file_chunk_size):
+ yield out.read_all()
+
+ tar.close()
+ yield out.read_all()
+ out.close()
+
+
+class StreamingTar(object):
+ """ Building a tar file chunk-by-chunk.
+ """
+
+ def __init__(self, fm: FileManager, chunksize: int = CHUNK_SIZE) -> None:
+ super().__init__()
+ self._fm = fm
+ self.chunksize = chunksize
+
+ def _archive(self, source_path: Union[str, list], target_path: str):
+ logging.info(f'will archive {source_path} to {target_path}')
+ tarfile = _TarFileWithoutCache(source_path, self.chunksize)
+ with open(target_path, 'wb') as target_f:
+ for chunk in tarfile:
+ target_f.write(chunk)
+
+ def _compress(self, filename: str, target_path: str):
+ with open(filename, 'rb') as tar_f:
+ with gzip.GzipFile(target_path, 'wb') as gzip_f:
+ stream = tar_f.read(self.chunksize)
+ while stream:
+ gzip_f.write(stream)
+ stream = tar_f.read(self.chunksize)
+
+ # TODO(lixiaoguang.01): remove this function after using FileManager
+ def _trim_prefix(self, path: str) -> str:
+ if path.startswith(FILE_PREFIX):
+ return path.split(FILE_PREFIX, 1)[1]
+ return path
+
+ # TODO(zeju): provide tar file in-memory option
+ def archive(self, source_path: Union[str, list], target_path: str, gzip_compress: bool = False):
+ # TODO(lixiaoguang.01): use FileManager in archive and compress
+ if isinstance(source_path, str):
+ source_path = self._trim_prefix(source_path)
+ else: # list
+ trimmed_source_path = []
+ for single_path in source_path:
+ trimmed_source_path.append(self._trim_prefix(single_path))
+ source_path = trimmed_source_path
+ target_path = self._trim_prefix(target_path)
+
+ with tempfile.NamedTemporaryFile('wb') as temp:
+ if gzip_compress:
+ self._archive(source_path, temp.name)
+ self._compress(temp.name, target_path)
+ else:
+ self._archive(source_path, target_path)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/stream_untars.py b/web_console_v2/api/fedlearner_webconsole/utils/stream_untars.py
new file mode 100644
index 000000000..28a48f33e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/stream_untars.py
@@ -0,0 +1,145 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# coding: utf-8
+# pylint: disable=redefined-builtin, no-else-continue, broad-except, consider-using-f-string, consider-using-with
+import tarfile
+import gzip
+from tarfile import (BLOCKSIZE, TarFile, ReadError, EOFHeaderError, InvalidHeaderError, EmptyHeaderError,
+ TruncatedHeaderError, SubsequentHeaderError)
+import tempfile
+import logging
+from typing import BinaryIO
+
+from fedlearner_webconsole.utils.file_manager import FileManager, FILE_PREFIX
+
+CHUNK_SIZE = 1 << 22
+
+TAR_SUFFIX = ('.tar',)
+GZIP_SUFFIX = ('.gz', '.tgz')
+
+
+class _TarFileWithoutCache(TarFile):
+
+ def next(self):
+ self._check('ra')
+ if self.firstmember is not None:
+ m = self.firstmember
+ self.firstmember = None
+ return m
+
+ # Advance the file pointer.
+ if self.offset != self.fileobj.tell():
+ self.fileobj.seek(self.offset - 1)
+ if not self.fileobj.read(1):
+ raise tarfile.ReadError('unexpected end of data')
+
+ # Read the next block.
+ tarinfo = None
+ while True:
+ try:
+ tarinfo = self.tarinfo.fromtarfile(self)
+ except EOFHeaderError as e:
+ if self.ignore_zeros:
+ self._dbg(2, '0x%X: %s' % (self.offset, e))
+ self.offset += BLOCKSIZE
+ continue
+ except InvalidHeaderError as e:
+ if self.ignore_zeros:
+ self._dbg(2, '0x%X: %s' % (self.offset, e))
+ self.offset += BLOCKSIZE
+ continue
+ elif self.offset == 0:
+ raise ReadError(str(e)) from e
+ except EmptyHeaderError as e:
+ if self.offset == 0:
+ raise ReadError('empty file') from e
+ except TruncatedHeaderError as e:
+ if self.offset == 0:
+ raise ReadError(str(e)) from e
+ except SubsequentHeaderError as e:
+ raise ReadError(str(e)) from e
+ break
+
+ if tarinfo is None:
+ self._loaded = True
+
+ return tarinfo
+
+
+class StreamingUntar(object):
+ """
+ A class used to support decompressing .tar.gz streamly.
+ 1. The first step is to decompress the gzip file, chunk by chunk, to the tarball
+ 2. Then use TarFileWithoutCache to untar the tarball with a fixed memory usage.
+ 3. TarFileWithoutCache is a subclass of TarFile, but remove the cache in its next function.
+ eg:
+ convert xxx.tar.gz -> xxx
+ """
+
+ def __init__(self, fm: FileManager, chunksize: int = CHUNK_SIZE) -> None:
+ super().__init__()
+ self._fm = fm
+ self.chunksize = chunksize
+
+ def _uncompressed(self, source: str, temp_file: BinaryIO) -> str:
+ try:
+ with gzip.GzipFile(source, 'rb') as gf:
+ stream = gf.read(self.chunksize)
+ while stream:
+ temp_file.write(stream)
+ stream = gf.read(self.chunksize)
+ except Exception as e:
+ logging.error(f'failed to streaming decompress file from:{source}, ex: {e}')
+ return temp_file.name
+
+ def _untar(self, source: str, dest: str) -> None:
+ tar = _TarFileWithoutCache.open(source)
+ try:
+ entry = tar.next()
+ while entry:
+ tar.extract(entry, path=dest)
+ entry = tar.next()
+ except Exception as e:
+ logging.error(f'failed to streaming untar file, from {source} to {dest}, ex: {e}')
+ finally:
+ tar.close()
+
+ # TODO(lixiaoguang.01): remove this function after using FileManager
+ def _trim_prefix(self, path: str) -> str:
+ if path.startswith(FILE_PREFIX):
+ return path.split(FILE_PREFIX, 1)[1]
+ return path
+
+ def untar(self, source: str, dest: str) -> None:
+ """
+ untar the source.tar.gz to the dest directory, with a fixed memory usage.
+
+ Args:
+ source: source path, only support local file system
+ dest: destination path, only support local file system
+
+ Raises:
+ ValueError: if tarfile not ends with .tar/.tar.gz
+ Exception: if io operation failed
+ """
+ # TODO(lixiaoguang.01): use FileManager in untar and uncompressed
+ source = self._trim_prefix(source)
+ dest = self._trim_prefix(dest)
+
+ if not source.endswith(TAR_SUFFIX + GZIP_SUFFIX):
+ raise ValueError(f'{source} is not ends with tarfile or gzip extension')
+ with tempfile.NamedTemporaryFile('wb') as temp:
+ if source.endswith(GZIP_SUFFIX):
+ source = self._uncompressed(source, temp)
+ self._untar(source, dest)
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/swagger.py b/web_console_v2/api/fedlearner_webconsole/utils/swagger.py
new file mode 100644
index 000000000..540b4170c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/swagger.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pathlib import Path
+
+
+def replace_ref_name(schema: dict, ref_name: str, message_name: str) -> dict:
+ for k, v in schema.items():
+ if isinstance(v, dict):
+ schema[k] = replace_ref_name(v, ref_name, message_name)
+ if '$ref' in schema and schema['$ref'] == f'#/definitions/{message_name}':
+ schema['$ref'] = f'#/definitions/{ref_name}'
+ return schema
+
+
+def remove_title(schema: dict) -> dict:
+ for k, v in schema.items():
+ if isinstance(v, dict):
+ schema[k] = remove_title(v)
+ if 'title' in schema:
+ del schema['title']
+ return schema
+
+
+def normalize_schema(definitions: dict, jsonschema_path: Path) -> dict:
+ # "prefix_schema_files_with_package" option in Makefile will generate a directory with
+ # the name of the corresponding package name, therefore the full name of a message is
+ # {directory_name}.{message_name}
+ package_name = jsonschema_path.parent.name
+ message_name = jsonschema_path.stem
+ ref_name = f'{package_name}.{message_name}'
+
+ # Title gets generated in newer version of jsonschema plugin; just remove it manually
+ definitions = remove_title(definitions)
+
+ # The name of the first message defined in .proto file will be the used as the generated
+ # json file's name, which does not have a package name. Therefore, we prepend the package
+ # name for it
+ definitions[ref_name] = replace_ref_name(definitions[message_name], ref_name, message_name)
+ del definitions[message_name]
+ return definitions
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/swagger_test.py b/web_console_v2/api/fedlearner_webconsole/utils/swagger_test.py
new file mode 100644
index 000000000..6b4960712
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/swagger_test.py
@@ -0,0 +1,163 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pathlib import Path
+
+from fedlearner_webconsole.utils.swagger import remove_title, replace_ref_name, normalize_schema
+
+
+class SwaggerTest(unittest.TestCase):
+
+ def test_replace_ref_name(self):
+ candidate = {
+ '$ref': '#/definitions/no',
+ 'hello': {
+ '$ref': '#/definitions/no',
+ 'world': {
+ '$ref': '#/definitions/no'
+ }
+ }
+ }
+ candidate = replace_ref_name(candidate, ref_name='yes', message_name='no')
+ self.assertDictEqual(
+ {
+ '$ref': '#/definitions/yes',
+ 'hello': {
+ '$ref': '#/definitions/yes',
+ 'world': {
+ '$ref': '#/definitions/yes'
+ }
+ }
+ }, candidate)
+
+ def test_remove_title(self):
+ candidate = {'title': 'hello', 'inner': {'title': 'world', 'inner': {'title': '!',}}}
+ candidate = remove_title(candidate)
+ self.assertDictEqual({'inner': {'inner': {}}}, candidate)
+
+ def test_normalize_schema(self):
+ candidate = {
+ 'FileTreeNode': {
+ 'properties': {
+ 'files': {
+ 'items': {
+ '$ref': '#/definitions/FileTreeNode'
+ },
+ 'additionalProperties': False,
+ 'type': 'array'
+ }
+ },
+ 'additionalProperties': False,
+ 'type': 'object',
+ 'title': 'File Tree Node'
+ }
+ }
+
+ candidate = normalize_schema(candidate, Path('aaa/FileTreeNode.json'))
+ self.assertEqual(
+ {
+ # here
+ 'aaa.FileTreeNode': {
+ 'properties': {
+ 'files': {
+ 'items': {
+ # here
+ '$ref': '#/definitions/aaa.FileTreeNode'
+ },
+ 'additionalProperties': False,
+ 'type': 'array'
+ }
+ },
+ 'additionalProperties': False,
+ 'type': 'object',
+ # no title
+ }
+ },
+ candidate)
+
+ candidate = {
+ 'AlgorithmData': {
+ 'properties': {
+ 'version': {
+ '$ref': '#/definitions/AlgorithmData',
+ },
+ 'parameter': {
+ '$ref': '#/definitions/fedlearner_webconsole.proto.AlgorithmParameter',
+ 'additionalProperties': False
+ },
+ },
+ 'additionalProperties': False,
+ 'type': 'object',
+ 'title': 'Algorithm Data'
+ },
+ 'fedlearner_webconsole.proto.AlgorithmParameter': {
+ 'properties': {
+ 'variables': {
+ 'items': {
+ '$ref': '#/definitions/fedlearner_webconsole.proto.AlgorithmVariable'
+ },
+ 'additionalProperties': False,
+ 'type': 'array'
+ }
+ },
+ 'additionalProperties': False,
+ 'type': 'object',
+ 'title': 'Algorithm Parameter'
+ },
+ }
+
+ candidate = normalize_schema(candidate, Path('aaa/AlgorithmData.json'))
+ self.assertDictEqual(
+ {
+ # here
+ 'aaa.AlgorithmData': {
+ 'properties': {
+ 'version': {
+ # here
+ '$ref': '#/definitions/aaa.AlgorithmData',
+ },
+ 'parameter': {
+ # this does not change
+ '$ref': '#/definitions/fedlearner_webconsole.proto.AlgorithmParameter',
+ 'additionalProperties': False
+ },
+ },
+ 'additionalProperties': False,
+ 'type': 'object',
+ # no title
+ },
+ 'fedlearner_webconsole.proto.AlgorithmParameter': {
+ 'properties': {
+ 'variables': {
+ 'items': {
+ # this does not change
+ '$ref': '#/definitions/fedlearner_webconsole.proto.AlgorithmVariable'
+ },
+ 'additionalProperties': False,
+ 'type': 'array'
+ }
+ },
+ 'additionalProperties': False,
+ 'type': 'object',
+ # no title
+ }
+ },
+ candidate)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/system_envs.py b/web_console_v2/api/fedlearner_webconsole/utils/system_envs.py
index b75f607a6..6431a70b8 100644
--- a/web_console_v2/api/fedlearner_webconsole/utils/system_envs.py
+++ b/web_console_v2/api/fedlearner_webconsole/utils/system_envs.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
# limitations under the License.
# coding: utf-8
-import json
-import os
+
+from envs import Envs
def _is_valid_env(env: dict) -> bool:
@@ -22,105 +22,118 @@ def _is_valid_env(env: dict) -> bool:
env.get('value', None) is not None
+def _normalize_env(env: dict) -> dict:
+ if 'value' in env:
+ env['value'] = str(env['value'])
+ return env
+
+
def get_system_envs():
"""Gets a JSON string to represent system envs."""
# Most envs should be from pod's env
- envs = [
- {
- 'name': 'POD_IP',
- 'valueFrom': {
- 'fieldRef': {
- 'fieldPath': 'status.podIP'
- }
+ envs = [{
+ 'name': 'POD_IP',
+ 'valueFrom': {
+ 'fieldRef': {
+ 'fieldPath': 'status.podIP'
}
- },
- {
- 'name': 'POD_NAME',
- 'valueFrom': {
- 'fieldRef': {
- 'fieldPath': 'metadata.name'
- }
+ }
+ }, {
+ 'name': 'POD_NAME',
+ 'valueFrom': {
+ 'fieldRef': {
+ 'fieldPath': 'metadata.name'
}
- },
- {
- 'name': 'CPU_REQUEST',
- 'valueFrom': {
- 'resourceFieldRef': {
- 'resource': 'requests.cpu'
- }
+ }
+ }, {
+ 'name': 'CPU_REQUEST',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'requests.cpu'
}
- },
- {
- 'name': 'MEM_REQUEST',
- 'valueFrom': {
- 'resourceFieldRef': {
- 'resource': 'requests.memory'
- }
+ }
+ }, {
+ 'name': 'MEM_REQUEST',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'requests.memory'
}
- },
- {
- 'name': 'CPU_LIMIT',
- 'valueFrom': {
- 'resourceFieldRef': {
- 'resource': 'limits.cpu'
- }
+ }
+ }, {
+ 'name': 'CPU_LIMIT',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'limits.cpu'
}
- },
- {
- 'name': 'MEM_LIMIT',
- 'valueFrom': {
- 'resourceFieldRef': {
- 'resource': 'limits.memory'
- }
+ }
+ }, {
+ 'name': 'MEM_LIMIT',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'limits.memory'
}
- },
- {
- 'name': 'ES_HOST',
- 'value': os.getenv('ES_HOST')
- },
- {
- 'name': 'ES_PORT',
- 'value': os.getenv('ES_PORT')
- },
- {
- 'name': 'DB_HOST',
- 'value': os.getenv('DB_HOST')
- },
- {
- 'name': 'DB_PORT',
- 'value': os.getenv('DB_PORT')
- },
- {
- 'name': 'DB_DATABASE',
- 'value': os.getenv('DB_DATABASE')
- },
- {
- 'name': 'DB_USERNAME',
- 'value': os.getenv('DB_USERNAME')
- },
- {
- 'name': 'DB_PASSWORD',
- 'value': os.getenv('DB_PASSWORD')
- },
- {
- 'name': 'KVSTORE_TYPE',
- 'value': os.getenv('KVSTORE_TYPE')
- },
- {
- 'name': 'ETCD_NAME',
- 'value': os.getenv('ETCD_NAME')
- },
- {
- 'name': 'ETCD_ADDR',
- 'value': os.getenv('ETCD_ADDR')
- },
- {
- 'name': 'ETCD_BASE_DIR',
- 'value': os.getenv('ETCD_BASE_DIR')
}
- ]
- return ','.join([json.dumps(env)
- for env in envs if _is_valid_env(env)])
+ }, {
+ 'name': 'ES_HOST',
+ 'value': Envs.ES_HOST
+ }, {
+ 'name': 'ES_PORT',
+ 'value': Envs.ES_PORT
+ }, {
+ 'name': 'DB_HOST',
+ 'value': Envs.DB_HOST
+ }, {
+ 'name': 'DB_PORT',
+ 'value': Envs.DB_PORT
+ }, {
+ 'name': 'DB_DATABASE',
+ 'value': Envs.DB_DATABASE
+ }, {
+ 'name': 'DB_USERNAME',
+ 'value': Envs.DB_USERNAME
+ }, {
+ 'name': 'DB_PASSWORD',
+ 'value': Envs.DB_PASSWORD
+ }, {
+ 'name': 'KVSTORE_TYPE',
+ 'value': Envs.KVSTORE_TYPE
+ }, {
+ 'name': 'ETCD_NAME',
+ 'value': Envs.ETCD_NAME
+ }, {
+ 'name': 'ETCD_ADDR',
+ 'value': Envs.ETCD_ADDR
+ }, {
+ 'name': 'ETCD_BASE_DIR',
+ 'value': Envs.ETCD_BASE_DIR
+ }, {
+ 'name': 'ROBOT_USERNAME',
+ 'value': Envs.ROBOT_USERNAME
+ }, {
+ 'name': 'ROBOT_PWD',
+ 'value': Envs.ROBOT_PWD
+ }, {
+ 'name': 'WEB_CONSOLE_V2_ENDPOINT',
+ 'value': Envs.WEB_CONSOLE_V2_ENDPOINT
+ }, {
+ 'name': 'HADOOP_HOME',
+ 'value': Envs.HADOOP_HOME
+ }, {
+ 'name': 'JAVA_HOME',
+ 'value': Envs.JAVA_HOME
+ }, {
+ 'name': 'PRE_START_HOOK',
+ 'value': Envs.PRE_START_HOOK
+ }, {
+ 'name': 'METRIC_COLLECTOR_EXPORT_ENDPOINT',
+ 'value': Envs.APM_SERVER_ENDPOINT
+ }, {
+ 'name': 'CLUSTER',
+ 'value': Envs.CLUSTER
+ }]
+ valid_envs = [env for env in envs if _is_valid_env(env)]
+ envs = [_normalize_env(env) for env in valid_envs]
+ return envs
if __name__ == '__main__':
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/system_envs_test.py b/web_console_v2/api/fedlearner_webconsole/utils/system_envs_test.py
new file mode 100644
index 000000000..ca628cb47
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/system_envs_test.py
@@ -0,0 +1,133 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import patch
+
+from fedlearner_webconsole.utils.system_envs import get_system_envs
+
+
+class _FakeEnvs(object):
+ ES_HOST = 'test es host'
+ ES_PORT = '9200'
+ DB_HOST = 'test db host'
+ DB_PORT = '3306'
+ DB_DATABASE = 'fedlearner'
+ DB_USERNAME = 'username'
+ DB_PASSWORD = 'password'
+ KVSTORE_TYPE = 'mysql'
+ ETCD_NAME = 'fedlearner'
+ ETCD_ADDR = 'fedlearner-stack-etcd.default.svc.cluster.local:2379'
+ ETCD_BASE_DIR = 'fedlearner'
+ APM_SERVER_ENDPOINT = 'http://apm-server-apm-server:8200'
+ CLUSTER = 'cloudnative-hl'
+ ROBOT_USERNAME = None
+ ROBOT_PWD = None
+ WEB_CONSOLE_V2_ENDPOINT = None
+ HADOOP_HOME = None
+ JAVA_HOME = None
+ PRE_START_HOOK = None
+
+
+class SystemEnvsTest(unittest.TestCase):
+
+ @patch('fedlearner_webconsole.utils.system_envs.Envs', _FakeEnvs)
+ def test_get_available_envs(self):
+ self.assertEqual(get_system_envs(), [{
+ 'name': 'POD_IP',
+ 'valueFrom': {
+ 'fieldRef': {
+ 'fieldPath': 'status.podIP'
+ }
+ }
+ }, {
+ 'name': 'POD_NAME',
+ 'valueFrom': {
+ 'fieldRef': {
+ 'fieldPath': 'metadata.name'
+ }
+ }
+ }, {
+ 'name': 'CPU_REQUEST',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'requests.cpu'
+ }
+ }
+ }, {
+ 'name': 'MEM_REQUEST',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'requests.memory'
+ }
+ }
+ }, {
+ 'name': 'CPU_LIMIT',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'limits.cpu'
+ }
+ }
+ }, {
+ 'name': 'MEM_LIMIT',
+ 'valueFrom': {
+ 'resourceFieldRef': {
+ 'resource': 'limits.memory'
+ }
+ }
+ }, {
+ 'name': 'ES_HOST',
+ 'value': 'test es host'
+ }, {
+ 'name': 'ES_PORT',
+ 'value': '9200'
+ }, {
+ 'name': 'DB_HOST',
+ 'value': 'test db host'
+ }, {
+ 'name': 'DB_PORT',
+ 'value': '3306'
+ }, {
+ 'name': 'DB_DATABASE',
+ 'value': 'fedlearner'
+ }, {
+ 'name': 'DB_USERNAME',
+ 'value': 'username'
+ }, {
+ 'name': 'DB_PASSWORD',
+ 'value': 'password'
+ }, {
+ 'name': 'KVSTORE_TYPE',
+ 'value': 'mysql'
+ }, {
+ 'name': 'ETCD_NAME',
+ 'value': 'fedlearner'
+ }, {
+ 'name': 'ETCD_ADDR',
+ 'value': 'fedlearner-stack-etcd.default.svc.cluster.local:2379'
+ }, {
+ 'name': 'ETCD_BASE_DIR',
+ 'value': 'fedlearner'
+ }, {
+ 'name': 'METRIC_COLLECTOR_EXPORT_ENDPOINT',
+ 'value': 'http://apm-server-apm-server:8200'
+ }, {
+ 'name': 'CLUSTER',
+ 'value': 'cloudnative-hl'
+ }])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/tars.py b/web_console_v2/api/fedlearner_webconsole/utils/tars.py
deleted file mode 100644
index 3e7a59ea1..000000000
--- a/web_console_v2/api/fedlearner_webconsole/utils/tars.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# coding: utf-8
-import tarfile
-
-
-class TarCli:
- @staticmethod
- def untar_file(tar_name, extract_path_prefix):
- with tarfile.open(tar_name, 'r:*') as tar_pack:
- tar_pack.extractall(extract_path_prefix)
-
- return True
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/tars_test.py b/web_console_v2/api/fedlearner_webconsole/utils/tars_test.py
new file mode 100644
index 000000000..6a646acea
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/tars_test.py
@@ -0,0 +1,69 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+import os
+import tempfile
+import shutil
+from pathlib import Path
+from fedlearner_webconsole.utils.stream_untars import StreamingUntar
+from fedlearner_webconsole.utils.stream_tars import StreamingTar
+from fedlearner_webconsole.utils.file_manager import FileManager
+
+
+class StreamingTarTest(unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._file_manager = FileManager()
+ self._tempdir = os.path.join(tempfile.gettempdir(), 'tar_dir')
+ os.makedirs(self._tempdir, exist_ok=True)
+
+ def _get_temp_path(self, file_path: str = None) -> str:
+ return str(Path(self._tempdir, file_path or '').absolute())
+
+ def test_untar(self):
+
+ # init a dir with some files
+ tar_path = os.path.join(self._tempdir, 'tar')
+ self._file_manager.mkdir(tar_path)
+ file1_path = os.path.join(tar_path, 'test-tar1.py')
+ file2_path = os.path.join(tar_path, 'test-tar2.py')
+ file3_path = os.path.join(tar_path, 'new/test-tar3.py')
+
+ self._file_manager.write(file1_path, 'abc')
+ self._file_manager.write(file2_path, 'abc')
+ self._file_manager.write(file3_path, 'abc')
+
+ # Create a tar file
+ tar_file_path = os.path.join(tar_path, 'test-tar.tar.gz')
+ StreamingTar(self._file_manager).archive(source_path=[file1_path, file2_path, file3_path],
+ target_path=tar_file_path,
+ gzip_compress=True)
+
+ # test streaming untar file
+ untar_dir = os.path.join(tar_path, 'untar')
+ StreamingUntar(self._file_manager).untar(tar_file_path, untar_dir)
+
+ self._file_manager.exists(os.path.join(tar_path, os.path.basename(file1_path)))
+ self._file_manager.exists(os.path.join(tar_path, os.path.basename(file2_path)))
+ self._file_manager.exists(os.path.join(tar_path, os.path.basename(file3_path)))
+
+ def __del__(self):
+ shutil.rmtree(self._tempdir)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/tfrecord_test.py b/web_console_v2/api/fedlearner_webconsole/utils/tfrecord_test.py
new file mode 100644
index 000000000..3494b9ede
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/tfrecord_test.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import json
+import unittest
+from testing.common import BaseTestCase
+from http import HTTPStatus
+
+from envs import Envs
+
+
+class TfRecordReaderTest(BaseTestCase):
+
+ def test_reader(self):
+ metrix = [['id', 'x_1', 'x_2', 'x_3', 'x_4'],
+ [['0'], [0.4660772681236267], [0.9965257048606873], [0.15621308982372284], [0.9282205700874329]],
+ [['1'], [0.04800121858716011], [0.1965402364730835], [0.6086887121200562], [0.9214732646942139]],
+ [['2'], [0.05255622789263725], [0.8994112610816956], [0.6675127744674683], [0.577964186668396]],
+ [['3'], [0.7057438492774963], [0.5592560172080994], [0.6767191886901855], [0.6311695575714111]],
+ [['4'], [0.9203364253044128], [0.9567945599555969], [0.19533273577690125], [0.17610156536102295]]]
+ data = {
+ 'path': f'{Envs.BASE_DIR}/testing/test_data/'
+ f'tfrecord_test.xx.aaa.data',
+ 'wrong_path': 'adsad.data',
+ 'lines': 5
+ }
+ # test right path
+ resp = self.get_helper('/api/v2/debug/tfrecord?path={}&lines={}'.format(data['path'], data['lines'])) # pylint: disable=consider-using-f-string
+ my_data = json.loads(resp.data).get('data')
+ self.assertEqual(metrix, my_data)
+ self.assertEqual(HTTPStatus.OK, resp.status_code)
+
+ # test None path
+ resp = self.get_helper('/api/v2/debug/tfrecord')
+ self.assertEqual(HTTPStatus.BAD_REQUEST, resp.status_code)
+
+ # test wrong path
+ resp = self.get_helper('/api/v2/debug/tfrecord?path={}&lines={}'.format(data['wrong_path'], data['lines'])) # pylint: disable=consider-using-f-string
+ self.assertEqual(HTTPStatus.BAD_REQUEST, resp.status_code)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/tfrecords_reader.py b/web_console_v2/api/fedlearner_webconsole/utils/tfrecords_reader.py
new file mode 100644
index 000000000..77e659c39
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/tfrecords_reader.py
@@ -0,0 +1,89 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+import itertools
+import tensorflow.compat.v1 as tf
+from typing import List
+
+
+def _parse_tfrecord(record) -> dict:
+ example = tf.train.Example()
+ example.ParseFromString(record)
+
+ parsed = {}
+ for key, value in example.features.feature.items():
+ kind = value.WhichOneof('kind')
+ if kind == 'float_list':
+ parsed[key] = [float(num) for num in value.float_list.value]
+ elif kind == 'int64_list':
+ parsed[key] = [int(num) for num in value.int64_list.value]
+ elif kind == 'bytes_list':
+ parsed[key] = [byte.decode() for byte in value.bytes_list.value]
+ else:
+ raise ValueError('Invalid tfrecord format')
+
+ return parsed
+
+
+def _get_data(path: str, max_lines: int) -> List:
+ reader = tf.io.tf_record_iterator(path)
+ reader, _ = itertools.tee(reader)
+ records = []
+ counter = 0
+ for line in reader:
+ features = _parse_tfrecord(line)
+ records.append(features)
+ counter += 1
+ if counter >= max_lines:
+ break
+ return records
+
+
+def _convert_to_matrix_view(records: List[dict]) -> List:
+ first_line = set()
+ for features in records:
+ first_line = first_line.union(features.keys())
+ sort_first_line = list(first_line)
+ sort_first_line.sort()
+ matrix = [sort_first_line]
+ for features in records:
+ current_line = []
+ for column in sort_first_line:
+ if column in features:
+ current_line.append(features[column])
+ else:
+ current_line.append('N/A')
+ matrix.append(current_line)
+ return matrix
+
+
+def tf_record_reader(path: str, max_lines: int = 10, matrix_view: bool = False) -> List:
+ """Read tfrecord from given path
+
+ Args:
+ path: the path of tfrecord file
+ max_lines: the maximum number of lines read from file
+ matrix_view: whether convert the data to csv-like matrix
+ Returns:
+ Dictionary or csv-like data
+ """
+ # read data from tfrecord
+ records = _get_data(path, max_lines)
+ if not matrix_view:
+ return records
+ # get sorted first row of the matrix
+ matrix = _convert_to_matrix_view(records)
+ return matrix
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/validator.py b/web_console_v2/api/fedlearner_webconsole/utils/validator.py
new file mode 100644
index 000000000..ac5cb6e18
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/validator.py
@@ -0,0 +1,49 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# coding: utf-8
+from typing import Callable, Optional, TypeVar, List, Dict, Union, Tuple, Any
+
+T = TypeVar('T')
+
+
+class Validator:
+
+ def __init__(self, name: str, is_valid: Callable[[T], bool]):
+ self.name = name
+ self._is_valid = is_valid
+
+ def is_valid(self, candidate: Optional[T]) -> Tuple[bool, Optional[str]]:
+ if candidate is None:
+ return False, f'"{self.name}" is required.'
+
+ if not self._is_valid(candidate):
+ return False, f'"{candidate}" is not a valid "{self.name}".'
+
+ return True, None
+
+ @staticmethod
+ def validate(candidates: Dict[str, T],
+ validators: List['Validator']) -> Tuple[Union[bool, Any], List[Optional[str]]]:
+ flag = True
+ error_messages = []
+
+ for validator in validators:
+ passed, error_message = validator.is_valid(candidates.get(validator.name))
+ flag = passed and flag
+ if not passed:
+ error_messages.append(error_message)
+
+ return flag, error_messages
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/validator_test.py b/web_console_v2/api/fedlearner_webconsole/utils/validator_test.py
new file mode 100644
index 000000000..3f118c850
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/validator_test.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from fedlearner_webconsole.utils.validator import Validator
+
+
+class MyTestCase(unittest.TestCase):
+
+ def test_validator(self):
+ validators = [
+ Validator('field_1', lambda x: x > 0),
+ Validator('field_2', lambda x: x > 0),
+ Validator('field_3', lambda x: x > 0)
+ ]
+
+ dct_1 = {'field_1': 1, 'field_2': 2, 'field_3': 3}
+
+ dct_2 = {'field_1': -1, 'field_2': 2, 'field_3': 3}
+
+ dct_3 = {'field_1': 1, 'field_2': 2}
+
+ res_1, err_1 = Validator.validate(dct_1, validators)
+ res_2, err_2 = Validator.validate(dct_2, validators)
+ res_3, err_3 = Validator.validate(dct_3, validators)
+
+ self.assertTrue(res_1)
+ self.assertFalse(res_2)
+ self.assertFalse(res_3)
+
+ self.assertEqual(0, len(err_1))
+ self.assertEqual(1, len(err_2))
+ self.assertEqual(1, len(err_3))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/workflow.py b/web_console_v2/api/fedlearner_webconsole/utils/workflow.py
new file mode 100644
index 000000000..19e426854
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/workflow.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import Generator, List
+
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+
+
+def build_job_name(workflow_uuid: str, job_def_name: str) -> str:
+ return f'{workflow_uuid}-{job_def_name}'
+
+
+def zip_workflow_variables(config: WorkflowDefinition) -> Generator[Variable, None, None]:
+ for v in config.variables:
+ yield v
+ for job in config.job_definitions:
+ for v in job.variables:
+ yield v
+
+
+def fill_variables(config: WorkflowDefinition,
+ variables: List[Variable],
+ *,
+ dry_run: bool = False) -> WorkflowDefinition:
+ variables_mapper = {v.name: v for v in variables}
+ for slot_variable in zip_workflow_variables(config):
+ variable = variables_mapper.get(slot_variable.name)
+ if variable is None:
+ continue
+ if variable.value_type != slot_variable.value_type:
+ raise TypeError(f'unmatched variable type! {variable.value_type} != {slot_variable.value_type}')
+ if dry_run:
+ continue
+ slot_variable.typed_value.MergeFrom(variable.typed_value)
+ slot_variable.value = variable.value
+
+ return config
diff --git a/web_console_v2/api/fedlearner_webconsole/utils/workflow_test.py b/web_console_v2/api/fedlearner_webconsole/utils/workflow_test.py
new file mode 100644
index 000000000..9df21222a
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/utils/workflow_test.py
@@ -0,0 +1,111 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+import unittest
+
+from google.protobuf.struct_pb2 import Value
+
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition, WorkflowDefinition
+from fedlearner_webconsole.utils.workflow import build_job_name, fill_variables, zip_workflow_variables
+
+
+class UtilsTest(unittest.TestCase):
+
+ def test_build_job_name(self):
+ self.assertEqual(build_job_name('uuid', 'job_name'), 'uuid-job_name')
+
+ def test_zip_workflow_variables(self):
+ config = WorkflowDefinition(
+ variables=[
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ Variable(name='hello', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=1))
+ ],
+ job_definitions=[
+ JobDefinition(variables=[
+ Variable(
+ name='hello_from_job', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=3))
+ ])
+ ])
+ self.assertEqual(sum(1 for v in zip_workflow_variables(config)), 3)
+
+ def test_fill_variables(self):
+ config = WorkflowDefinition(
+ variables=[
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ Variable(name='hello', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=1))
+ ],
+ job_definitions=[
+ JobDefinition(variables=[
+ Variable(
+ name='hello_from_job', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=3))
+ ])
+ ])
+ variables = [
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='new_test_value'))
+ ]
+ config = fill_variables(config, variables)
+ self.assertEqual(config.variables[0].typed_value.string_value, 'new_test_value')
+
+ def test_fill_variables_invalid(self):
+ config = WorkflowDefinition(
+ variables=[
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ Variable(name='hello', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=1))
+ ],
+ job_definitions=[
+ JobDefinition(variables=[
+ Variable(
+ name='hello_from_job', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=3))
+ ])
+ ])
+ variables = [Variable(name='test', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=2))]
+ with self.assertRaises(TypeError):
+ fill_variables(config, variables)
+
+ def test_fill_variables_dry_run(self):
+ config = WorkflowDefinition(
+ variables=[
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='test_value')),
+ Variable(name='hello', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=1))
+ ],
+ job_definitions=[
+ JobDefinition(variables=[
+ Variable(
+ name='hello_from_job', value_type=Variable.ValueType.NUMBER, typed_value=Value(number_value=3))
+ ])
+ ])
+ variables = [
+ Variable(name='test',
+ value_type=Variable.ValueType.STRING,
+ typed_value=Value(string_value='new_test_value'))
+ ]
+ config = fill_variables(config, variables, dry_run=True)
+ self.assertEqual(config.variables[0].typed_value.string_value, 'test_value')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/workflow/BUILD.bazel
new file mode 100644
index 000000000..98dda7ae2
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/BUILD.bazel
@@ -0,0 +1,342 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "utils_test",
+ size = "small",
+ srcs = [
+ "utils_test.py",
+ ],
+ imports = ["../.."],
+ main = "utils_test.py",
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:mixins_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "resource_manager_lib",
+ srcs = ["resource_manager.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_library(
+ name = "service_lib",
+ srcs = ["service.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:composer_service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:yaml_formatter_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:metrics_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:resource_name_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:workflow_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "service_test",
+ size = "small",
+ srcs = [
+ "service_test.py",
+ ],
+ imports = ["../.."],
+ main = "service_test.py",
+ deps = [
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "workflow_controller_lib",
+ srcs = ["workflow_controller.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole/notification:notification_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "workflow_controller_test",
+ size = "small",
+ srcs = [
+ "workflow_controller_test.py",
+ ],
+ imports = ["../.."],
+ main = "workflow_controller_test.py",
+ deps = [
+ ":workflow_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/notification:notification_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "workflow_job_controller_lib",
+ srcs = [
+ "workflow_job_controller.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":workflow_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/two_pc:transaction_manager_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "workflow_job_controller_test",
+ size = "small",
+ srcs = [
+ "workflow_job_controller_test.py",
+ ],
+ imports = ["../.."],
+ main = "workflow_job_controller_test.py",
+ deps = [
+ ":workflow_job_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "cronjob_lib",
+ srcs = [
+ "cronjob.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":workflow_job_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "cronjob_test",
+ size = "small",
+ srcs = [
+ "cronjob_test.py",
+ ],
+ imports = ["../.."],
+ main = "cronjob_test.py",
+ deps = [
+ ":cronjob_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:fake_lib",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "workflow_scheduler_lib",
+ srcs = [
+ "workflow_scheduler.py",
+ ],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ ":workflow_job_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:const_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "workflow_scheduler_test",
+ size = "small",
+ srcs = [
+ "workflow_scheduler_test.py",
+ ],
+ imports = ["../.."],
+ main = "workflow_scheduler_test.py",
+ deps = [
+ ":workflow_scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:initial_db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ ":workflow_job_controller_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:iam_required_lib",
+ "//web_console_v2/api/fedlearner_webconsole/iam:permission_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:services_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc:client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/scheduler:scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:service_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ data = [
+ "//web_console_v2/api/testing/test_data",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/composer:common_lib",
+ "//web_console_v2/api/fedlearner_webconsole/dataset:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/job:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/participant:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/project:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/scheduler:scheduler_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/apis.py b/web_console_v2/api/fedlearner_webconsole/workflow/apis.py
index abc5c19a5..5e7012d59 100644
--- a/web_console_v2/api/fedlearner_webconsole/workflow/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,423 +14,698 @@
# pylint: disable=global-statement
# coding: utf-8
-import logging
import json
-from uuid import uuid4
+import logging
from http import HTTPStatus
-from flask_restful import Resource, reqparse, request
+from typing import Optional, List
+
+from flask_restful import Resource
from google.protobuf.json_format import MessageToDict
-from fedlearner_webconsole.composer.models import ItemStatus
-from fedlearner_webconsole.utils.decorators import jwt_required
-from fedlearner_webconsole.workflow.models import (
- Workflow, WorkflowState, TransactionState
-)
-from fedlearner_webconsole.job.yaml_formatter import generate_job_run_yaml
-from fedlearner_webconsole.proto import common_pb2
-from fedlearner_webconsole.workflow_template.apis import \
- dict_to_workflow_definition
+from sqlalchemy.orm import Session
+from marshmallow import Schema, fields, validate, post_load
+
+from fedlearner_webconsole.audit.decorators import emits_event
from fedlearner_webconsole.db import db
-from fedlearner_webconsole.exceptions import (
- NotFoundException, ResourceConflictException, InvalidArgumentException,
- InternalException, NoAccessException)
-from fedlearner_webconsole.scheduler.scheduler import scheduler
+from fedlearner_webconsole.exceptions import (NotFoundException, InvalidArgumentException, InternalException)
+from fedlearner_webconsole.iam.permission import Permission
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
from fedlearner_webconsole.rpc.client import RpcClient
-from fedlearner_webconsole.composer.composer import composer
-from fedlearner_webconsole.workflow.cronjob import WorkflowCronJobItem
-from fedlearner_webconsole.utils.metrics import emit_counter
+from fedlearner_webconsole.scheduler.scheduler import scheduler
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.utils.decorators.pp_flask import input_validator, use_kwargs
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.flask_utils import download_json, get_current_user, make_flask_response, FilterExpField
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.workflow.service import WorkflowService, \
+ ForkWorkflowParams, CreateNewWorkflowParams
+from fedlearner_webconsole.workflow_template.service import \
+ dict_to_workflow_definition
+
+from fedlearner_webconsole.iam.iam_required import iam_required
+from fedlearner_webconsole.workflow.workflow_job_controller import start_workflow, stop_workflow, \
+ invalidate_workflow_job
+from fedlearner_webconsole.proto.audit_pb2 import Event
-def _get_workflow(workflow_id) -> Workflow:
- result = Workflow.query.filter_by(id=workflow_id).first()
- if result is None:
+def _get_workflow(workflow_id: int, project_id: int, session: Session) -> Workflow:
+ workflow_query = session.query(Workflow)
+ # project_id 0 means search in all projects
+ if project_id != 0:
+ workflow_query = workflow_query.filter_by(project_id=project_id)
+ workflow = workflow_query.filter_by(id=workflow_id).first()
+ if workflow is None:
raise NotFoundException()
- return result
-
-def start_or_stop_cronjob(batch_update_interval: int, workflow: Workflow):
- """start a cronjob for workflow if batch_update_interval is valid
-
- Args:
- batch_update_interval (int): restart workflow interval, unit is minutes
-
- Returns:
- raise when workflow is_left is False
- """
- item_name = f'workflow_cron_job_{workflow.id}'
- batch_update_interval = batch_update_interval * 60
- if workflow.get_config().is_left and batch_update_interval > 0:
- status = composer.get_item_status(name=item_name)
- # create a cronjob
- if not status:
- composer.collect(name=item_name,
- items=[WorkflowCronJobItem(workflow.id)],
- metadata={},
- interval=batch_update_interval)
- return
- if status == ItemStatus.OFF:
- raise InvalidArgumentException(
- f'cannot set item [{item_name}], since item is off')
- # patch a cronjob
- try:
- composer.patch_item_attr(name=item_name,
- key='interval_time',
- value=batch_update_interval)
- except ValueError as err:
- raise InvalidArgumentException(details=repr(err))
-
-
- elif batch_update_interval < 0:
- composer.finish(name=item_name)
- elif not workflow.get_config().is_left:
- raise InvalidArgumentException('Only left can operate this')
- else:
- logging.info('skip cronjob since batch_update_interval is -1')
-
-def is_peer_job_inheritance_matched(workflow):
- # TODO: Move it to workflow service
- if workflow.forked_from is None:
- return True
- job_flags = workflow.get_create_job_flags()
- peer_job_flags = workflow.get_peer_create_job_flags()
- job_defs = workflow.get_config().job_definitions
- project = workflow.project
- if project is None:
- return True
- project_config = project.get_config()
- # TODO: Fix for multi-peer
- client = RpcClient(project_config, project_config.participants[0])
- parent_workflow = db.session.query(Workflow).get(workflow.forked_from)
- resp = client.get_workflow(parent_workflow.name)
- if resp.status.code != common_pb2.STATUS_SUCCESS:
- emit_counter('get_workflow_failed', 1)
- raise InternalException(resp.status.msg)
- peer_job_defs = resp.config.job_definitions
- for i, job_def in enumerate(job_defs):
- if job_def.is_federated:
- for j, peer_job_def in enumerate(peer_job_defs):
- if job_def.name == peer_job_def.name:
- if job_flags[i] != peer_job_flags[j]:
- return False
- return True
+ return workflow
+
+
+class GetWorkflowsParameter(Schema):
+ keyword = fields.String(required=False, load_default=None)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
+ states = fields.List(fields.String(required=False,
+ validate=validate.OneOf([
+ 'completed', 'failed', 'stopped', 'running', 'warmup', 'pending', 'ready',
+ 'configuring', 'invalid'
+ ])),
+ required=False,
+ load_default=None)
+ favour = fields.Integer(required=False, load_default=None, validate=validate.OneOf([0, 1]))
+ uuid = fields.String(required=False, load_default=None)
+ name = fields.String(required=False, load_default=None)
+ template_revision_id = fields.Integer(required=False, load_default=None)
+ filter_exp = FilterExpField(data_key='filter', required=False, load_default=None)
+
+
+class PostWorkflowsParameter(Schema):
+ name = fields.Str(required=True)
+ config = fields.Dict(required=True)
+ template_id = fields.Int(required=False, load_default=None)
+ template_revision_id = fields.Int(required=False, load_default=None)
+ forkable = fields.Bool(required=True)
+ forked_from = fields.Int(required=False, load_default=None)
+ create_job_flags = fields.List(required=False, load_default=None, cls_or_instance=fields.Int)
+ peer_create_job_flags = fields.List(required=False, load_default=None, cls_or_instance=fields.Int)
+ fork_proposal_config = fields.Dict(required=False, load_default=None)
+ comment = fields.Str(required=False, load_default=None)
+ cron_config = fields.Str(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'] = dict_to_workflow_definition(data['config'])
+ data['fork_proposal_config'] = dict_to_workflow_definition(data['fork_proposal_config'])
+ return data
+
+
+class PutWorkflowParameter(Schema):
+ config = fields.Dict(required=True)
+ template_id = fields.Integer(required=False, load_default=None)
+ template_revision_id = fields.Integer(required=False, load_default=None)
+ forkable = fields.Boolean(required=True)
+ create_job_flags = fields.List(required=False, load_default=None, cls_or_instance=fields.Integer)
+ comment = fields.String(required=False, load_default=None)
+ cron_config = fields.String(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'] = dict_to_workflow_definition(data['config'])
+ return data
+
+
+class PatchWorkflowParameter(Schema):
+ config = fields.Dict(required=False, load_default=None)
+ template_id = fields.Integer(required=False, load_default=None)
+ template_revision_id = fields.Integer(required=False, load_default=None)
+ forkable = fields.Boolean(required=False, load_default=None)
+ create_job_flags = fields.List(required=False, load_default=None, cls_or_instance=fields.Integer)
+ cron_config = fields.String(required=False, load_default=None)
+ favour = fields.Boolean(required=False, load_default=None)
+ metric_is_public = fields.Boolean(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'] = data['config'] and dict_to_workflow_definition(data['config'])
+ return data
+
+
+class PatchPeerWorkflowParameter(Schema):
+ config = fields.Dict(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'] = data['config'] and dict_to_workflow_definition(data['config'])
+ return data
-class WorkflowsApi(Resource):
- @jwt_required()
- def get(self):
- result = Workflow.query
- if 'project' in request.args and request.args['project'] is not None:
- project_id = request.args['project']
- result = result.filter_by(project_id=project_id)
- if 'keyword' in request.args and request.args['keyword'] is not None:
- keyword = request.args['keyword']
- result = result.filter(Workflow.name.like(
- '%{}%'.format(keyword)))
- if 'uuid' in request.args and request.args['uuid'] is not None:
- uuid = request.args['uuid']
- result = result.filter_by(uuid=uuid)
- res = []
- for row in result.order_by(Workflow.created_at.desc()).all():
- try:
- wf_dict = row.to_dict()
- except Exception as e: # pylint: disable=broad-except
- wf_dict = {
- 'id': row.id,
- 'name': row.name,
- 'uuid': row.uuid,
- 'error': f'Failed to get workflow state {repr(e)}'
- }
- res.append(wf_dict)
- return {'data': res}, HTTPStatus.OK
-
- @jwt_required()
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('name', required=True, help='name is empty')
- parser.add_argument('project_id', type=int, required=True,
- help='project_id is empty')
- # TODO: should verify if the config is compatible with
- # workflow template
- parser.add_argument('config', type=dict, required=True,
- help='config is empty')
- parser.add_argument('forkable', type=bool, required=True,
- help='forkable is empty')
- parser.add_argument('forked_from', type=int, required=False,
- help='fork from base workflow')
- parser.add_argument('create_job_flags', type=list, required=False,
- location='json',
- help='flags in common.CreateJobFlag')
- parser.add_argument('peer_create_job_flags', type=list,
- required=False, location='json',
- help='peer flags in common.CreateJobFlag')
- parser.add_argument('fork_proposal_config', type=dict, required=False,
- help='fork and edit peer config')
- parser.add_argument('batch_update_interval',
- type=int,
- required=False,
- help='interval for workflow cronjob in minute')
- parser.add_argument('extra',
- type=str,
- required=False,
- help='extra json string that needs send to peer')
-
- parser.add_argument('comment')
- data = parser.parse_args()
- name = data['name']
- if Workflow.query.filter_by(name=name).first() is not None:
- raise ResourceConflictException(
- 'Workflow {} already exists.'.format(name))
-
- # form to proto buffer
- template_proto = dict_to_workflow_definition(data['config'])
- workflow = Workflow(name=name,
- # 20 bytes
- # a DNS-1035 label must start with an
- # alphabetic character. substring uuid[:19] has
- # no collision in 10 million draws
- uuid=f'u{uuid4().hex[:19]}',
- comment=data['comment'],
- project_id=data['project_id'],
- forkable=data['forkable'],
- forked_from=data['forked_from'],
- state=WorkflowState.NEW,
- target_state=WorkflowState.READY,
- transaction_state=TransactionState.READY,
- extra=data['extra']
- )
- workflow.set_config(template_proto)
- workflow.set_create_job_flags(data['create_job_flags'])
-
- if workflow.forked_from is not None:
- fork_config = dict_to_workflow_definition(
- data['fork_proposal_config'])
- # TODO: more validations
- if len(fork_config.job_definitions) != \
- len(template_proto.job_definitions):
- raise InvalidArgumentException(
- 'Forked workflow\'s template does not match base workflow')
- workflow.set_fork_proposal_config(fork_config)
- workflow.set_peer_create_job_flags(data['peer_create_job_flags'])
- if not is_peer_job_inheritance_matched(workflow):
- raise InvalidArgumentException('Forked workflow has federated \
- job with unmatched inheritance')
-
- db.session.add(workflow)
- db.session.commit()
- logging.info('Inserted a workflow to db')
- scheduler.wakeup(workflow.id)
-
- # start cronjob every interval time
- # should start after inserting to db
- batch_update_interval = data['batch_update_interval']
- if batch_update_interval:
- start_or_stop_cronjob(batch_update_interval, workflow)
-
- return {'data': workflow.to_dict()}, HTTPStatus.CREATED
+class WorkflowsApi(Resource):
-class WorkflowApi(Resource):
- @jwt_required()
- def get(self, workflow_id):
- workflow = _get_workflow(workflow_id)
- result = workflow.to_dict()
- result['jobs'] = [job.to_dict() for job in workflow.get_jobs()]
- result['owned_jobs'] = [job.to_dict() for job in workflow.owned_jobs]
- result['config'] = None
- if workflow.get_config() is not None:
- result['config'] = MessageToDict(
- workflow.get_config(),
- preserving_proto_field_name=True,
- including_default_value_fields=True)
- return {'data': result}, HTTPStatus.OK
-
- @jwt_required()
- def put(self, workflow_id):
- parser = reqparse.RequestParser()
- parser.add_argument('config', type=dict, required=True,
- help='config is empty')
- parser.add_argument('forkable', type=bool, required=True,
- help='forkable is empty')
- parser.add_argument('create_job_flags', type=list, required=False,
- location='json',
- help='flags in common.CreateJobFlag')
- parser.add_argument(
- 'batch_update_interval',
- type=int,
- required=False,
- help='interval time for cronjob of workflow in minute')
- parser.add_argument('comment')
- data = parser.parse_args()
-
- workflow = _get_workflow(workflow_id)
- if workflow.config:
- raise ResourceConflictException(
- 'Resetting workflow is not allowed')
-
- batch_update_interval = data['batch_update_interval']
- if batch_update_interval:
- start_or_stop_cronjob(batch_update_interval, workflow)
-
- workflow.comment = data['comment']
- workflow.forkable = data['forkable']
- workflow.set_config(dict_to_workflow_definition(data['config']))
- workflow.set_create_job_flags(data['create_job_flags'])
- workflow.update_target_state(WorkflowState.READY)
- db.session.commit()
- scheduler.wakeup(workflow_id)
- logging.info('update workflow %d target_state to %s',
- workflow.id, workflow.target_state)
- return {'data': workflow.to_dict()}, HTTPStatus.OK
-
- @jwt_required()
- def patch(self, workflow_id):
- parser = reqparse.RequestParser()
- parser.add_argument('target_state', type=str, required=False,
- default=None, help='target_state is empty')
- parser.add_argument('state',
- type=str,
- required=False,
- help='state is empty')
- parser.add_argument('forkable', type=bool)
- parser.add_argument('metric_is_public', type=bool)
- parser.add_argument('config',
- type=dict,
- required=False,
- help='updated config')
- parser.add_argument('create_job_flags', type=list, required=False,
- location='json',
- help='flags in common.CreateJobFlag')
- parser.add_argument('batch_update_interval',
- type=int,
- required=False,
- help='interval for restart workflow in minute')
- data = parser.parse_args()
-
- workflow = _get_workflow(workflow_id)
-
- # start workflow every interval time
- batch_update_interval = data['batch_update_interval']
- if batch_update_interval:
- start_or_stop_cronjob(batch_update_interval, workflow)
-
- forkable = data['forkable']
- if forkable is not None:
- workflow.forkable = forkable
- db.session.flush()
-
- metric_is_public = data['metric_is_public']
- if metric_is_public is not None:
- workflow.metric_is_public = metric_is_public
- db.session.flush()
-
- target_state = data['target_state']
- if target_state:
+ @credentials_required
+ @use_kwargs(GetWorkflowsParameter(), location='query')
+ def get(
+ self,
+ page: Optional[int],
+ page_size: Optional[int],
+ name: Optional[str],
+ uuid: Optional[str],
+ keyword: Optional[str],
+ favour: Optional[bool],
+ template_revision_id: Optional[int],
+ states: Optional[List[str]],
+ filter_exp: Optional[FilterExpression],
+ project_id: int,
+ ):
+ """Get workflows.
+ ---
+ tags:
+ - workflow
+ description: Get workflows.
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the project. 0 means get all workflows.
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ - in: query
+ name: name
+ schema:
+ type: string
+ - in: query
+ name: uuid
+ schema:
+ type: string
+ - in: query
+ name: keyword
+ schema:
+ type: string
+ - in: query
+ name: favour
+ schema:
+ type: boolean
+ - in: query
+ name: template_revision_id
+ schema:
+ type: integer
+ - in: query
+ name: states
+ schema:
+ type: array
+ collectionFormat: multi
+ items:
+ type: string
+ enum: [completed, failed, stopped, running, warmup, pending, ready, configuring, invalid]
+ - in: query
+ name: filter
+ schema:
+ type: string
+ responses:
+ 200:
+ description: list of workflows.
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowRef'
+ """
+ with db.session_scope() as session:
+ result = session.query(Workflow)
+ if project_id != 0:
+ result = result.filter_by(project_id=project_id)
+ if name is not None:
+ result = result.filter_by(name=name)
+ if keyword is not None:
+ result = result.filter(Workflow.name.like(f'%{keyword}%'))
+ if uuid is not None:
+ result = result.filter_by(uuid=uuid)
+ if favour is not None:
+ result = result.filter_by(favour=favour)
+ if states is not None:
+ result = WorkflowService.filter_workflows(result, states)
+ if template_revision_id is not None:
+ result = result.filter_by(template_revision_id=template_revision_id)
+ if filter_exp is not None:
+ result = WorkflowService(session).build_filter_query(result, filter_exp)
+ result = result.order_by(Workflow.id.desc())
+ pagination = paginate(result, page, page_size)
+ res = []
+ for item in pagination.get_items():
+ try:
+ wf_dict = to_dict(item.to_workflow_ref())
+ except Exception as e: # pylint: disable=broad-except
+ wf_dict = {
+ 'id': item.id,
+ 'name': item.name,
+ 'uuid': item.uuid,
+ 'error': f'Failed to get workflow state {repr(e)}'
+ }
+ res.append(wf_dict)
+ # To resolve the issue of that MySQL 8 Select Count(*) is very slow
+ # https://bugs.mysql.com/bug.php?id=97709
+ pagination.query = pagination.query.filter(Workflow.id > -1)
+ page_meta = pagination.get_metadata()
+ return make_flask_response(data=res, page_meta=page_meta)
+
+ @input_validator
+ @credentials_required
+ @iam_required(Permission.WORKFLOWS_POST)
+ @emits_event(resource_type=Event.ResourceType.WORKFLOW, audit_fields=['forkable'])
+ @use_kwargs(PostWorkflowsParameter(), location='json')
+ def post(
+ self,
+ name: str,
+ comment: Optional[str],
+ forkable: bool,
+ forked_from: Optional[bool],
+ create_job_flags: Optional[List[int]],
+ peer_create_job_flags: Optional[List[int]],
+ # Peer config
+ fork_proposal_config: Optional[WorkflowDefinition],
+ template_id: Optional[int],
+ config: WorkflowDefinition,
+ cron_config: Optional[str],
+ template_revision_id: Optional[int],
+ project_id: int):
+ """Create workflows.
+ ---
+ tags:
+ - workflow
+ description: Get workflows.
+ parameters:
+ - in: path
+ description: The ID of the project.
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/PostWorkflowsParameter'
+ responses:
+ 201:
+ description: detail of workflows.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
+ with db.session_scope() as session:
+ if forked_from:
+ params = ForkWorkflowParams(fork_from_id=forked_from,
+ fork_proposal_config=fork_proposal_config,
+ peer_create_job_flags=peer_create_job_flags)
+ else:
+ params = CreateNewWorkflowParams(project_id=project_id,
+ template_id=template_id,
+ template_revision_id=template_revision_id)
try:
- if WorkflowState[target_state] == WorkflowState.RUNNING:
- for job in workflow.owned_jobs:
- try:
- generate_job_run_yaml(job)
- # TODO: check if peer variables is valid
- except Exception as e: # pylint: disable=broad-except
- raise ValueError(
- f'Invalid Variable when try '
- f'to format the job {job.name}:{str(e)}')
- workflow.update_target_state(WorkflowState[target_state])
- db.session.flush()
- logging.info('updated workflow %d target_state to %s',
- workflow.id, workflow.target_state)
+ workflow = WorkflowService(session).create_workflow(name=name,
+ comment=comment,
+ forkable=forkable,
+ config=config,
+ create_job_flags=create_job_flags,
+ cron_config=cron_config,
+ params=params,
+ creator_username=get_current_user().username)
except ValueError as e:
raise InvalidArgumentException(details=str(e)) from e
+ session.commit()
+ logging.info('Inserted a workflow to db')
+ scheduler.wakeup(workflow.id)
+ return make_flask_response(data=workflow.to_proto(), status=HTTPStatus.CREATED)
+
- state = data['state']
- if state:
+class WorkflowApi(Resource):
+
+ @credentials_required
+ @use_kwargs({'download': fields.Bool(required=False, load_default=False)}, location='query')
+ def get(self, download: Optional[bool], project_id: int, workflow_id: int):
+ """Get workflow and with jobs.
+ ---
+ tags:
+ - workflow
+ description: Get workflow.
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the project. 0 means get all workflows.
+ - in: path
+ name: workflow_id
+ schema:
+ type: integer
+ required: true
+ - in: query
+ name: download
+ schema:
+ type: boolean
+ responses:
+ 200:
+ description: detail of workflow.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
+ del project_id
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ if workflow is None:
+ raise NotFoundException(f'workflow {workflow_id} is not found')
+ result = workflow.to_proto()
+ result.jobs.extend([job.to_proto() for job in workflow.get_jobs(session)])
+ if download:
+ return download_json(content=to_dict(result), filename=workflow.name)
+ return make_flask_response(data=result)
+
+ @credentials_required
+ @iam_required(Permission.WORKFLOW_PUT)
+ @emits_event(resource_type=Event.ResourceType.WORKFLOW, audit_fields=['forkable'])
+ @use_kwargs(PutWorkflowParameter(), location='json')
+ def put(self, config: WorkflowDefinition, template_id: Optional[int], forkable: bool,
+ create_job_flags: Optional[List[int]], cron_config: Optional[str], comment: Optional[str],
+ template_revision_id: Optional[int], project_id: int, workflow_id: int):
+ """Config workflow.
+ ---
+ tags:
+ - workflow
+ description: Config workflow.
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the project.
+ - in: path
+ name: workflow_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the workflow.
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/PutWorkflowParameter'
+ responses:
+ 200:
+ description: detail of workflow.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
+ with db.session_scope() as session:
+ workflow = _get_workflow(workflow_id, project_id, session)
try:
- assert state == 'INVALID', \
- 'Can only set state to INVALID for invalidation'
- workflow.invalidate()
- db.session.flush()
- logging.info('invalidate workflow %d', workflow.id)
+ WorkflowService(session).config_workflow(workflow=workflow,
+ template_id=template_id,
+ config=config,
+ forkable=forkable,
+ comment=comment,
+ cron_config=cron_config,
+ create_job_flags=create_job_flags,
+ creator_username=get_current_user().username,
+ template_revision_id=template_revision_id)
except ValueError as e:
raise InvalidArgumentException(details=str(e)) from e
-
- config = data['config']
- if config:
+ session.commit()
+ scheduler.wakeup(workflow_id)
+ logging.info('update workflow %d target_state to %s', workflow.id, workflow.target_state)
+ return make_flask_response(data=workflow.to_proto())
+
+ @input_validator
+ @credentials_required
+ @iam_required(Permission.WORKFLOW_PATCH)
+ @emits_event(resource_type=Event.ResourceType.WORKFLOW, audit_fields=['forkable', 'metric_is_public'])
+ @use_kwargs(PatchWorkflowParameter(), location='json')
+ def patch(self, forkable: Optional[bool], metric_is_public: Optional[bool], config: Optional[WorkflowDefinition],
+ template_id: Optional[int], create_job_flags: Optional[List[int]], cron_config: Optional[str],
+ favour: Optional[bool], template_revision_id: Optional[int], project_id: int, workflow_id: int):
+ """Patch workflow.
+ ---
+ tags:
+ - workflow
+ description: Patch workflow.
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the project.
+ - in: path
+ name: workflow_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the workflow.
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/PatchWorkflowParameter'
+ responses:
+ 200:
+ description: detail of workflow.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
+ with db.session_scope() as session:
+ workflow = _get_workflow(workflow_id, project_id, session)
try:
- if workflow.target_state != WorkflowState.INVALID or \
- workflow.state not in \
- [WorkflowState.READY, WorkflowState.STOPPED]:
- raise NoAccessException('Cannot edit running workflow')
- config_proto = dict_to_workflow_definition(data['config'])
- workflow.set_config(config_proto)
- db.session.flush()
+ WorkflowService(session).patch_workflow(workflow=workflow,
+ forkable=forkable,
+ metric_is_public=metric_is_public,
+ config=config,
+ template_id=template_id,
+ create_job_flags=create_job_flags,
+ cron_config=cron_config,
+ favour=favour,
+ template_revision_id=template_revision_id)
+ session.commit()
except ValueError as e:
raise InvalidArgumentException(details=str(e)) from e
-
- create_job_flags = data['create_job_flags']
- if create_job_flags:
- jobs = workflow.get_jobs()
- if len(create_job_flags) != len(jobs):
- raise InvalidArgumentException(
- details='Number of job defs does not match number '
- f'of create_job_flags {len(jobs)} '
- f'vs {len(create_job_flags)}')
- workflow.set_create_job_flags(create_job_flags)
- flags = workflow.get_create_job_flags()
- for i, job in enumerate(jobs):
- if job.workflow_id == workflow.id:
- job.is_disabled = flags[i] == \
- common_pb2.CreateJobFlag.DISABLED
-
- db.session.commit()
- scheduler.wakeup(workflow.id)
- return {'data': workflow.to_dict()}, HTTPStatus.OK
+ return make_flask_response(data=workflow.to_proto())
class PeerWorkflowsApi(Resource):
- @jwt_required()
- def get(self, workflow_id):
- workflow = _get_workflow(workflow_id)
- project_config = workflow.project.get_config()
+
+ @credentials_required
+ def get(self, project_id: int, workflow_id: int):
+ """Get peer workflow and with jobs.
+ ---
+ tags:
+ - workflow
+ description: Get peer workflow.
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ - in: path
+ name: workflow_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 200:
+ description: detail of workflow.
+ content:
+ application/json:
+ schema:
+ type: object
+ additionalProperties:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
peer_workflows = {}
- for party in project_config.participants:
- client = RpcClient(project_config, party)
- # TODO(xiangyxuan): use uuid to identify the workflow
- resp = client.get_workflow(workflow.name)
- if resp.status.code != common_pb2.STATUS_SUCCESS:
- raise InternalException(resp.status.msg)
- peer_workflow = MessageToDict(
- resp,
- preserving_proto_field_name=True,
- including_default_value_fields=True)
- for job in peer_workflow['jobs']:
- if 'pods' in job:
- job['pods'] = json.loads(job['pods'])
- peer_workflows[party.name] = peer_workflow
- return {'data': peer_workflows}, HTTPStatus.OK
-
- @jwt_required()
- def patch(self, workflow_id):
- parser = reqparse.RequestParser()
- parser.add_argument('config', type=dict, required=True,
- help='new config for peer')
- data = parser.parse_args()
- config_proto = dict_to_workflow_definition(data['config'])
-
- workflow = _get_workflow(workflow_id)
- project_config = workflow.project.get_config()
+ with db.session_scope() as session:
+ workflow = _get_workflow(workflow_id, project_id, session)
+ service = ParticipantService(session)
+ participants = service.get_platform_participants_by_project(workflow.project.id)
+
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(workflow.project.name, workflow.project.token,
+ participant.domain_name)
+ # TODO(xiangyxuan): use uuid to identify the workflow
+ resp = client.get_workflow(workflow.uuid, workflow.name)
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ raise InternalException(resp.status.msg)
+ peer_workflow = MessageToDict(resp,
+ preserving_proto_field_name=True,
+ including_default_value_fields=True)
+ for job in peer_workflow['jobs']:
+ if 'pods' in job:
+ job['pods'] = json.loads(job['pods'])
+ peer_workflows[participant.name] = peer_workflow
+ return make_flask_response(peer_workflows)
+
+ @credentials_required
+ @iam_required(Permission.WORKFLOW_PATCH)
+ @use_kwargs(PatchPeerWorkflowParameter(), location='json')
+ def patch(self, config: WorkflowDefinition, project_id: int, workflow_id: int):
+ """Patch peer workflow.
+ ---
+ tags:
+ - workflow
+ description: patch peer workflow.
+ parameters:
+ - in: path
+ name: project_id
+ required: true
+ schema:
+ type: integer
+ - in: path
+ name: workflow_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/PatchPeerWorkflowParameter'
+ responses:
+ 200:
+ description: detail of workflow.
+ content:
+ application/json:
+ schema:
+ type: object
+ additionalProperties:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
peer_workflows = {}
- for party in project_config.participants:
- client = RpcClient(project_config, party)
- resp = client.update_workflow(
- workflow.name, config_proto)
- if resp.status.code != common_pb2.STATUS_SUCCESS:
- raise InternalException(resp.status.msg)
- peer_workflows[party.name] = MessageToDict(
- resp,
- preserving_proto_field_name=True,
- including_default_value_fields=True)
- return {'data': peer_workflows}, HTTPStatus.OK
+ with db.session_scope() as session:
+ workflow = _get_workflow(workflow_id, project_id, session)
+ service = ParticipantService(session)
+ participants = service.get_platform_participants_by_project(workflow.project.id)
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(workflow.project.name, workflow.project.token,
+ participant.domain_name)
+ resp = client.update_workflow(workflow.uuid, workflow.name, config)
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ raise InternalException(resp.status.msg)
+ peer_workflows[participant.name] = MessageToDict(resp,
+ preserving_proto_field_name=True,
+ including_default_value_fields=True)
+ return make_flask_response(peer_workflows)
+
+
+class WorkflowInvalidateApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.WORKFLOW, op_type=Event.OperationType.INVALIDATE)
+ def post(self, project_id: int, workflow_id: int):
+ """Invalidates the workflow job.
+ ---
+ tags:
+ - workflow
+ description: Invalidates the workflow job.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: workflow_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: Invalidated workflow
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
+ with db.session_scope() as session:
+ workflow = _get_workflow(workflow_id, project_id, session)
+ invalidate_workflow_job(session, workflow)
+ session.commit()
+ return make_flask_response(workflow.to_proto())
+
+
+class WorkflowStartApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.WORKFLOW, op_type=Event.OperationType.UPDATE)
+ def post(self, project_id: int, workflow_id: int):
+ """Starts the workflow job.
+ ---
+ tags:
+ - workflow
+ description: Starts the workflow job.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: workflow_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: Started workflow
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
+ start_workflow(workflow_id)
+ with db.session_scope() as session:
+ workflow = _get_workflow(workflow_id, project_id, session)
+ return make_flask_response(workflow.to_proto())
+
+
+class WorkflowStopApi(Resource):
+
+ @credentials_required
+ @emits_event(resource_type=Event.ResourceType.WORKFLOW, op_type=Event.OperationType.UPDATE)
+ def post(self, project_id: int, workflow_id: int):
+ """Stops the workflow job.
+ ---
+ tags:
+ - workflow
+ description: Stops the workflow job.
+ parameters:
+ - in: path
+ name: project_id
+ schema:
+ type: integer
+ - in: path
+ name: workflow_id
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: Stopped workflow
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowPb'
+ """
+ stop_workflow(workflow_id)
+ with db.session_scope() as session:
+ workflow = _get_workflow(workflow_id, project_id, session)
+ return make_flask_response(workflow.to_proto())
def initialize_workflow_apis(api):
- api.add_resource(WorkflowsApi, '/workflows')
- api.add_resource(WorkflowApi, '/workflows/')
- api.add_resource(PeerWorkflowsApi,
- '/workflows//peer_workflows')
+ api.add_resource(WorkflowsApi, '/projects//workflows')
+ api.add_resource(WorkflowApi, '/projects//workflows/')
+ api.add_resource(PeerWorkflowsApi, '/projects//workflows//peer_workflows')
+ api.add_resource(WorkflowInvalidateApi, '/projects//workflows/:invalidate')
+ api.add_resource(WorkflowStartApi, '/projects//workflows/:start')
+ api.add_resource(WorkflowStopApi, '/projects//workflows/:stop')
+
+ # if a schema is used, one has to append it to schema_manager so Swagger knows there is a schema available
+ schema_manager.append(PostWorkflowsParameter)
+ schema_manager.append(PutWorkflowParameter)
+ schema_manager.append(PatchWorkflowParameter)
+ schema_manager.append(PatchPeerWorkflowParameter)
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/apis_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/apis_test.py
new file mode 100644
index 000000000..27bbf363c
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/apis_test.py
@@ -0,0 +1,772 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import random
+import string
+import time
+import json
+import unittest
+import urllib.parse
+from http import HTTPStatus
+from pathlib import Path
+from unittest.mock import (patch, call)
+from google.protobuf.json_format import ParseDict
+from fedlearner_webconsole.utils.const import SYSTEM_WORKFLOW_CREATOR_USERNAME
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import ItemStatus
+from fedlearner_webconsole.dataset.models import Dataset, DatasetType
+from fedlearner_webconsole.participant.models import Participant, ProjectParticipant
+from fedlearner_webconsole.proto.composer_pb2 import WorkflowCronJobInput, RunnerInput
+from fedlearner_webconsole.proto.workflow_definition_pb2 import \
+ WorkflowDefinition
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.scheduler.transaction import TransactionState
+from fedlearner_webconsole.proto import (project_pb2, service_pb2, common_pb2)
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from testing.common import BaseTestCase
+
+
+class WorkflowsApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ self.maxDiff = None
+ super().setUp()
+ # Inserts data
+ template1 = WorkflowTemplate(name='t1', comment='comment for t1', group_alias='g1')
+ template1.set_config(WorkflowDefinition(group_alias='g1',))
+ workflow1 = Workflow(name='workflow_key_get1',
+ project_id=1,
+ state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY,
+ creator=SYSTEM_WORKFLOW_CREATOR_USERNAME,
+ favour=True)
+ workflow2 = Workflow(name='workflow_key_get2',
+ project_id=2,
+ state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.COORDINATOR_COMMITTABLE)
+ workflow3 = Workflow(name='workflow_key_get3', project_id=2)
+ workflow4 = Workflow(name='workflow_key_get4',
+ project_id=4,
+ state=WorkflowState.INVALID,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY,
+ favour=True)
+ project = Project(id=123, name='project_123')
+ dataset1 = Dataset(
+ name='default dataset1',
+ dataset_type=DatasetType.STREAMING,
+ comment='test comment1',
+ path='/data/dataset/123',
+ project_id=1,
+ )
+ with db.session_scope() as session:
+ session.add(project)
+ session.add(workflow1)
+ session.add(workflow2)
+ session.add(workflow3)
+ session.add(workflow4)
+ session.add(template1)
+ session.add(dataset1)
+ session.commit()
+
+ def test_get_with_name(self):
+ response = self.get_helper('/api/v2/projects/0/workflows?name=workflow_key_get3')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'workflow_key_get3')
+
+ def test_get_with_project(self):
+ response = self.get_helper('/api/v2/projects/1/workflows')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual(data[0]['name'], 'workflow_key_get1')
+
+ def test_get_with_keyword(self):
+ response = self.get_helper('/api/v2/projects/0/workflows?keyword=key')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 4)
+ self.assertEqual(data[0]['name'], 'workflow_key_get4')
+
+ def test_get_with_states(self):
+ response = self.get_helper('/api/v2/projects/0/workflows?states=configuring&states=ready')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual(data[0]['name'], 'workflow_key_get2')
+
+ def test_get_with_state_invalid(self):
+ response = self.get_helper('/api/v2/projects/0/workflows?states=invalid')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual('workflow_key_get4', data[0]['name'])
+
+ def test_get_with_favour(self):
+ response = self.get_helper('/api/v2/projects/0/workflows?favour=1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 2)
+ self.assertEqual('workflow_key_get4', data[0]['name'])
+
+ def test_get_with_filter(self):
+ filter_exp = urllib.parse.quote('(system=true)')
+ response = self.get_helper(f'/api/v2/projects/0/workflows?filter={filter_exp}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 1)
+ self.assertEqual('workflow_key_get1', data[0]['name'])
+ filter_exp = urllib.parse.quote('(system=false)')
+ response = self.get_helper(f'/api/v2/projects/0/workflows?filter={filter_exp}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(len(data), 3)
+
+ def test_get_workflows(self):
+ # Sleeps 1 second for making workflow create_at bigger
+ time.sleep(1)
+ workflow = Workflow(name='last', project_id=1)
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.commit()
+ response = self.get_helper('/api/v2/projects/0/workflows')
+ data = self.get_response_data(response)
+ self.assertEqual(data[0]['name'], 'last')
+
+ @patch('fedlearner_webconsole.workflow.apis.scheduler.wakeup')
+ @patch('fedlearner_webconsole.workflow.service.resource_uuid')
+ def test_create_new_workflow(self, mock_resource_uuid, mock_wakeup):
+ mock_resource_uuid.return_value = 'test-uuid'
+ with open(Path(__file__, '../../../testing/test_data/workflow_config.json').resolve(),
+ encoding='utf-8') as workflow_config:
+ config = json.load(workflow_config)
+ # TODO(hangweiqiang): remove this in workflow test
+ extra = ''.join(random.choice(string.ascii_lowercase) for _ in range(10))
+ # extra should be a valid json string so we mock one
+ extra = f'{{"parent_job_name":"{extra}"}}'
+
+ local_extra = ''.join(random.choice(string.ascii_lowercase) for _ in range(10))
+ # local_extra should be a valid json string so we mock one
+ local_extra = f'{{"model_desc":"{local_extra}"}}'
+
+ workflow = {
+ 'name': 'test-workflow',
+ 'project_id': 1234567,
+ 'forkable': True,
+ 'comment': 'test-comment',
+ 'config': config,
+ 'extra': extra,
+ 'local_extra': local_extra,
+ 'template_id': 1,
+ 'template_revision_id': 1
+ }
+ response = self.post_helper('/api/v2/projects/1234567/workflows', data=workflow)
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ created_workflow = self.get_response_data(response)
+ # Check scheduler
+ mock_wakeup.assert_called_once_with(created_workflow['id'])
+ self.assertIsNotNone(created_workflow['id'])
+ self.assertIsNotNone(created_workflow['created_at'])
+ self.assertIsNotNone(created_workflow['updated_at'])
+ self.assertResponseDataEqual(response, {
+ 'cron_config': '',
+ 'name': 'test-workflow',
+ 'project_id': 1234567,
+ 'forkable': True,
+ 'forked_from': 0,
+ 'is_local': False,
+ 'metric_is_public': False,
+ 'comment': 'test-comment',
+ 'state': 'PARTICIPANT_CONFIGURING',
+ 'create_job_flags': [1, 1, 1],
+ 'peer_create_job_flags': [],
+ 'job_ids': [1, 2, 3],
+ 'uuid': 'test-uuid',
+ 'template_revision_id': 1,
+ 'template_id': 1,
+ 'creator': 'ada',
+ 'favour': False,
+ 'jobs': []
+ },
+ ignore_fields=[
+ 'id', 'created_at', 'updated_at', 'start_at', 'stop_at', 'config',
+ 'editor_info', 'template_info'
+ ])
+ # Check DB
+ with db.session_scope() as session:
+ self.assertEqual(len(session.query(Workflow).all()), 5)
+
+ # Post again
+ mock_wakeup.reset_mock()
+ response = self.post_helper('/api/v2/projects/1234567/workflows', data=workflow)
+ self.assertEqual(response.status_code, HTTPStatus.CONFLICT)
+ # Check mock
+ mock_wakeup.assert_not_called()
+ # Check DB
+ with db.session_scope() as session:
+ self.assertEqual(len(session.query(Workflow).all()), 5)
+
+ @patch('fedlearner_webconsole.participant.services.ParticipantService.get_platform_participants_by_project')
+ @patch('fedlearner_webconsole.workflow.utils.is_peer_job_inheritance_matched')
+ def test_fork_local_workflow(self, mock_is_peer_job_inheritance_matched, mock_get_platform_participants_by_project):
+ config = {
+ 'groupAlias': 'test',
+ 'job_definitions': [{
+ 'name': 'raw-data-job',
+ 'is_federated': False,
+ 'yaml_template': '{}',
+ }]
+ }
+ config_proto = ParseDict(config, WorkflowDefinition())
+ with db.session_scope() as session:
+ project = Project(name='test project')
+ session.add(project)
+ template = WorkflowTemplate(group_alias='test')
+ template.set_config(config_proto)
+ session.add(template)
+ session.flush()
+ parent_workflow = Workflow(name='local-workflow',
+ state=WorkflowState.READY,
+ forkable=True,
+ project_id=project.id,
+ template_id=template.id,
+ template_revision_id=1)
+ parent_workflow.set_config(config_proto)
+ session.add(parent_workflow)
+ session.commit()
+
+ fork_request = {
+ 'name': 'test-fork-local-workflow',
+ 'project_id': project.id,
+ 'forkable': True,
+ 'config': config,
+ 'comment': 'test-comment',
+ 'forked_from': parent_workflow.id,
+ 'fork_proposal_config': config,
+ }
+ response = self.post_helper(f'/api/v2/projects/{project.id}/workflows', data=fork_request)
+ mock_get_platform_participants_by_project.assert_not_called()
+ mock_is_peer_job_inheritance_matched.assert_not_called()
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ self.assertResponseDataEqual(
+ response, {
+ 'name': 'test-fork-local-workflow',
+ 'project_id': project.id,
+ 'template_id': template.id,
+ 'template_revision_id': 1,
+ 'comment': 'test-comment',
+ 'metric_is_public': False,
+ 'create_job_flags': [1],
+ 'job_ids': [1],
+ 'forkable': True,
+ 'forked_from': parent_workflow.id,
+ 'peer_create_job_flags': [],
+ 'state': 'PARTICIPANT_CONFIGURING',
+ 'start_at': 0,
+ 'stop_at': 0,
+ 'cron_config': '',
+ 'is_local': True,
+ 'creator': 'ada',
+ 'favour': False,
+ },
+ ignore_fields=['id', 'uuid', 'created_at', 'updated_at', 'config', 'template_info', 'editor_info', 'jobs'])
+
+ @patch('fedlearner_webconsole.participant.services.ParticipantService.get_platform_participants_by_project')
+ @patch('fedlearner_webconsole.workflow.service.is_peer_job_inheritance_matched')
+ def test_fork_workflow(self, mock_is_peer_job_inheritance_matched, mock_get_platform_participants_by_project):
+ # Prepares data
+ with open(Path(__file__, '../../../testing/test_data/workflow_config.json').resolve(),
+ encoding='utf-8') as workflow_config:
+ config = json.load(workflow_config)
+ with db.session_scope() as session:
+ project = Project(id=1, name='test project')
+ session.add(project)
+ config_proto = ParseDict(config, WorkflowDefinition())
+ template = WorkflowTemplate(name='parent-template', group_alias=config['group_alias'])
+ template.set_config(config_proto)
+ session.add(template)
+ session.flush()
+ parent_workflow = Workflow(name='parent_workflow',
+ project_id=1,
+ template_id=template.id,
+ state=WorkflowState.READY)
+ parent_workflow.set_config(config_proto)
+ session.add(parent_workflow)
+ session.commit()
+ fork_request = {
+ 'name': 'test-fork-workflow',
+ 'project_id': project.id,
+ 'forkable': True,
+ 'config': config,
+ 'comment': 'test-comment',
+ 'forked_from': parent_workflow.id,
+ 'fork_proposal_config': config,
+ 'peer_create_job_flags': [1, 1, 1],
+ }
+ # By default it is not forkable
+ response = self.post_helper(f'/api/v2/projects/{project.id}/workflows', data=fork_request)
+ self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
+ self.assertEqual(json.loads(response.data)['details'], 'workflow not forkable')
+
+ # Forks after parent workflow is forkable
+ with db.session_scope() as session:
+ parent_workflow = session.query(Workflow).get(parent_workflow.id)
+ parent_workflow.forkable = True
+ session.commit()
+ mock_get_platform_participants_by_project.return_value = None
+ mock_is_peer_job_inheritance_matched.return_value = True
+ response = self.post_helper(f'/api/v2/projects/{project.id}/workflows', data=fork_request)
+ mock_is_peer_job_inheritance_matched.assert_called_once()
+ self.assertEqual(response.status_code, HTTPStatus.CREATED)
+ self.assertResponseDataEqual(response, {
+ 'cron_config': '',
+ 'name': 'test-fork-workflow',
+ 'project_id': project.id,
+ 'forkable': True,
+ 'forked_from': parent_workflow.id,
+ 'is_local': False,
+ 'metric_is_public': False,
+ 'comment': 'test-comment',
+ 'state': 'PARTICIPANT_CONFIGURING',
+ 'create_job_flags': [1, 1, 1],
+ 'peer_create_job_flags': [1, 1, 1],
+ 'job_ids': [1, 2, 3],
+ 'template_id': template.id,
+ 'template_revision_id': 0,
+ 'creator': 'ada',
+ 'favour': False,
+ },
+ ignore_fields=[
+ 'id', 'created_at', 'updated_at', 'start_at', 'stop_at', 'uuid', 'config',
+ 'editor_info', 'template_info', 'jobs'
+ ])
+
+ @patch('fedlearner_webconsole.composer.composer_service.ComposerService.get_item_status')
+ @patch('fedlearner_webconsole.composer.composer_service.ComposerService.collect_v2')
+ @patch('fedlearner_webconsole.workflow.apis.scheduler.wakeup')
+ def test_post_cron_job(self, mock_wakeup, mock_collect, mock_get_item_status):
+ mock_get_item_status.return_value = None
+ with open(Path(__file__, '../../../testing/test_data/workflow_config.json').resolve(),
+ encoding='utf-8') as workflow_config:
+ config = json.load(workflow_config)
+ workflow = {
+ 'name': 'test-workflow-left',
+ 'project_id': 123,
+ 'forkable': True,
+ 'config': config,
+ 'cron_config': '*/10 * * * *',
+ 'template_id': 1
+ }
+ responce = self.post_helper('/api/v2/projects/123/workflows', data=workflow)
+ self.assertEqual(responce.status_code, HTTPStatus.CREATED)
+
+ with open(Path(__file__, '../../../testing/test_data/workflow_config_right.json').resolve(),
+ encoding='utf-8') as workflow_config:
+ config = json.load(workflow_config)
+ workflow = {
+ 'name': 'test-workflow-right',
+ 'project_id': 1234567,
+ 'forkable': True,
+ 'config': config,
+ 'cron_config': '*/10 * * * *',
+ }
+ responce = self.post_helper('/api/v2/projects/1234567/workflows', data=workflow)
+ self.assertEqual(responce.status_code, HTTPStatus.CREATED)
+
+ mock_collect.assert_called()
+ mock_wakeup.assert_called()
+
+
+class WorkflowApiTest(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ self._project = Project(id=123, name='project_123')
+ self._template1 = WorkflowTemplate(name='t1', comment='comment for t1', group_alias='g1')
+ self._template1.set_config(WorkflowDefinition(group_alias='g1',))
+ session.add(self._project)
+ session.add(self._template1)
+ session.commit()
+ self.signin_as_admin()
+
+ def test_get_workflow(self):
+ workflow = Workflow(name='test-workflow',
+ project_id=self._project.id,
+ config=WorkflowDefinition(group_alias='g1',).SerializeToString(),
+ template_id=self._template1.id,
+ forkable=False,
+ state=WorkflowState.RUNNING,
+ job_ids='1')
+ job1 = Job(name='job 1', workflow_id=3, project_id=self._project.id, job_type=JobType.RAW_DATA)
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.add(job1)
+ session.commit()
+
+ response = self.get_helper(f'/api/v2/projects/{self._project.id}/workflows/{workflow.id}')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ workflow_data = self.get_response_data(response)
+ self.assertEqual(workflow_data['name'], 'test-workflow')
+ self.assertEqual(len(workflow_data['jobs']), 1)
+ self.assertEqual(workflow_data['jobs'][0]['name'], 'job 1')
+ response = self.get_helper(f'/api/v2/projects/{self._project.id}/workflows/6666')
+ self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+
+ @patch('fedlearner_webconsole.scheduler.scheduler.Scheduler.wakeup')
+ def test_put_successfully(self, mock_wake_up):
+ config = {
+ 'variables': [{
+ 'name': 'namespace',
+ 'value': 'leader'
+ }, {
+ 'name': 'basic_envs',
+ 'value': '{}'
+ }, {
+ 'name': 'storage_root_dir',
+ 'value': '/'
+ }]
+ }
+ with db.session_scope() as session:
+ project = Project(id=1,
+ name='test',
+ config=ParseDict(config, project_pb2.ProjectConfig()).SerializeToString())
+ participant = Participant(name='party_leader', host='127.0.0.1', port=5000, domain_name='fl-leader.com')
+ relationship = ProjectParticipant(project_id=1, participant_id=1)
+ session.add(project)
+ session.add(participant)
+ session.add(relationship)
+ workflow = Workflow(name='test-workflow',
+ project_id=project.id,
+ state=WorkflowState.NEW,
+ transaction_state=TransactionState.PARTICIPANT_PREPARE,
+ target_state=WorkflowState.READY)
+ session.add(workflow)
+ session.commit()
+
+ response = self.put_helper(f'/api/v2/projects/{project.id}/workflows/{workflow.id}',
+ data={
+ 'forkable': True,
+ 'config': {
+ 'group_alias': 'test-template'
+ },
+ 'comment': 'test comment',
+ 'template_id': 1,
+ 'template_revision_id': 1
+ })
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_wake_up.assert_called_with(workflow.id)
+ with db.session_scope() as session:
+ updated_workflow = session.query(Workflow).get(workflow.id)
+ self.assertIsNotNone(updated_workflow.config)
+ self.assertTrue(updated_workflow.forkable)
+ self.assertEqual(updated_workflow.comment, 'test comment')
+ self.assertEqual(updated_workflow.target_state, WorkflowState.READY)
+ self.assertEqual(updated_workflow.template_revision_id, 1)
+
+ def test_put_resetting(self):
+ with db.session_scope() as session:
+ project_id = 123
+ workflow = Workflow(
+ name='test-workflow',
+ project_id=project_id,
+ config=WorkflowDefinition(group_alias='test-template').SerializeToString(),
+ state=WorkflowState.NEW,
+ )
+ session.add(workflow)
+ session.commit()
+ session.refresh(workflow)
+
+ response = self.put_helper(f'/api/v2/projects/{project_id}/workflows/{workflow.id}',
+ data={
+ 'forkable': True,
+ 'config': {
+ 'group_alias': 'test-template'
+ },
+ 'template_id': 1
+ })
+ self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST)
+
+ @patch('fedlearner_webconsole.composer.composer_service.ComposerService.get_item_status')
+ @patch('fedlearner_webconsole.composer.composer_service.ComposerService.patch_item_attr')
+ @patch('fedlearner_webconsole.composer.composer_service.ComposerService.finish')
+ @patch('fedlearner_webconsole.composer.composer_service.ComposerService.collect_v2')
+ def test_patch_cron_config(self, mock_collect, mock_finish, mock_patch_item, mock_get_item_status):
+ mock_get_item_status.side_effect = [None, ItemStatus.ON]
+ project_id = 123
+ workflow = Workflow(
+ name='test-workflow-left',
+ project_id=project_id,
+ config=WorkflowDefinition().SerializeToString(),
+ forkable=False,
+ state=WorkflowState.STOPPED,
+ )
+ cron_config = '*/20 * * * *'
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.commit()
+ session.refresh(workflow)
+
+ # test create cronjob
+ response = self.patch_helper(f'/api/v2/projects/{project_id}/workflows/{workflow.id}',
+ data={'cron_config': cron_config})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+
+ mock_collect.assert_called_with(
+ name=f'workflow_cron_job_{workflow.id}',
+ items=[(ItemType.WORKFLOW_CRON_JOB,
+ RunnerInput(workflow_cron_job_input=WorkflowCronJobInput(workflow_id=workflow.id)))],
+ cron_config=cron_config)
+
+ # patch new config for cronjob
+ cron_config = '*/30 * * * *'
+ response = self.patch_helper(f'/api/v2/projects/{project_id}/workflows/{workflow.id}',
+ data={'cron_config': cron_config})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_patch_item.assert_called_with(name=f'workflow_cron_job_{workflow.id}',
+ key='cron_config',
+ value=cron_config)
+
+ # test stop cronjob
+ response = self.patch_helper(f'/api/v2/projects/{project_id}/workflows/{workflow.id}', data={'cron_config': ''})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_finish.assert_called_with(name=f'workflow_cron_job_{workflow.id}')
+
+ def test_patch_not_found(self):
+ response = self.patch_helper('/api/v2/projects/123/workflows/1', data={'target_state': 'RUNNING'})
+ self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND)
+
+ def test_patch_create_job_flags(self):
+ with db.session_scope() as session:
+ workflow, job = add_fake_workflow(session)
+ job_id = job.id
+ response = self.patch_helper(f'/api/v2/projects/{workflow.project_id}/workflows/{workflow.id}',
+ data={'create_job_flags': [3]})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ patched_job = session.query(Job).get(job_id)
+ self.assertEqual(patched_job.is_disabled, True)
+ response = self.patch_helper(f'/api/v2/projects/{workflow.project_id}/workflows/{workflow.id}',
+ data={'create_job_flags': [1]})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ with db.session_scope() as session:
+ patched_job = session.query(Job).get(job_id)
+ self.assertEqual(patched_job.is_disabled, False)
+
+ def test_patch_favour(self):
+ with db.session_scope() as session:
+ workflow, job = add_fake_workflow(session)
+ response = self.patch_helper(f'/api/v2/projects/{workflow.project_id}/workflows/{workflow.id}',
+ data={'favour': True})
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(data['favour'], True)
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow.id)
+ self.assertEqual(workflow.favour, True)
+
+ def test_ptach_template(self):
+ with db.session_scope() as session:
+ workflow, job = add_fake_workflow(session)
+ response = self.patch_helper(f'/api/v2/projects/{workflow.project_id}/workflows/{workflow.id}',
+ data={
+ 'config': to_dict(workflow.get_config()),
+ 'template_revision_id': 1
+ })
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(data['template_revision_id'], 1)
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow.id)
+ self.assertEqual(workflow.template_revision_id, 1)
+
+ def test_is_local(self):
+ with db.session_scope() as session:
+ workflow, job = add_fake_workflow(session)
+ self.assertTrue(workflow.is_local())
+ config = workflow.get_config()
+ config.job_definitions[0].is_federated = True
+ workflow.set_config(config)
+ self.assertFalse(workflow.is_local())
+
+
+class WorkflowInvalidateApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project_1')
+ participant1 = Participant(name='party_1', id=1, host='127.0.0.1', port=1997, domain_name='fl-party1.com')
+
+ participant2 = Participant(name='party_2', id=2, host='127.0.0.1', port=1998, domain_name='fl-party2.com')
+ relationship1 = ProjectParticipant(project_id=1, participant_id=1)
+ relationship2 = ProjectParticipant(project_id=1, participant_id=2)
+ ready_workflow = Workflow(name='workflow_invalidate1',
+ project_id=1,
+ uuid='11111',
+ state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.READY)
+ session.add(project)
+ session.add(participant1)
+ session.add(participant2)
+ session.add(relationship1)
+ session.add(relationship2)
+ session.add(ready_workflow)
+ session.commit()
+ self.signin_as_admin()
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.invalidate_workflow')
+ def test_invalidate_after_created(self, mock_invalidate_workflow):
+ mock_invalidate_workflow.return_value = service_pb2.InvalidateWorkflowResponse(
+ status=common_pb2.Status(code=common_pb2.STATUS_SUCCESS, msg=''),
+ succeeded=True,
+ )
+ response = self.post_helper('/api/v2/projects/1/workflows/1:invalidate')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ expected = [call('11111'), call('11111')]
+ self.assertEqual(mock_invalidate_workflow.call_args_list, expected)
+ response = self.get_helper('/api/v2/projects/0/workflows/1')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ data = self.get_response_data(response)
+ self.assertEqual(data['state'], WorkflowState.INVALID.name)
+
+
+class WorkflowStartAndStopApiTest(BaseTestCase):
+
+ class Config(BaseTestCase.Config):
+ START_SCHEDULER = False
+
+ def setUp(self):
+ super().setUp()
+ with db.session_scope() as session:
+ project = Project(id=1, name='project_1')
+ participant1 = Participant(name='party_1', id=1, host='127.0.0.1', port=1997, domain_name='fl-party1.com')
+
+ participant2 = Participant(name='party_2', id=2, host='127.0.0.1', port=1998, domain_name='fl-party2.com')
+ relationship1 = ProjectParticipant(project_id=1, participant_id=1)
+ relationship2 = ProjectParticipant(project_id=1, participant_id=2)
+ workflow_test_start_fed = Workflow(name='workflow_test_start_fed',
+ project_id=1,
+ uuid='11111',
+ state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY)
+ workflow_test_stop_fed = Workflow(name='workflow_test_stop_fed',
+ project_id=1,
+ uuid='22222',
+ state=WorkflowState.RUNNING,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY)
+ workflow_test_start_local = Workflow(name='workflow_test_start_local',
+ project_id=1,
+ uuid='33333',
+ state=WorkflowState.STOPPED,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY)
+ workflow_test_stop_local = Workflow(name='workflow_test_stop_local',
+ project_id=1,
+ uuid='44444',
+ state=WorkflowState.RUNNING,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY)
+ session.add(project)
+ session.add(participant1)
+ session.add(participant2)
+ session.add(relationship1)
+ session.add(relationship2)
+ session.add(workflow_test_start_fed)
+ session.add(workflow_test_stop_fed)
+ session.add(workflow_test_start_local)
+ session.add(workflow_test_stop_local)
+ session.commit()
+ self.signin_as_admin()
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager.run')
+ def test_start_workflow_fed(self, mock_run):
+ mock_run.return_value = (True, '')
+ response = self.post_helper('/api/v2/projects/1/workflows/1:start')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_run.assert_called_once()
+
+ @patch('fedlearner_webconsole.two_pc.transaction_manager.TransactionManager.run')
+ def test_stop_workflow_fed(self, mock_run):
+ mock_run.return_value = (True, '')
+ response = self.post_helper('/api/v2/projects/1/workflows/2:stop')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_run.assert_called_once()
+
+ @patch('fedlearner_webconsole.workflow.models.Workflow.is_local')
+ @patch('fedlearner_webconsole.workflow.workflow_job_controller.start_workflow_locally')
+ def test_start_workflow_local(self, mock_start_workflow_locally, mock_is_local):
+ mock_is_local.return_value = True
+ response = self.post_helper('/api/v2/projects/1/workflows/3:start')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_start_workflow_locally.assert_called_once()
+
+ @patch('fedlearner_webconsole.workflow.models.Workflow.is_local')
+ @patch('fedlearner_webconsole.workflow.workflow_job_controller.stop_workflow_locally')
+ def test_stop_workflow_local(self, mock_stop_workflow_locally, mock_is_local):
+ mock_is_local.return_value = True
+ response = self.post_helper('/api/v2/projects/1/workflows/4:stop')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+ mock_stop_workflow_locally.assert_called_once()
+ response = self.post_helper('/api/v2/projects/1/workflows/4:stop')
+ self.assertEqual(response.status_code, HTTPStatus.OK)
+
+
+def add_fake_workflow(session):
+ wd = WorkflowDefinition()
+ jd = wd.job_definitions.add()
+ jd.yaml_template = '{}'
+ workflow = Workflow(
+ name='test-workflow',
+ project_id=123,
+ config=wd.SerializeToString(),
+ forkable=False,
+ state=WorkflowState.READY,
+ )
+ session.add(workflow)
+ session.flush()
+ job = Job(name='test_job',
+ job_type=JobType(1),
+ config=jd.SerializeToString(),
+ workflow_id=workflow.id,
+ project_id=123,
+ state=JobState.STOPPED,
+ is_disabled=False)
+ session.add(job)
+ session.flush()
+ workflow.job_ids = str(job.id)
+ session.commit()
+ return workflow, job
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/cronjob.py b/web_console_v2/api/fedlearner_webconsole/workflow/cronjob.py
index 184393e32..58df1d82e 100644
--- a/web_console_v2/api/fedlearner_webconsole/workflow/cronjob.py
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/cronjob.py
@@ -1,94 +1,46 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
# coding: utf-8
-
+import logging
from typing import Tuple
-from time import sleep
-
-from fedlearner_webconsole.composer.interface import IItem, IRunner, ItemType
-from fedlearner_webconsole.composer.models import Context, RunnerStatus
-from fedlearner_webconsole.db import get_session
-from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
-
-
-class WorkflowCronJobItem(IItem):
- def __init__(self, task_id: int):
- self.id = task_id
- def type(self) -> ItemType:
- return ItemType.WORKFLOW_CRON_JOB
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import IRunnerV2
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, WorkflowCronJobOutput
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowExternalState
+from fedlearner_webconsole.workflow.workflow_job_controller import start_workflow
- def get_id(self) -> int:
- return self.id
- def __eq__(self, obj: IItem):
- return self.id == obj.id and self.type() == obj.type()
-
-
-class WorkflowCronJob(IRunner):
- """ start workflow every intervals
+class WorkflowCronJob(IRunnerV2):
+ """Starts workflow periodically.
"""
- def __init__(self, task_id: int):
- self._workflow_id = task_id
- self._msg = None
-
- def start(self, context: Context):
- with get_session(context.db_engine) as session:
- try:
- workflow: Workflow = session.query(Workflow).filter_by(
- id=self._workflow_id).one()
- # TODO: This is a hack!!! Templatelly use this method
- # cc @hangweiqiang: Transaction State Refactor
- state = workflow.get_state_for_frontend()
- if state in ('COMPLETED', 'FAILED', 'READY', 'STOPPED', 'NEW'):
- if state in ('COMPLETED', 'FAILED'):
- workflow.update_target_state(
- target_state=WorkflowState.STOPPED)
- session.commit()
- # check workflow stopped
- # TODO: use composer timeout cc @yurunyu
- for _ in range(24):
- # use session refresh to get the latest info
- # otherwise it'll use the indentity map locally
- session.refresh(workflow)
- if workflow.state == WorkflowState.STOPPED:
- break
- sleep(5)
- else:
- self._msg = f'failed to stop \
- workflow[{self._workflow_id}]'
- return
- workflow.update_target_state(
- target_state=WorkflowState.RUNNING)
- session.commit()
- self._msg = f'restarted workflow[{self._workflow_id}]'
- elif state == 'RUNNING':
- self._msg = f'skip restarting workflow[{self._workflow_id}]'
- elif state == 'INVALID':
- self._msg = f'current workflow[{self._workflow_id}] \
- is invalid'
- else:
- self._msg = f'workflow[{self._workflow_id}] \
- state is {state}, which is out of expection'
-
- except Exception as err: # pylint: disable=broad-except
- self._msg = f'exception of workflow[{self._workflow_id}], \
- details is {err}'
-
- def result(self, context: Context) -> Tuple[RunnerStatus, dict]:
- del context # unused by result
- if self._msg is None:
- return RunnerStatus.RUNNING, {}
- output = {'msg': self._msg}
- return RunnerStatus.DONE, output
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ output = WorkflowCronJobOutput()
+ with db.session_scope() as session:
+ workflow_id = context.input.workflow_cron_job_input.workflow_id
+ workflow: Workflow = session.query(Workflow).get(workflow_id)
+ state = workflow.get_state_for_frontend()
+ logging.info(f'[WorkflowCronJob] Try to start workflow {workflow_id}, state: {state.name}')
+ if state in (WorkflowExternalState.READY_TO_RUN, WorkflowExternalState.COMPLETED,
+ WorkflowExternalState.FAILED, WorkflowExternalState.STOPPED):
+ start_workflow(workflow_id)
+ output.message = 'Restarted workflow'
+ else:
+ output.message = f'Skip starting workflow, state is {state.name}'
+ return RunnerStatus.DONE, RunnerOutput(workflow_cron_job_output=output)
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/cronjob_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/cronjob_test.py
new file mode 100644
index 000000000..113036f11
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/cronjob_test.py
@@ -0,0 +1,107 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from datetime import datetime
+from unittest.mock import patch, Mock
+
+from sqlalchemy import and_
+
+from fedlearner_webconsole.composer.composer import ComposerConfig, Composer
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.composer.models import RunnerStatus, SchedulerItem, SchedulerRunner
+from fedlearner_webconsole.composer.composer_service import ComposerService
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, WorkflowCronJobInput
+from fedlearner_webconsole.workflow.cronjob import WorkflowCronJob
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from testing.no_web_server_test_case import NoWebServerTestCase
+from testing.fake_time_patcher import FakeTimePatcher
+
+
+class CronJobTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.time_patcher = FakeTimePatcher()
+ self.time_patcher.start(datetime(2012, 1, 14, 12, 0, 5))
+
+ self.test_id = 8848
+ workflow = Workflow(id=self.test_id, state=WorkflowState.RUNNING)
+ with db.session_scope() as session:
+ session.add(workflow)
+ session.commit()
+
+ def tearDown(self):
+ self.time_patcher.stop()
+
+ super().tearDown()
+
+ def test_run_skip_running_workflow(self):
+ workflow_id = 123
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, state=WorkflowState.RUNNING)
+ session.add(workflow)
+ session.commit()
+
+ context = RunnerContext(0, RunnerInput(workflow_cron_job_input=WorkflowCronJobInput(workflow_id=workflow_id)))
+ runner = WorkflowCronJob()
+ status, output = runner.run(context)
+ self.assertEqual(status, RunnerStatus.DONE)
+ self.assertEqual(output.workflow_cron_job_output.message, 'Skip starting workflow, state is RUNNING')
+
+ @patch('fedlearner_webconsole.workflow.cronjob.start_workflow')
+ def test_run_ready_workflow(self, mock_start_workflow: Mock):
+ workflow_id = 234
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, state=WorkflowState.READY)
+ session.add(workflow)
+ session.commit()
+
+ context = RunnerContext(0, RunnerInput(workflow_cron_job_input=WorkflowCronJobInput(workflow_id=workflow_id)))
+ runner = WorkflowCronJob()
+ status, output = runner.run(context)
+ self.assertEqual(status, RunnerStatus.DONE)
+ self.assertEqual(output.workflow_cron_job_output.message, 'Restarted workflow')
+ mock_start_workflow.assert_called_once_with(workflow_id)
+
+ def test_cronjob_with_composer(self):
+ item_name = f'workflow_cronjob_{self.test_id}'
+ config = ComposerConfig(runner_fn={ItemType.WORKFLOW_CRON_JOB.value: WorkflowCronJob}, name='test_cronjob')
+ composer = Composer(config=config)
+ with db.session_scope() as session:
+ service = ComposerService(session)
+ service.collect_v2(name=item_name,
+ items=[
+ (ItemType.WORKFLOW_CRON_JOB,
+ RunnerInput(workflow_cron_job_input=WorkflowCronJobInput(workflow_id=self.test_id)))
+ ],
+ cron_config='* * * * * */10')
+ session.commit()
+ composer.run(db_engine=db.engine)
+ # Interrupts twice since we need two rounds of tick for
+ # composer to schedule items in fake world
+ self.time_patcher.interrupt(10)
+ self.time_patcher.interrupt(10)
+ with db.session_scope() as session:
+ runners = session.query(SchedulerRunner).filter(
+ and_(SchedulerRunner.item_id == SchedulerItem.id, SchedulerItem.name == item_name)).all()
+ self.assertEqual(len(runners), 2)
+ composer.stop()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/models.py b/web_console_v2/api/fedlearner_webconsole/workflow/models.py
index f988f93db..33d226645 100644
--- a/web_console_v2/api/fedlearner_webconsole/workflow/models.py
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/models.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,23 +13,24 @@
# limitations under the License.
# coding: utf-8
-# pylint: disable=broad-except
-import json
-import logging
+# pylint: disable=use-a-generator
import enum
-from datetime import datetime
+from typing import List, Optional
+
+from sqlalchemy.orm import deferred
from sqlalchemy.sql import func
from sqlalchemy import UniqueConstraint
-from envs import Features
-from fedlearner_webconsole.composer.models import SchedulerItem
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.proto.workflow_pb2 import WorkflowRef, WorkflowPb
from fedlearner_webconsole.utils.mixins import to_dict_mixin
from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
from fedlearner_webconsole.proto import (common_pb2, workflow_definition_pb2)
-from fedlearner_webconsole.job.models import (Job, JobState, JobType,
- JobDependency)
-from fedlearner_webconsole.rpc.client import RpcClient
-from fedlearner_webconsole.mmgr.service import ModelService
+from fedlearner_webconsole.job.models import JobState, Job
+from fedlearner_webconsole.utils.pp_datetime import to_timestamp
+from fedlearner_webconsole.workflow.utils import is_local
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate, WorkflowTemplateRevision
class WorkflowState(enum.Enum):
@@ -38,20 +39,60 @@ class WorkflowState(enum.Enum):
READY = 2
RUNNING = 3
STOPPED = 4
-
-
-class RecurType(enum.Enum):
- NONE = 0
- ON_NEW_DATA = 1
- HOURLY = 2
- DAILY = 3
- WEEKLY = 4
-
-
-VALID_TRANSITIONS = [(WorkflowState.NEW, WorkflowState.READY),
- (WorkflowState.READY, WorkflowState.RUNNING),
- (WorkflowState.RUNNING, WorkflowState.STOPPED),
- (WorkflowState.STOPPED, WorkflowState.RUNNING)]
+ COMPLETED = 5
+ FAILED = 6
+
+
+class WorkflowExternalState(enum.Enum):
+ # state of workflow is unknown
+ UNKNOWN = 0
+ # workflow is completed
+ COMPLETED = 1
+ # workflow is failed
+ FAILED = 2
+ # workflow is stopped
+ STOPPED = 3
+ # workflow is running
+ RUNNING = 4
+ # workflow is prepare to run
+ PREPARE_RUN = 5
+ # workflow is prepare to stop
+ PREPARE_STOP = 6
+ # workflow is warming up under the hood
+ WARMUP_UNDERHOOD = 7
+ # workflow is pending participant accept
+ PENDING_ACCEPT = 8
+ # workflow is ready to run
+ READY_TO_RUN = 9
+ # workflow is waiting for participant configure
+ PARTICIPANT_CONFIGURING = 10
+ # workflow is invalid
+ INVALID = 11
+
+
+# yapf: disable
+VALID_TRANSITIONS = [
+ (WorkflowState.NEW, WorkflowState.READY),
+ (WorkflowState.READY, WorkflowState.RUNNING),
+ (WorkflowState.READY, WorkflowState.STOPPED),
+
+ (WorkflowState.RUNNING, WorkflowState.STOPPED),
+ # Transitions below are not used, because state controller treat COMPLETED and FAILED as STOPPED.
+ # (WorkflowState.RUNNING, WorkflowState.COMPLETED),
+ # (WorkflowState.RUNNING, WorkflowState.FAILED),
+
+
+ (WorkflowState.STOPPED, WorkflowState.RUNNING),
+ (WorkflowState.COMPLETED, WorkflowState.RUNNING),
+ (WorkflowState.FAILED, WorkflowState.RUNNING),
+ (WorkflowState.RUNNING, WorkflowState.RUNNING),
+
+ # This is hack to make workflow_state_controller's committing stage idempotent.
+ (WorkflowState.STOPPED, WorkflowState.STOPPED),
+ (WorkflowState.COMPLETED, WorkflowState.STOPPED),
+ (WorkflowState.FAILED, WorkflowState.STOPPED)
+]
+# yapf: enable
class TransactionState(enum.Enum):
@@ -75,86 +116,69 @@ class TransactionState(enum.Enum):
(TransactionState.READY, TransactionState.COORDINATOR_PREPARE),
# (TransactionState.COORDINATOR_PREPARE,
# TransactionState.COORDINATOR_COMMITTABLE),
- (TransactionState.COORDINATOR_COMMITTABLE,
- TransactionState.COORDINATOR_COMMITTING),
+ (TransactionState.COORDINATOR_COMMITTABLE, TransactionState.COORDINATOR_COMMITTING),
# (TransactionState.COORDINATOR_PREPARE,
# TransactionState.COORDINATOR_ABORTING),
- (TransactionState.COORDINATOR_COMMITTABLE,
- TransactionState.COORDINATOR_ABORTING),
+ (TransactionState.COORDINATOR_COMMITTABLE, TransactionState.COORDINATOR_ABORTING),
(TransactionState.COORDINATOR_ABORTING, TransactionState.ABORTED),
(TransactionState.READY, TransactionState.PARTICIPANT_PREPARE),
# (TransactionState.PARTICIPANT_PREPARE,
# TransactionState.PARTICIPANT_COMMITTABLE),
- (TransactionState.PARTICIPANT_COMMITTABLE,
- TransactionState.PARTICIPANT_COMMITTING),
+ (TransactionState.PARTICIPANT_COMMITTABLE, TransactionState.PARTICIPANT_COMMITTING),
# (TransactionState.PARTICIPANT_PREPARE,
# TransactionState.PARTICIPANT_ABORTING),
- (TransactionState.PARTICIPANT_COMMITTABLE,
- TransactionState.PARTICIPANT_ABORTING),
+ (TransactionState.PARTICIPANT_COMMITTABLE, TransactionState.PARTICIPANT_ABORTING),
# (TransactionState.PARTICIPANT_ABORTING,
# TransactionState.ABORTED),
]
IGNORED_TRANSACTION_TRANSITIONS = [
- (TransactionState.PARTICIPANT_COMMITTABLE,
- TransactionState.PARTICIPANT_PREPARE),
+ (TransactionState.PARTICIPANT_COMMITTABLE, TransactionState.PARTICIPANT_PREPARE),
]
-def _merge_variables(base, new, access_mode):
- new_dict = {i.name: i.value for i in new}
- for var in base:
- if var.access_mode in access_mode and var.name in new_dict:
- # use json.dumps to escape " in peer's input, a"b ----> "a\"b"
- # and use [1:-1] to remove ", "a\"b" ----> a\"b
- var.value = json.dumps(new_dict[var.name])[1:-1]
-
+def compare_yaml_templates_in_wf(wf_a: workflow_definition_pb2.WorkflowDefinition,
+ wf_b: workflow_definition_pb2.WorkflowDefinition):
+ """"Compare two WorkflowDefinition's each template,
+ return True if any job different"""
+ if len(wf_a.job_definitions) != len(wf_b.job_definitions):
+ return False
+ job_defs_a = wf_a.job_definitions
+ job_defs_b = wf_b.job_definitions
+ return any([
+ job_defs_a[i].yaml_template != job_defs_b[i].yaml_template or job_defs_a[i].name != job_defs_b[i].name
+ for i in range(len(job_defs_a))
+ ])
-def _merge_workflow_config(base, new, access_mode):
- _merge_variables(base.variables, new.variables, access_mode)
- if not new.job_definitions:
- return
- assert len(base.job_definitions) == len(new.job_definitions)
- for base_job, new_job in \
- zip(base.job_definitions, new.job_definitions):
- _merge_variables(base_job.variables, new_job.variables, access_mode)
-
-@to_dict_mixin(ignores=['fork_proposal_config', 'config'],
+@to_dict_mixin(ignores=['fork_proposal_config', 'config', 'editor_info'],
extras={
'job_ids': (lambda wf: wf.get_job_ids()),
'create_job_flags': (lambda wf: wf.get_create_job_flags()),
- 'peer_create_job_flags':
- (lambda wf: wf.get_peer_create_job_flags()),
+ 'peer_create_job_flags': (lambda wf: wf.get_peer_create_job_flags()),
'state': (lambda wf: wf.get_state_for_frontend()),
- 'transaction_state':
- (lambda wf: wf.get_transaction_state_for_frontend()),
- 'batch_update_interval':
- (lambda wf: wf.get_batch_update_interval()),
+ 'is_local': (lambda wf: wf.is_local())
})
class Workflow(db.Model):
__tablename__ = 'workflow_v2'
__table_args__ = (UniqueConstraint('uuid', name='uniq_uuid'),
- UniqueConstraint('name', name='uniq_name'), {
+ UniqueConstraint('project_id', 'name', name='uniq_name_in_project'), {
'comment': 'workflow_v2',
'mysql_engine': 'innodb',
'mysql_charset': 'utf8mb4',
})
- id = db.Column(db.Integer, primary_key=True, comment='id')
+ id = db.Column(db.Integer, primary_key=True, comment='id', autoincrement=True)
uuid = db.Column(db.String(64), comment='uuid')
name = db.Column(db.String(255), comment='name')
project_id = db.Column(db.Integer, comment='project_id')
+ template_id = db.Column(db.Integer, comment='template_id', nullable=True)
+ template_revision_id = db.Column(db.Integer, comment='template_revision_id', nullable=True)
+ editor_info = deferred(db.Column(db.LargeBinary(16777215), comment='editor_info', default=b'', nullable=True))
# max store 16777215 bytes (16 MB)
- config = db.Column(db.LargeBinary(16777215), comment='config')
- comment = db.Column('cmt',
- db.String(255),
- key='comment',
- comment='comment')
-
- metric_is_public = db.Column(db.Boolean(),
- default=False,
- nullable=False,
- comment='metric_is_public')
+ config = deferred(db.Column(db.LargeBinary(16777215), comment='config'))
+ comment = db.Column('cmt', db.String(255), key='comment', comment='comment')
+
+ metric_is_public = db.Column(db.Boolean(), default=False, nullable=False, comment='metric_is_public')
create_job_flags = db.Column(db.TEXT(), comment='create_job_flags')
job_ids = db.Column(db.TEXT(), comment='job_ids')
@@ -162,31 +186,22 @@ class Workflow(db.Model):
forkable = db.Column(db.Boolean, default=False, comment='forkable')
forked_from = db.Column(db.Integer, default=None, comment='forked_from')
# index in config.job_defs instead of job's id
- peer_create_job_flags = db.Column(db.TEXT(),
- comment='peer_create_job_flags')
+ peer_create_job_flags = db.Column(db.TEXT(), comment='peer_create_job_flags')
# max store 16777215 bytes (16 MB)
- fork_proposal_config = db.Column(db.LargeBinary(16777215),
- comment='fork_proposal_config')
-
- recur_type = db.Column(db.Enum(RecurType, native_enum=False),
- default=RecurType.NONE,
- comment='recur_type')
- recur_at = db.Column(db.Interval, comment='recur_at')
+ fork_proposal_config = db.Column(db.LargeBinary(16777215), comment='fork_proposal_config')
trigger_dataset = db.Column(db.Integer, comment='trigger_dataset')
- last_triggered_batch = db.Column(db.Integer,
- comment='last_triggered_batch')
+ last_triggered_batch = db.Column(db.Integer, comment='last_triggered_batch')
- state = db.Column(db.Enum(WorkflowState,
- native_enum=False,
- name='workflow_state'),
+ state = db.Column(db.Enum(WorkflowState, native_enum=False, create_constraint=False, name='workflow_state'),
default=WorkflowState.INVALID,
comment='state')
target_state = db.Column(db.Enum(WorkflowState,
native_enum=False,
+ create_constraint=False,
name='workflow_target_state'),
default=WorkflowState.INVALID,
comment='target_state')
- transaction_state = db.Column(db.Enum(TransactionState, native_enum=False),
+ transaction_state = db.Column(db.Enum(TransactionState, native_enum=False, create_constraint=False),
default=TransactionState.READY,
comment='transaction_state')
transaction_err = db.Column(db.Text(), comment='transaction_err')
@@ -194,56 +209,86 @@ class Workflow(db.Model):
start_at = db.Column(db.Integer, comment='start_at')
stop_at = db.Column(db.Integer, comment='stop_at')
- created_at = db.Column(db.DateTime(timezone=True),
- server_default=func.now(),
- comment='created_at')
+ created_at = db.Column(db.DateTime(timezone=True), server_default=func.now(), comment='created_at')
updated_at = db.Column(db.DateTime(timezone=True),
onupdate=func.now(),
server_default=func.now(),
comment='update_at')
- extra = db.Column(db.Text(), comment='extra') # json string
+ extra = db.Column(db.Text(), comment='json string that will be send to peer') # deprecated
+ local_extra = db.Column(db.Text(), comment='json string that will only be store locally') # deprecated
+ cron_config = db.Column('cronjob_config', db.Text(), key='cron_config', comment='cronjob json string')
+
+ creator = db.Column(db.String(255), comment='the username of the creator')
+ favour = db.Column(db.Boolean, default=False, comment='favour')
owned_jobs = db.relationship(
- 'Job', primaryjoin='foreign(Job.workflow_id) == Workflow.id')
- project = db.relationship(
- 'Project', primaryjoin='Project.id == foreign(Workflow.project_id)')
+ 'Job',
+ primaryjoin='foreign(Job.workflow_id) == Workflow.id',
+ # To disable the warning of back_populates
+ overlaps='workflow')
+ project = db.relationship(Project.__name__, primaryjoin='Project.id == foreign(Workflow.project_id)')
+ template = db.relationship(WorkflowTemplate.__name__,
+ primaryjoin='WorkflowTemplate.id == foreign(Workflow.template_id)')
+ template_revision = db.relationship(
+ WorkflowTemplateRevision.__name__,
+ primaryjoin='WorkflowTemplateRevision.id == foreign(Workflow.template_revision_id)',
+ # To disable the warning of back_populates
+ overlaps='workflows')
+
+ def is_finished(self) -> bool:
+ return all([job.is_disabled or job.state == JobState.COMPLETED for job in self.owned_jobs])
+
+ def is_failed(self) -> bool:
+ return any([job.state == JobState.FAILED for job in self.owned_jobs])
+
+ def get_state_for_frontend(self) -> WorkflowExternalState:
+ """Get workflow states that frontend need."""
+
+ # states in workflow creating stage.
+ if self.state == WorkflowState.NEW \
+ and self.target_state == WorkflowState.READY:
+ if self.transaction_state in [
+ TransactionState.PARTICIPANT_COMMITTABLE, TransactionState.PARTICIPANT_COMMITTING,
+ TransactionState.COORDINATOR_COMMITTING
+ ]:
+ return WorkflowExternalState.WARMUP_UNDERHOOD
+ if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
+ return WorkflowExternalState.PENDING_ACCEPT
+ if self.transaction_state in [
+ TransactionState.READY, TransactionState.COORDINATOR_COMMITTABLE,
+ TransactionState.COORDINATOR_PREPARE
+ ]:
+ return WorkflowExternalState.PARTICIPANT_CONFIGURING
+
+ # static state
+ if self.state == WorkflowState.READY:
+ return WorkflowExternalState.READY_TO_RUN
- def get_state_for_frontend(self):
if self.state == WorkflowState.RUNNING:
- is_complete = all([job.is_disabled or
- job.state == JobState.COMPLETED
- for job in self.owned_jobs])
- if is_complete:
- return 'COMPLETED'
- is_failed = any([job.state == JobState.FAILED
- for job in self.owned_jobs])
- if is_failed:
- return 'FAILED'
- return self.state.name
-
- def get_transaction_state_for_frontend(self):
- # TODO(xiangyuxuan): remove this hack by redesign 2pc
- if (self.transaction_state == TransactionState.PARTICIPANT_PREPARE
- and self.config is not None):
- return 'PARTICIPANT_COMMITTABLE'
- return self.transaction_state.name
-
- def set_config(self, proto):
+ return WorkflowExternalState.RUNNING
+
+ if self.state == WorkflowState.STOPPED:
+ return WorkflowExternalState.STOPPED
+ if self.state == WorkflowState.COMPLETED:
+ return WorkflowExternalState.COMPLETED
+ if self.state == WorkflowState.FAILED:
+ return WorkflowExternalState.FAILED
+
+ if self.state == WorkflowState.INVALID:
+ return WorkflowExternalState.INVALID
+
+ return WorkflowExternalState.UNKNOWN
+
+ def set_config(self, proto: WorkflowDefinition):
if proto is not None:
self.config = proto.SerializeToString()
- job_defs = {i.name: i for i in proto.job_definitions}
- for job in self.owned_jobs:
- name = job.get_config().name
- assert name in job_defs, \
- f'Invalid workflow template: job {name} is missing'
- job.set_config(job_defs[name])
else:
self.config = None
- def get_config(self):
+ def get_config(self) -> Optional[WorkflowDefinition]:
if self.config is not None:
- proto = workflow_definition_pb2.WorkflowDefinition()
+ proto = WorkflowDefinition()
proto.ParseFromString(self.config)
return proto
return None
@@ -269,8 +314,9 @@ def get_job_ids(self):
return []
return [int(i) for i in self.job_ids.split(',')]
- def get_jobs(self):
- return [Job.query.get(i) for i in self.get_job_ids()]
+ def get_jobs(self, session) -> List[Job]:
+ job_ids = self.get_job_ids()
+ return session.query(Job).filter(Job.id.in_(job_ids)).all()
def set_create_job_flags(self, create_job_flags):
if not create_job_flags:
@@ -306,256 +352,76 @@ def get_peer_create_job_flags(self):
return None
return [int(i) for i in self.peer_create_job_flags.split(',')]
- def get_batch_update_interval(self):
- item = SchedulerItem.query.filter_by(
- name=f'workflow_cron_job_{self.id}').first()
- if not item:
- return -1
- return int(item.interval_time) / 60
+ def to_workflow_ref(self) -> WorkflowRef:
+ return WorkflowRef(id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ project_id=self.project_id,
+ state=self.get_state_for_frontend().name,
+ created_at=to_timestamp(self.created_at),
+ forkable=self.forkable,
+ metric_is_public=self.metric_is_public,
+ favour=self.favour)
+
+ def to_proto(self) -> WorkflowPb:
+ return WorkflowPb(id=self.id,
+ name=self.name,
+ uuid=self.uuid,
+ project_id=self.project_id,
+ state=self.get_state_for_frontend().name,
+ created_at=to_timestamp(self.created_at),
+ forkable=self.forkable,
+ metric_is_public=self.metric_is_public,
+ favour=self.favour,
+ template_revision_id=self.template_revision_id,
+ template_id=self.template_id,
+ config=self.get_config(),
+ editor_info=self.get_editor_info(),
+ comment=self.comment,
+ job_ids=self.get_job_ids(),
+ create_job_flags=self.get_create_job_flags(),
+ is_local=self.is_local(),
+ forked_from=self.forked_from,
+ peer_create_job_flags=self.get_peer_create_job_flags(),
+ start_at=self.start_at,
+ stop_at=self.stop_at,
+ updated_at=to_timestamp(self.updated_at),
+ cron_config=self.cron_config,
+ creator=self.creator,
+ template_info=self.get_template_info())
+
+ def is_local(self):
+ return is_local(self.get_config(), self.get_create_job_flags())
+
+ def get_template_info(self) -> WorkflowPb.TemplateInfo:
+ template_info = WorkflowPb.TemplateInfo(id=self.template_id, is_modified=True)
+ if self.template is not None:
+ template_info.name = self.template.name
+ template_info.is_modified = compare_yaml_templates_in_wf(self.get_config(), self.template.get_config())
+
+ if self.template_revision is not None:
+ template_info.is_modified = False
+ template_info.revision_index = self.template_revision.revision_index
+ return template_info
+
+ def get_editor_info(self):
+ proto = workflow_definition_pb2.WorkflowTemplateEditorInfo()
+ if self.editor_info is not None:
+ proto.ParseFromString(self.editor_info)
+ return proto
+
+ def is_invalid(self):
+ return self.state == WorkflowState.INVALID
+
+ def can_transit_to(self, target_state: WorkflowState):
+ return (self.state, target_state) in VALID_TRANSITIONS
def update_target_state(self, target_state):
- if self.target_state != target_state \
- and self.target_state != WorkflowState.INVALID:
- raise ValueError(f'Another transaction is in progress [{self.id}]')
- if target_state not in [
- WorkflowState.READY, WorkflowState.RUNNING,
- WorkflowState.STOPPED
- ]:
- raise ValueError(f'Invalid target_state {self.target_state}')
+ if self.target_state not in [target_state, WorkflowState.INVALID]:
+ raise ValueError(f'Another transaction is in progress ' f'[{self.id}]')
+ if target_state != WorkflowState.READY:
+ raise ValueError(f'Invalid target_state ' f'{self.target_state}')
if (self.state, target_state) not in VALID_TRANSITIONS:
- raise ValueError(
- f'Invalid transition from {self.state} to {target_state}')
+ raise ValueError(f'Invalid transition from ' f'{self.state} to {target_state}')
self.target_state = target_state
-
- def update_state(self, asserted_state, target_state, transaction_state):
- assert asserted_state is None or self.state == asserted_state, \
- 'Cannot change current state directly'
-
- if transaction_state != self.transaction_state:
- if (self.transaction_state, transaction_state) in \
- IGNORED_TRANSACTION_TRANSITIONS:
- return self.transaction_state
- assert (self.transaction_state, transaction_state) in \
- VALID_TRANSACTION_TRANSITIONS, \
- 'Invalid transaction transition from {} to {}'.format(
- self.transaction_state, transaction_state)
- self.transaction_state = transaction_state
-
- # coordinator prepare & rollback
- if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
- self.prepare(target_state)
- if self.transaction_state == TransactionState.COORDINATOR_ABORTING:
- self.rollback()
-
- # participant prepare & rollback & commit
- if self.transaction_state == TransactionState.PARTICIPANT_PREPARE:
- self.prepare(target_state)
- if self.transaction_state == TransactionState.PARTICIPANT_ABORTING:
- self.rollback()
- self.transaction_state = TransactionState.ABORTED
- if self.transaction_state == TransactionState.PARTICIPANT_COMMITTING:
- self.commit()
-
- return self.transaction_state
-
- def prepare(self, target_state):
- assert self.transaction_state in [
- TransactionState.COORDINATOR_PREPARE,
- TransactionState.PARTICIPANT_PREPARE], \
- 'Workflow not in prepare state'
-
- # TODO(tjulinfan): remove this
- if target_state is None:
- # No action
- return
-
- # Validation
- try:
- self.update_target_state(target_state)
- except ValueError as e:
- logging.warning('Error during update target state in prepare: %s',
- str(e))
- self.transaction_state = TransactionState.ABORTED
- return
-
- success = True
- if self.target_state == WorkflowState.READY:
- success = self._prepare_for_ready()
-
- if success:
- if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
- self.transaction_state = \
- TransactionState.COORDINATOR_COMMITTABLE
- else:
- self.transaction_state = \
- TransactionState.PARTICIPANT_COMMITTABLE
-
- def rollback(self):
- self.target_state = WorkflowState.INVALID
-
- def start(self):
- self.start_at = int(datetime.now().timestamp())
- for job in self.owned_jobs:
- if not job.is_disabled:
- job.schedule()
-
- def stop(self):
- self.stop_at = int(datetime.now().timestamp())
- for job in self.owned_jobs:
- job.stop()
-
- # TODO: separate this method to another module
- def commit(self):
- assert self.transaction_state in [
- TransactionState.COORDINATOR_COMMITTING,
- TransactionState.PARTICIPANT_COMMITTING], \
- 'Workflow not in prepare state'
-
- if self.target_state == WorkflowState.STOPPED:
- try:
- self.stop()
- except RuntimeError as e:
- # errors from k8s
- logging.error('Stop workflow %d has error msg: %s',
- self.id, e.args)
- return
- elif self.target_state == WorkflowState.READY:
- self._setup_jobs()
- self.fork_proposal_config = None
- elif self.target_state == WorkflowState.RUNNING:
- self.start()
-
- self.state = self.target_state
- self.target_state = WorkflowState.INVALID
- self.transaction_state = TransactionState.READY
-
- def invalidate(self):
- self.state = WorkflowState.INVALID
- self.target_state = WorkflowState.INVALID
- self.transaction_state = TransactionState.READY
- for job in self.owned_jobs:
- try:
- job.stop()
- except Exception as e: # pylint: disable=broad-except
- logging.warning(
- 'Error while stopping job %s during invalidation: %s',
- job.name, repr(e))
-
- def _setup_jobs(self):
- if self.forked_from is not None:
- trunk = Workflow.query.get(self.forked_from)
- assert trunk is not None, \
- 'Source workflow %d not found' % self.forked_from
- trunk_job_defs = trunk.get_config().job_definitions
- trunk_name2index = {
- job.name: i
- for i, job in enumerate(trunk_job_defs)
- }
-
- job_defs = self.get_config().job_definitions
- flags = self.get_create_job_flags()
- assert len(job_defs) == len(flags), \
- 'Number of job defs does not match number of create_job_flags ' \
- '%d vs %d'%(len(job_defs), len(flags))
- jobs = []
- for i, (job_def, flag) in enumerate(zip(job_defs, flags)):
- if flag == common_pb2.CreateJobFlag.REUSE:
- assert job_def.name in trunk_name2index, \
- f'Job {job_def.name} not found in base workflow'
- j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
- job = Job.query.get(j)
- assert job is not None, \
- 'Job %d not found' % j
- # TODO: check forked jobs does not depend on non-forked jobs
- else:
- job = Job(
- name=f'{self.uuid}-{job_def.name}',
- job_type=JobType(job_def.job_type),
- config=job_def.SerializeToString(),
- workflow_id=self.id,
- project_id=self.project_id,
- state=JobState.NEW,
- is_disabled=(flag == common_pb2.CreateJobFlag.DISABLED))
- db.session.add(job)
- jobs.append(job)
- db.session.flush()
- name2index = {job.name: i for i, job in enumerate(job_defs)}
- for i, (job, flag) in enumerate(zip(jobs, flags)):
- if flag == common_pb2.CreateJobFlag.REUSE:
- continue
- for j, dep_def in enumerate(job.get_config().dependencies):
- dep = JobDependency(
- src_job_id=jobs[name2index[dep_def.source]].id,
- dst_job_id=job.id,
- dep_index=j)
- db.session.add(dep)
-
- self.set_job_ids([job.id for job in jobs])
- if Features.FEATURE_MODEL_WORKFLOW_HOOK:
- for job in jobs:
- ModelService(db.session).workflow_hook(job)
-
-
- def log_states(self):
- logging.debug(
- 'workflow %d updated to state=%s, target_state=%s, '
- 'transaction_state=%s', self.id, self.state.name,
- self.target_state.name, self.transaction_state.name)
-
- def _get_peer_workflow(self):
- project_config = self.project.get_config()
- # TODO: find coordinator for multiparty
- client = RpcClient(project_config, project_config.participants[0])
- return client.get_workflow(self.name)
-
- def _prepare_for_ready(self):
- # This is a hack, if config is not set then
- # no action needed
- if self.transaction_state == TransactionState.COORDINATOR_PREPARE:
- # TODO(tjulinfan): validate if the config is legal or not
- return bool(self.config)
-
- if self.forked_from:
- peer_workflow = self._get_peer_workflow()
- base_workflow = Workflow.query.get(self.forked_from)
- if base_workflow is None or not base_workflow.forkable:
- return False
- self.forked_from = base_workflow.id
- self.forkable = base_workflow.forkable
- self.set_create_job_flags(peer_workflow.peer_create_job_flags)
- self.set_peer_create_job_flags(peer_workflow.create_job_flags)
- config = base_workflow.get_config()
- _merge_workflow_config(config, peer_workflow.fork_proposal_config,
- [common_pb2.Variable.PEER_WRITABLE])
- self.set_config(config)
- return True
-
- return bool(self.config)
-
- def is_local(self):
- # since _setup_jobs has not been called, job_definitions is used
- job_defs = self.get_config().job_definitions
- flags = self.get_create_job_flags()
- for i, (job_def, flag) in enumerate(zip(job_defs, flags)):
- if flag != common_pb2.CreateJobFlag.REUSE and job_def.is_federated:
- return False
- return True
-
- def update_local_state(self):
- if self.target_state == WorkflowState.INVALID:
- return
- if self.target_state == WorkflowState.READY:
- self._setup_jobs()
- elif self.target_state == WorkflowState.RUNNING:
- self.start()
- elif self.target_state == WorkflowState.STOPPED:
- try:
- self.stop()
- except Exception as e:
- # errors from k8s
- logging.error('Stop workflow %d has error msg: %s',
- self.id, e.args)
- return
- self.state = self.target_state
- self.target_state = WorkflowState.INVALID
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/models_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/models_test.py
new file mode 100644
index 000000000..11c89cca0
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/models_test.py
@@ -0,0 +1,212 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from datetime import datetime, timezone
+
+from unittest.mock import patch, PropertyMock
+
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowTemplateEditorInfo
+from fedlearner_webconsole.proto.workflow_pb2 import WorkflowRef, WorkflowPb
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplateRevision, WorkflowTemplate
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.workflow.models import (Workflow, WorkflowState, TransactionState, WorkflowExternalState)
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.project.models import Project
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class WorkflowTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=0)
+
+ with db.session_scope() as session:
+ session.add(project)
+ session.commit()
+
+ def test_get_jobs(self):
+ workflow = Workflow(id=100, job_ids='1,2,3')
+ job1 = Job(id=1, name='job 1', workflow_id=3, project_id=0, job_type=JobType.RAW_DATA)
+ job2 = Job(id=2, name='job 2', workflow_id=3, project_id=0, job_type=JobType.RAW_DATA)
+ job3 = Job(id=3, name='job 3', workflow_id=100, project_id=0, job_type=JobType.RAW_DATA)
+ with db.session_scope() as session:
+ session.add_all([workflow, job1, job2, job3])
+ session.commit()
+ jobs = workflow.get_jobs(session)
+ jobs.sort(key=lambda job: job.name)
+ self.assertEqual(jobs[0].name, 'job 1')
+ self.assertEqual(jobs[1].name, 'job 2')
+ self.assertEqual(jobs[2].name, 'job 3')
+
+ def test_workflow_state(self):
+ with db.session_scope() as session:
+ completed_workflow = Workflow(state=WorkflowState.COMPLETED)
+ failed_workflow = Workflow(state=WorkflowState.FAILED)
+
+ stopped_workflow_1 = Workflow(state=WorkflowState.STOPPED, target_state=WorkflowState.INVALID)
+ stopped_workflow_2 = Workflow(state=WorkflowState.STOPPED)
+
+ running_workflow = Workflow(state=WorkflowState.RUNNING, target_state=WorkflowState.INVALID)
+
+ warmup_underhood_workflow_1 = Workflow(state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.PARTICIPANT_COMMITTABLE)
+ warmup_underhood_workflow_2 = Workflow(state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.COORDINATOR_COMMITTING)
+
+ pending_accept_workflow = Workflow(state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.PARTICIPANT_PREPARE)
+
+ ready_to_run_workflow = Workflow(state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY)
+
+ participant_configuring_workflow_1 = Workflow(state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.READY)
+ participant_configuring_workflow_2 = Workflow(state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.COORDINATOR_COMMITTABLE)
+ participant_configuring_workflow_3 = Workflow(state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.COORDINATOR_PREPARE)
+
+ invalid_workflow = Workflow(state=WorkflowState.INVALID)
+
+ unknown_workflow = Workflow(state=WorkflowState.NEW, target_state=WorkflowState.INVALID)
+ session.add_all([
+ completed_workflow, failed_workflow, stopped_workflow_1, stopped_workflow_2, running_workflow,
+ warmup_underhood_workflow_1, warmup_underhood_workflow_2, pending_accept_workflow,
+ ready_to_run_workflow, participant_configuring_workflow_1, participant_configuring_workflow_2,
+ participant_configuring_workflow_3, invalid_workflow, unknown_workflow
+ ])
+ session.commit()
+
+ completed_job_cw = Job(job_type=JobType.RAW_DATA,
+ workflow_id=completed_workflow.id,
+ project_id=0,
+ state=JobState.COMPLETED)
+ failed_job_fw = Job(job_type=JobType.RAW_DATA,
+ workflow_id=failed_workflow.id,
+ project_id=0,
+ state=JobState.FAILED)
+ running_job_rw = Job(job_type=JobType.RAW_DATA,
+ workflow_id=running_workflow.id,
+ project_id=0,
+ state=JobState.STARTED)
+ session.add_all([completed_job_cw, failed_job_fw, running_job_rw])
+ session.commit()
+
+ self.assertEqual(completed_workflow.get_state_for_frontend(), WorkflowExternalState.COMPLETED)
+ self.assertEqual(failed_workflow.get_state_for_frontend(), WorkflowExternalState.FAILED)
+
+ self.assertEqual(stopped_workflow_1.get_state_for_frontend(), WorkflowExternalState.STOPPED)
+ self.assertEqual(stopped_workflow_2.get_state_for_frontend(), WorkflowExternalState.STOPPED)
+
+ self.assertEqual(running_workflow.get_state_for_frontend(), WorkflowExternalState.RUNNING)
+
+ self.assertEqual(warmup_underhood_workflow_1.get_state_for_frontend(),
+ WorkflowExternalState.WARMUP_UNDERHOOD)
+ self.assertEqual(warmup_underhood_workflow_2.get_state_for_frontend(),
+ WorkflowExternalState.WARMUP_UNDERHOOD)
+
+ self.assertEqual(pending_accept_workflow.get_state_for_frontend(), WorkflowExternalState.PENDING_ACCEPT)
+ self.assertEqual(ready_to_run_workflow.get_state_for_frontend(), WorkflowExternalState.READY_TO_RUN)
+
+ self.assertEqual(participant_configuring_workflow_1.get_state_for_frontend(),
+ WorkflowExternalState.PARTICIPANT_CONFIGURING)
+ self.assertEqual(participant_configuring_workflow_2.get_state_for_frontend(),
+ WorkflowExternalState.PARTICIPANT_CONFIGURING)
+ self.assertEqual(participant_configuring_workflow_3.get_state_for_frontend(),
+ WorkflowExternalState.PARTICIPANT_CONFIGURING)
+
+ self.assertEqual(invalid_workflow.get_state_for_frontend(), WorkflowExternalState.INVALID)
+ self.assertEqual(unknown_workflow.get_state_for_frontend(), WorkflowExternalState.UNKNOWN)
+
+ def test_to_workflow_ref(self):
+ created_at = datetime(2021, 10, 1, 8, 8, 8, tzinfo=timezone.utc)
+ workflow = Workflow(
+ id=123,
+ name='test',
+ uuid='uuid',
+ project_id=1,
+ state=WorkflowState.STOPPED,
+ target_state=WorkflowState.INVALID,
+ created_at=created_at,
+ forkable=True,
+ metric_is_public=False,
+ extra='{}',
+ )
+ workflow_ref = WorkflowRef(
+ id=123,
+ name='test',
+ uuid='uuid',
+ project_id=1,
+ state=WorkflowExternalState.STOPPED.name,
+ created_at=int(created_at.timestamp()),
+ forkable=True,
+ metric_is_public=False,
+ )
+ self.assertEqual(workflow.to_workflow_ref(), workflow_ref)
+
+ def test_to_proto(self):
+ created_at = datetime(2021, 10, 1, 8, 8, 8, tzinfo=timezone.utc)
+ updated_at = datetime(2021, 10, 1, 8, 8, 8, tzinfo=timezone.utc)
+ workflow = Workflow(id=123,
+ name='test',
+ uuid='uuid',
+ project_id=1,
+ state=WorkflowState.STOPPED,
+ target_state=WorkflowState.INVALID,
+ created_at=created_at,
+ forkable=True,
+ metric_is_public=False,
+ extra='{}',
+ updated_at=updated_at)
+ workflow_pb = WorkflowPb(id=123,
+ name='test',
+ uuid='uuid',
+ project_id=1,
+ state=WorkflowExternalState.STOPPED.name,
+ created_at=int(created_at.timestamp()),
+ forkable=True,
+ metric_is_public=False,
+ updated_at=int(updated_at.timestamp()),
+ editor_info=WorkflowTemplateEditorInfo(),
+ template_info=WorkflowPb.TemplateInfo(is_modified=True))
+ self.assertEqual(workflow.to_proto(), workflow_pb)
+
+ @patch('fedlearner_webconsole.workflow.models.Workflow.template_revision', new_callable=PropertyMock)
+ @patch('fedlearner_webconsole.workflow.models.Workflow.template', new_callable=PropertyMock)
+ def test_get_template_info(self, mock_template, mock_template_revision):
+ workflow = Workflow(id=123,
+ name='test',
+ uuid='uuid',
+ project_id=1,
+ template_id=1,
+ template_revision_id=1,
+ config=b'')
+ mock_template.return_value = WorkflowTemplate(id=1, name='test', config=b'')
+ mock_template_revision.return_value = WorkflowTemplateRevision(id=1, revision_index=3, template_id=1)
+ self.assertEqual(workflow.get_template_info(),
+ WorkflowPb.TemplateInfo(name='test', id=1, is_modified=False, revision_index=3))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/resource_manager.py b/web_console_v2/api/fedlearner_webconsole/workflow/resource_manager.py
new file mode 100644
index 000000000..3b9be6080
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/resource_manager.py
@@ -0,0 +1,178 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import logging
+from sqlalchemy.orm import Session
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState, \
+ VALID_TRANSACTION_TRANSITIONS, \
+ IGNORED_TRANSACTION_TRANSITIONS, TransactionState
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.workflow.service import WorkflowService
+
+
+def _merge_variables(base, new, access_mode):
+ new_dict = {i.name: i for i in new}
+ for var in base:
+ if var.access_mode in access_mode and var.name in new_dict:
+ var.typed_value.CopyFrom(new_dict[var.name].typed_value)
+ # TODO(xiangyuxuan.prs): remove when value is deprecated in Variable
+ var.value = new_dict[var.name].value
+
+
+# TODO(hangweiqiang): move it to utils
+def merge_workflow_config(base, new, access_mode):
+ _merge_variables(base.variables, new.variables, access_mode)
+ if not new.job_definitions:
+ return
+ assert len(base.job_definitions) == len(new.job_definitions)
+ for base_job, new_job in \
+ zip(base.job_definitions, new.job_definitions):
+ _merge_variables(base_job.variables, new_job.variables, access_mode)
+
+
+class ResourceManager:
+
+ def __init__(self, session: Session, workflow: Workflow):
+ self._session = session
+ self._workflow = workflow
+
+ def update_state(self, asserted_state, target_state, transaction_state):
+ if self._workflow.is_invalid():
+ return self._workflow.transaction_state
+
+ assert asserted_state is None or \
+ self._workflow.state == asserted_state, \
+ 'Cannot change current state directly'
+
+ if transaction_state != self._workflow.transaction_state:
+ if (self._workflow.transaction_state, transaction_state) in \
+ IGNORED_TRANSACTION_TRANSITIONS:
+ return self._workflow.transaction_state
+ assert (self._workflow.transaction_state, transaction_state) in \
+ VALID_TRANSACTION_TRANSITIONS, \
+ f'Invalid transaction transition from {self._workflow.transaction_state} to {transaction_state}'
+ self._workflow.transaction_state = transaction_state
+
+ # coordinator prepare & rollback
+ if self._workflow.transaction_state == \
+ TransactionState.COORDINATOR_PREPARE:
+ self.prepare(target_state)
+ if self._workflow.transaction_state == \
+ TransactionState.COORDINATOR_ABORTING:
+ self.rollback()
+
+ # participant prepare & rollback & commit
+ if self._workflow.transaction_state == \
+ TransactionState.PARTICIPANT_PREPARE:
+ self.prepare(target_state)
+ if self._workflow.transaction_state == \
+ TransactionState.PARTICIPANT_ABORTING:
+ self.rollback()
+ self._workflow.transaction_state = TransactionState.ABORTED
+ if self._workflow.transaction_state == \
+ TransactionState.PARTICIPANT_COMMITTING:
+ self.commit()
+
+ return self._workflow.transaction_state
+
+ def prepare(self, target_state):
+ assert self._workflow.transaction_state in [
+ TransactionState.COORDINATOR_PREPARE,
+ TransactionState.PARTICIPANT_PREPARE], \
+ 'Workflow not in prepare state'
+
+ # TODO(tjulinfan): remove this
+ if target_state is None:
+ # No action
+ return
+
+ # Validation
+ try:
+ self._workflow.update_target_state(target_state)
+ except ValueError as e:
+ logging.warning('Error during update target state in prepare: %s', str(e))
+ self._workflow.transaction_state = TransactionState.ABORTED
+ return
+
+ success = True
+ if self._workflow.target_state == WorkflowState.READY:
+ success = self._prepare_for_ready()
+
+ if success:
+ if self._workflow.transaction_state == \
+ TransactionState.COORDINATOR_PREPARE:
+ self._workflow.transaction_state = \
+ TransactionState.COORDINATOR_COMMITTABLE
+ else:
+ self._workflow.transaction_state = \
+ TransactionState.PARTICIPANT_COMMITTABLE
+
+ def rollback(self):
+ self._workflow.target_state = WorkflowState.INVALID
+
+ # TODO: separate this method to another module
+ def commit(self):
+ assert self._workflow.transaction_state in [
+ TransactionState.COORDINATOR_COMMITTING,
+ TransactionState.PARTICIPANT_COMMITTING], \
+ 'Workflow not in prepare state'
+
+ if self._workflow.target_state == WorkflowState.READY:
+ self._workflow.fork_proposal_config = None
+
+ self._workflow.state = self._workflow.target_state
+ self._workflow.target_state = WorkflowState.INVALID
+ self._workflow.transaction_state = TransactionState.READY
+
+ def _prepare_for_ready(self):
+ # This is a hack, if config is not set then
+ # no action needed
+ if self._workflow.transaction_state == \
+ TransactionState.COORDINATOR_PREPARE:
+ # TODO(tjulinfan): validate if the config is legal or not
+ return bool(self._workflow.config)
+
+ if self._workflow.forked_from:
+ peer_workflow = WorkflowService(self._session).get_peer_workflow(self._workflow)
+ base_workflow = self._session.query(Workflow).get(self._workflow.forked_from)
+ if base_workflow is None or not base_workflow.forkable:
+ return False
+ self._workflow.forked_from = base_workflow.id
+ self._workflow.forkable = base_workflow.forkable
+ self._workflow.set_create_job_flags(peer_workflow.peer_create_job_flags)
+ self._workflow.set_peer_create_job_flags(peer_workflow.create_job_flags)
+ config = base_workflow.get_config()
+ merge_workflow_config(config, peer_workflow.fork_proposal_config, [common_pb2.Variable.PEER_WRITABLE])
+ WorkflowService(self._session).update_config(self._workflow, config)
+ logging.error(base_workflow.to_dict())
+ self._workflow.template_id = base_workflow.template_id
+ self._workflow.editor_info = base_workflow.editor_info
+ # TODO: set forked workflow in grpc server
+ WorkflowService(self._session).setup_jobs(self._workflow)
+ return True
+
+ return bool(self._workflow.config)
+
+ def log_states(self):
+ workflow = self._workflow
+ logging.debug('workflow %d updated to state=%s, target_state=%s, '
+ 'transaction_state=%s', workflow.id, workflow.state.name, workflow.target_state.name,
+ workflow.transaction_state.name)
+
+ def update_local_state(self):
+ if self._workflow.target_state == WorkflowState.INVALID:
+ return
+ self._workflow.state = self._workflow.target_state
+ self._workflow.target_state = WorkflowState.INVALID
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/service.py b/web_console_v2/api/fedlearner_webconsole/workflow/service.py
new file mode 100644
index 000000000..e75f8ab4e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/service.py
@@ -0,0 +1,372 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import List, Tuple, Union, Optional
+
+from sqlalchemy import or_, and_
+from sqlalchemy.orm import Session, Query
+
+from fedlearner_webconsole.composer.interface import ItemType
+from fedlearner_webconsole.job.service import JobService
+from fedlearner_webconsole.proto.composer_pb2 import RunnerInput, WorkflowCronJobInput
+from fedlearner_webconsole.proto.filtering_pb2 import FilterExpression, FilterOp, SimpleExpression
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.utils.const import SYSTEM_WORKFLOW_CREATOR_USERNAME
+from fedlearner_webconsole.utils.filtering import SupportedField, FilterBuilder, FieldType
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.composer.composer_service import CronJobService
+from fedlearner_webconsole.exceptions import InvalidArgumentException, ResourceConflictException
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.job.yaml_formatter import YamlFormatterService
+from fedlearner_webconsole.utils.workflow import build_job_name
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState, TransactionState
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from fedlearner_webconsole.workflow.utils import is_local, is_peer_job_inheritance_matched
+from fedlearner_webconsole.job.models import Job, JobType, JobState, JobDependency
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.utils.resource_name import resource_uuid
+
+
+def update_cronjob_config(workflow_id: int, cron_config: str, session: Session):
+ """Starts a cronjob for workflow if cron_config is valid.
+
+ Args:
+ workflow_id (int): id of the workflow
+ cron_config (str): cron expression;
+ if cron_config is None or '', cancel previous cron setting
+ session: db session
+ Raises:
+ Raise if some check violates
+ InvalidArgumentException: if some check violates
+ """
+ item_name = f'workflow_cron_job_{workflow_id}'
+ if cron_config:
+ rinput = RunnerInput(workflow_cron_job_input=WorkflowCronJobInput(workflow_id=workflow_id))
+ items = [(ItemType.WORKFLOW_CRON_JOB, rinput)]
+ CronJobService(session).start_cronjob(item_name=item_name, items=items, cron_config=cron_config)
+ else:
+ CronJobService(session).stop_cronjob(item_name=item_name)
+
+
+class ForkWorkflowParams(object):
+
+ def __init__(self, fork_from_id: int, fork_proposal_config: WorkflowDefinition, peer_create_job_flags: List[int]):
+ self.fork_from_id = fork_from_id
+ self.fork_proposal_config = fork_proposal_config
+ self.peer_create_job_flags = peer_create_job_flags
+
+
+class CreateNewWorkflowParams(object):
+
+ def __init__(self, project_id: int, template_id: Optional[int], template_revision_id: Optional[int] = None):
+ self.project_id = project_id
+ self.template_id = template_id
+ self.template_revision_id = template_revision_id
+
+
+def _filter_system_workflow(exp: SimpleExpression):
+ if exp.bool_value:
+ return Workflow.creator == SYSTEM_WORKFLOW_CREATOR_USERNAME
+ # != Null or == Null will always return Null in mysql.
+ return or_(Workflow.creator != SYSTEM_WORKFLOW_CREATOR_USERNAME, Workflow.creator.is_(None))
+
+
+class WorkflowService:
+ FILTER_FIELDS = {'system': SupportedField(type=FieldType.BOOL, ops={FilterOp.EQUAL: _filter_system_workflow})}
+
+ def __init__(self, session):
+ self._session = session
+ self._filter_builder = FilterBuilder(model_class=Workflow, supported_fields=self.FILTER_FIELDS)
+
+ def build_filter_query(self, query: Query, exp: FilterExpression) -> Query:
+ return self._filter_builder.build_query(query, exp)
+
+ def validate_workflow(self, workflow: Workflow) -> Tuple[bool, tuple]:
+ for job in workflow.owned_jobs:
+ try:
+ YamlFormatterService(self._session).generate_job_run_yaml(job)
+ except Exception as e: # pylint: disable=broad-except
+ return False, (job.name, e)
+ return True, ()
+
+ @staticmethod
+ def filter_workflows(query: Query, states: List[str]) -> Query:
+ query_states = []
+ filters = []
+ for state in states:
+ query_states.append(state.upper())
+ # TODO(xiangyuxuan.prs): simplify Workflow create to remove the specific process for states below.
+ # The logic of process is same as get_state_for_frontend.
+ if state == 'warmup':
+ filters.append(
+ and_(
+ Workflow.state == WorkflowState.NEW, Workflow.target_state == WorkflowState.READY,
+ Workflow.transaction_state.in_([
+ TransactionState.PARTICIPANT_COMMITTABLE, TransactionState.PARTICIPANT_COMMITTING,
+ TransactionState.COORDINATOR_COMMITTING
+ ])))
+ if state == 'pending':
+ filters.append(
+ and_(Workflow.state == WorkflowState.NEW, Workflow.target_state == WorkflowState.READY,
+ Workflow.transaction_state == TransactionState.PARTICIPANT_PREPARE))
+ if state == 'configuring':
+ filters.append(
+ and_(
+ Workflow.state == WorkflowState.NEW, Workflow.target_state == WorkflowState.READY,
+ Workflow.transaction_state.in_([
+ TransactionState.READY, TransactionState.COORDINATOR_COMMITTABLE,
+ TransactionState.COORDINATOR_PREPARE
+ ])))
+ filters.append(Workflow.state.in_(query_states))
+ query = query.filter(or_(*filters))
+ return query
+
+ def _check_conflict(self, workflow_name: str, project_id: int):
+ if self._session.query(Workflow).filter_by(name=workflow_name).filter_by(
+ project_id=project_id).first() is not None:
+ raise ResourceConflictException(f'Workflow {workflow_name} already exists in project: {project_id}.')
+
+ def create_workflow(self,
+ name: str,
+ config: WorkflowDefinition,
+ params: Union[CreateNewWorkflowParams, ForkWorkflowParams],
+ forkable: bool = False,
+ comment: Optional[str] = None,
+ create_job_flags: Optional[List[int]] = None,
+ cron_config: Optional[str] = None,
+ creator_username: str = None,
+ uuid: Optional[str] = None,
+ state: WorkflowState = WorkflowState.NEW,
+ target_state: WorkflowState = WorkflowState.READY):
+ # Parameter validations
+ parent_workflow = None
+ template = None
+ project_id = None
+ if isinstance(params, ForkWorkflowParams):
+ # Fork mode
+ parent_workflow = self._session.query(Workflow).get(params.fork_from_id)
+ if parent_workflow is None:
+ raise InvalidArgumentException('fork_from_id is not valid')
+ if not parent_workflow.forkable:
+ raise InvalidArgumentException('workflow not forkable')
+ project_id = parent_workflow.project_id
+ self._check_conflict(name, project_id)
+ # it is possible that parent_workflow.template is None
+ template = parent_workflow.template
+ if not is_local(config, create_job_flags):
+ participants = ParticipantService(self._session).get_platform_participants_by_project(
+ parent_workflow.project.id)
+ if not is_peer_job_inheritance_matched(project=parent_workflow.project,
+ workflow_definition=config,
+ job_flags=create_job_flags,
+ peer_job_flags=params.peer_create_job_flags,
+ parent_uuid=parent_workflow.uuid,
+ parent_name=parent_workflow.name,
+ participants=participants):
+ raise ValueError('Forked workflow has federated job with ' 'unmatched inheritance')
+ else:
+ # Create new mode
+ project_id = params.project_id
+ self._check_conflict(name, project_id)
+ if params.template_id:
+ template = self._session.query(WorkflowTemplate).get(params.template_id)
+ assert template is not None
+ assert project_id is not None
+ if uuid is None:
+ uuid = resource_uuid()
+ workflow = Workflow(name=name,
+ uuid=uuid,
+ comment=comment,
+ project_id=project_id,
+ forkable=forkable,
+ forked_from=None if parent_workflow is None else parent_workflow.id,
+ state=state,
+ target_state=target_state,
+ transaction_state=TransactionState.READY,
+ creator=creator_username,
+ template_revision_id=parent_workflow.template_revision_id
+ if parent_workflow else params.template_revision_id)
+ if template:
+ workflow.template_id = template.id
+ workflow.editor_info = template.editor_info
+ self.update_config(workflow, config)
+ workflow.set_create_job_flags(create_job_flags)
+ if isinstance(params, ForkWorkflowParams):
+ # Fork mode
+ # TODO(hangweiqiang): more validations
+ workflow.set_fork_proposal_config(params.fork_proposal_config)
+ workflow.set_peer_create_job_flags(params.peer_create_job_flags)
+ self._session.add(workflow)
+ # To get workflow id
+ self._session.flush()
+ if cron_config is not None:
+ workflow.cron_config = cron_config
+ update_cronjob_config(workflow.id, cron_config, self._session)
+ self.setup_jobs(workflow)
+ return workflow
+
+ def config_workflow(self,
+ workflow: Workflow,
+ template_id: int,
+ config: Optional[WorkflowDefinition] = None,
+ forkable: bool = False,
+ comment: Optional[str] = None,
+ cron_config: Optional[str] = None,
+ create_job_flags: Optional[List[int]] = None,
+ creator_username: Optional[str] = None,
+ template_revision_id: Optional[int] = None) -> Workflow:
+ if workflow.config:
+ raise ValueError('Resetting workflow is not allowed')
+ workflow.comment = comment
+ workflow.forkable = forkable
+ workflow.creator = creator_username
+ workflow.set_config(config)
+ workflow.set_create_job_flags(create_job_flags)
+ workflow.update_target_state(WorkflowState.READY)
+ workflow.template_id = template_id
+ workflow.template_revision_id = template_revision_id
+ self._session.flush()
+ if workflow.template is None:
+ emit_store('template_not_found', 1)
+ raise ValueError('template not found')
+ workflow.editor_info = workflow.template.editor_info
+ if cron_config is not None:
+ workflow.cron_config = cron_config
+ update_cronjob_config(workflow.id, cron_config, self._session)
+ self.setup_jobs(workflow)
+ return workflow
+
+ def patch_workflow(self,
+ workflow: Workflow,
+ forkable: Optional[bool] = None,
+ metric_is_public: Optional[bool] = None,
+ config: Optional[WorkflowDefinition] = None,
+ template_id: Optional[int] = None,
+ create_job_flags: List[int] = None,
+ cron_config: Optional[str] = None,
+ favour: Optional[bool] = None,
+ template_revision_id: Optional[int] = None):
+ if forkable is not None:
+ workflow.forkable = forkable
+ if metric_is_public is not None:
+ workflow.metric_is_public = metric_is_public
+
+ if config:
+ if workflow.target_state != WorkflowState.INVALID or \
+ workflow.state not in \
+ [WorkflowState.READY, WorkflowState.STOPPED, WorkflowState.COMPLETED,
+ WorkflowState.FAILED]:
+ raise ValueError('Cannot edit running workflow')
+ self.update_config(workflow, config)
+ workflow.template_id = template_id
+ self._session.flush()
+ if workflow.template is not None:
+ workflow.editor_info = workflow.template.editor_info
+ self._session.flush()
+
+ if create_job_flags:
+ jobs = [self._session.query(Job).get(i) for i in workflow.get_job_ids()]
+ if len(create_job_flags) != len(jobs):
+ raise ValueError(f'Number of job defs does not match number of '
+ f'create_job_flags {len(jobs)} vs {len(create_job_flags)}')
+ workflow.set_create_job_flags(create_job_flags)
+ flags = workflow.get_create_job_flags()
+ for i, job in enumerate(jobs):
+ if job.workflow_id == workflow.id:
+ job.is_disabled = flags[i] == \
+ common_pb2.CreateJobFlag.DISABLED
+
+ # start workflow periodically.
+ # Session.commit inside, so this part must be the last of the api
+ # to guarantee atomicity.
+ if cron_config is not None:
+ workflow.cron_config = cron_config
+ update_cronjob_config(workflow.id, cron_config, self._session)
+
+ if favour is not None:
+ workflow.favour = favour
+
+ if template_revision_id is not None:
+ workflow.template_revision_id = template_revision_id
+
+ def setup_jobs(self, workflow: Workflow):
+ if workflow.forked_from is not None:
+ trunk = self._session.query(Workflow).get(workflow.forked_from)
+ assert trunk is not None, \
+ f'Source workflow {workflow.forked_from} not found'
+ trunk_job_defs = trunk.get_config().job_definitions
+ trunk_name2index = {job.name: i for i, job in enumerate(trunk_job_defs)}
+
+ job_defs = workflow.get_config().job_definitions
+ flags = workflow.get_create_job_flags()
+ assert len(job_defs) == len(flags), \
+ f'Number of job defs does not match number of create_job_flags {len(job_defs)} vs {len(flags)}'
+ jobs = []
+ for i, (job_def, flag) in enumerate(zip(job_defs, flags)):
+ if flag == common_pb2.CreateJobFlag.REUSE:
+ assert job_def.name in trunk_name2index, \
+ f'Job {job_def.name} not found in base workflow'
+ j = trunk.get_job_ids()[trunk_name2index[job_def.name]]
+ job = self._session.query(Job).get(j)
+ assert job is not None, f'Job {j} not found'
+ # TODO: check forked jobs does not depend on non-forked jobs
+ else:
+ job = Job(name=build_job_name(workflow.uuid, job_def.name),
+ job_type=JobType(job_def.job_type),
+ workflow_id=workflow.id,
+ project_id=workflow.project_id,
+ state=JobState.NEW,
+ is_disabled=(flag == common_pb2.CreateJobFlag.DISABLED))
+ self._session.add(job)
+ self._session.flush()
+ JobService(self._session).set_config_and_crd_info(job, job_def)
+ jobs.append(job)
+ self._session.refresh(workflow)
+ name2index = {job.name: i for i, job in enumerate(job_defs)}
+ for i, (job, flag) in enumerate(zip(jobs, flags)):
+ if flag == common_pb2.CreateJobFlag.REUSE:
+ continue
+ for j, dep_def in enumerate(job.get_config().dependencies):
+ dep = JobDependency(src_job_id=jobs[name2index[dep_def.source]].id, dst_job_id=job.id, dep_index=j)
+ self._session.add(dep)
+
+ workflow.set_job_ids([job.id for job in jobs])
+
+ def get_peer_workflow(self, workflow: Workflow):
+ service = ParticipantService(self._session)
+ participants = service.get_platform_participants_by_project(workflow.project.id)
+ # TODO: find coordinator for multiparty
+ client = RpcClient.from_project_and_participant(workflow.project.name, workflow.project.token,
+ participants[0].domain_name)
+ return client.get_workflow(workflow.uuid, workflow.name)
+
+ def is_federated_workflow_finished(self, workflow: Workflow):
+ if not workflow.is_finished():
+ return False
+ return workflow.is_local() or self.get_peer_workflow(workflow).is_finished
+
+ def should_auto_stop(self, workflow: Workflow):
+ return workflow.is_failed() or self.is_federated_workflow_finished(workflow)
+
+ def update_config(self, workflow: Workflow, proto: WorkflowDefinition):
+ workflow.set_config(proto)
+ if proto is not None:
+ job_defs = {i.name: i for i in proto.job_definitions}
+ for job in workflow.owned_jobs:
+ name = job.get_config().name
+ assert name in job_defs, \
+ f'Invalid workflow template: job {name} is missing'
+ JobService(self._session).set_config_and_crd_info(job, job_defs[name])
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/service_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/service_test.py
new file mode 100644
index 000000000..a2dce2033
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/service_test.py
@@ -0,0 +1,123 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import patch
+
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.workflow.service import WorkflowService
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition
+from fedlearner_webconsole.job.models import Job, JobType, JobState
+from fedlearner_webconsole.workflow.models import (Workflow, WorkflowState, TransactionState)
+from fedlearner_webconsole.workflow.service import update_cronjob_config
+from fedlearner_webconsole.composer.models import SchedulerItem, ItemStatus
+
+
+class WorkflowServiceTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=0)
+
+ with db.session_scope() as session:
+ session.add(project)
+ session.commit()
+
+ @patch('fedlearner_webconsole.workflow.service.YamlFormatterService.generate_job_run_yaml')
+ def test_valid_workflow(self, mock_generate_job_run_yaml):
+ with db.session_scope() as session:
+ workflow = Workflow(id=0, project_id=99)
+ session.add(workflow)
+ session.flush()
+ job = Job(id=0,
+ name='test-job-0',
+ job_type=JobType.RAW_DATA,
+ workflow_id=0,
+ project_id=99,
+ config=JobDefinition(name='test-job').SerializeToString())
+ session.add(job)
+ session.flush()
+ sample_json = {'apiVersion': 'v1', 'kind': 'FLApp', 'metadata': {}}
+ mock_generate_job_run_yaml.return_value = sample_json
+
+ workflow_valid = WorkflowService(session).validate_workflow(workflow)
+ self.assertTrue(workflow_valid[0])
+ mock_generate_job_run_yaml.side_effect = ValueError
+ workflow_valid = WorkflowService(session).validate_workflow(workflow)
+ self.assertFalse(workflow_valid[0])
+
+ def test_filter_workflow_state(self):
+ with db.session_scope() as session:
+ configuring_workflow = Workflow(id=1,
+ state=WorkflowState.NEW,
+ target_state=WorkflowState.READY,
+ transaction_state=TransactionState.READY)
+ ready_workflow = Workflow(id=2,
+ state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID,
+ transaction_state=TransactionState.READY)
+ completed_workflow = Workflow(id=3, state=WorkflowState.COMPLETED)
+ failed_workflow = Workflow(id=4, state=WorkflowState.FAILED)
+ running_workflow = Workflow(id=5, state=WorkflowState.RUNNING)
+ session.add_all(
+ [configuring_workflow, ready_workflow, running_workflow, completed_workflow, failed_workflow])
+ session.flush()
+ completed_job = Job(id=1, job_type=JobType.RAW_DATA, workflow_id=3, project_id=99, state=JobState.COMPLETED)
+ failed_job = Job(id=2, job_type=JobType.RAW_DATA, workflow_id=4, project_id=99, state=JobState.FAILED)
+ running_job = Job(id=3, job_type=JobType.RAW_DATA, workflow_id=5, project_id=99, state=JobState.STARTED)
+ session.add_all([completed_job, failed_job, running_job])
+ session.flush()
+ all_workflows = session.query(Workflow)
+ self.assertEqual(
+ WorkflowService.filter_workflows(all_workflows, ['configuring', 'ready']).all(),
+ [configuring_workflow, ready_workflow])
+ self.assertEqual(WorkflowService.filter_workflows(all_workflows, ['failed']).all(), [failed_workflow])
+ self.assertEqual(WorkflowService.filter_workflows(all_workflows, ['completed']).all(), [completed_workflow])
+ self.assertEqual(
+ WorkflowService.filter_workflows(all_workflows, ['running', 'completed']).all(),
+ [completed_workflow, running_workflow])
+
+ def _get_scheduler_item(self, session) -> SchedulerItem:
+ item: SchedulerItem = session.query(SchedulerItem).filter_by(name='workflow_cron_job_1').first()
+ return item
+
+ def test_update_cronjob_config(self):
+ with db.session_scope() as session:
+ # test for collect
+ update_cronjob_config(1, '1 2 3 4 5', session)
+ session.commit()
+ item = self._get_scheduler_item(session)
+ self.assertEqual(item.cron_config, '1 2 3 4 5')
+ self.assertEqual(item.status, ItemStatus.ON.value)
+ item.status = ItemStatus.OFF.value
+ session.commit()
+ with db.session_scope() as session:
+ update_cronjob_config(1, '1 2 3 4 6', session)
+ session.commit()
+ item = self._get_scheduler_item(session)
+ self.assertEqual(item.status, ItemStatus.ON.value)
+ self.assertEqual(item.cron_config, '1 2 3 4 6')
+ with db.session_scope() as session:
+ update_cronjob_config(1, None, session)
+ session.commit()
+ item = self._get_scheduler_item(session)
+ self.assertEqual(item.status, ItemStatus.OFF.value)
+ self.assertEqual(item.cron_config, '1 2 3 4 6')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/utils.py b/web_console_v2/api/fedlearner_webconsole/workflow/utils.py
new file mode 100644
index 000000000..bad3153a6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/utils.py
@@ -0,0 +1,65 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+from typing import List, Optional
+
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.exceptions import InternalException
+from fedlearner_webconsole.utils.metrics import emit_store
+from fedlearner_webconsole.proto import common_pb2
+from fedlearner_webconsole.proto.common_pb2 import CreateJobFlag
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.workflow_definition_pb2 import \
+ WorkflowDefinition
+
+
+def is_local(config: WorkflowDefinition, job_flags: Optional[List[int]] = None) -> bool:
+ # if self.config is None, it must be created by the opposite side
+ if config is None:
+ return False
+ # since _setup_jobs has not been called, job_definitions is used
+ job_defs = config.job_definitions
+ if job_flags is None:
+ num_jobs = len(job_defs)
+ job_flags = [common_pb2.CreateJobFlag.NEW] * num_jobs
+ for i, (job_def, flag) in enumerate(zip(job_defs, job_flags)):
+ if flag != CreateJobFlag.REUSE and job_def.is_federated:
+ return False
+ return True
+
+
+def is_peer_job_inheritance_matched(project: Project, workflow_definition: WorkflowDefinition, job_flags: List[int],
+ peer_job_flags: List[int], parent_uuid: str, parent_name: str,
+ participants: List) -> bool:
+ """Checks if the job inheritance is matched with peer workflow.
+
+ We should make sure the federated jobs should have the same job flag
+ (inherit from parent or not)."""
+ # TODO(hangweiqiang): Fix for multi-peer
+ client = RpcClient.from_project_and_participant(project.name, project.token, participants[0].domain_name)
+ # Gets peer parent workflow
+ resp = client.get_workflow(parent_uuid, parent_name)
+ if resp.status.code != common_pb2.STATUS_SUCCESS:
+ emit_store('get_peer_workflow_failed', 1)
+ raise InternalException(resp.status.msg)
+ job_defs = workflow_definition.job_definitions
+ peer_job_defs = resp.config.job_definitions
+ for i, job_def in enumerate(job_defs):
+ if job_def.is_federated:
+ for j, peer_job_def in enumerate(peer_job_defs):
+ if job_def.name == peer_job_def.name:
+ if job_flags[i] != peer_job_flags[j]:
+ return False
+ return True
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/utils_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/utils_test.py
new file mode 100644
index 000000000..7790cecb4
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/utils_test.py
@@ -0,0 +1,111 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+import unittest
+from unittest.mock import patch, MagicMock
+from google.protobuf.json_format import ParseDict
+from testing.no_web_server_test_case import NoWebServerTestCase
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.proto.common_pb2 import CreateJobFlag
+from fedlearner_webconsole.proto.service_pb2 import GetWorkflowResponse
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition, \
+ WorkflowDefinition
+from fedlearner_webconsole.workflow.models import Workflow
+from fedlearner_webconsole.workflow.utils import \
+ is_peer_job_inheritance_matched, is_local
+
+
+class UtilsTest(NoWebServerTestCase):
+
+ def setUp(self):
+ super().setUp()
+ project = Project(id=0)
+
+ with db.session_scope() as session:
+ session.add(project)
+ session.commit()
+
+ def test_is_local(self):
+ config = {
+ 'job_definitions': [
+ {
+ 'name': 'raw-data',
+ 'is_federated': False
+ },
+ {
+ 'name': 'raw-data',
+ 'is_federated': True
+ },
+ ]
+ }
+ config = ParseDict(config, WorkflowDefinition())
+ self.assertFalse(is_local(config))
+ job_flags = [CreateJobFlag.NEW, CreateJobFlag.NEW]
+ self.assertFalse(is_local(config, job_flags))
+ job_flags = [CreateJobFlag.NEW, CreateJobFlag.REUSE]
+ self.assertTrue(is_local(config, job_flags))
+ job_flags = [CreateJobFlag.REUSE, CreateJobFlag.REUSE]
+ self.assertTrue(is_local(config, job_flags))
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient' '.from_project_and_participant')
+ def test_is_peer_job_inheritance_matched(self, mock_rpc_client_factory):
+ # Mock RPC
+ peer_job_0 = JobDefinition(name='raw-data-job')
+ peer_job_1 = JobDefinition(name='train-job', is_federated=True)
+ peer_config = WorkflowDefinition(job_definitions=[peer_job_0, peer_job_1])
+ resp = GetWorkflowResponse(config=peer_config)
+ mock_rpc_client = MagicMock()
+ mock_rpc_client.get_workflow = MagicMock(return_value=resp)
+ mock_rpc_client_factory.return_value = mock_rpc_client
+
+ job_0 = JobDefinition(name='train-job', is_federated=True)
+ workflow_definition = WorkflowDefinition(job_definitions=[job_0])
+
+ participant = Participant(domain_name='fl-test.com')
+
+ project = Project(name='test-project', token='test-token')
+ parent_workflow = Workflow(project=project, uuid='workflow-uuid-0000', name='workflow-0')
+ self.assertTrue(
+ is_peer_job_inheritance_matched(project=project,
+ workflow_definition=workflow_definition,
+ job_flags=[CreateJobFlag.REUSE],
+ peer_job_flags=[CreateJobFlag.NEW, CreateJobFlag.REUSE],
+ parent_uuid=parent_workflow.uuid,
+ parent_name=parent_workflow.name,
+ participants=[participant]))
+ mock_rpc_client.get_workflow.assert_called_once_with(parent_workflow.uuid, parent_workflow.name)
+ mock_rpc_client_factory.assert_called_once()
+ args, kwargs = mock_rpc_client_factory.call_args_list[0]
+ # Comparing call args one by one because message list
+ # can not compare directly
+ self.assertEqual(len(args), 3)
+ self.assertEqual(args[0], 'test-project')
+ self.assertEqual(args[1], 'test-token')
+ self.assertEqual(args[2], 'fl-test.com')
+
+ self.assertFalse(
+ is_peer_job_inheritance_matched(project=project,
+ workflow_definition=workflow_definition,
+ job_flags=[CreateJobFlag.NEW],
+ peer_job_flags=[CreateJobFlag.NEW, CreateJobFlag.REUSE],
+ parent_uuid=parent_workflow.uuid,
+ parent_name=parent_workflow.name,
+ participants=[participant]))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/workflow_controller.py b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_controller.py
new file mode 100644
index 000000000..371c614b6
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_controller.py
@@ -0,0 +1,135 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import Optional
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.auth.services import UserService
+from fedlearner_webconsole.job.controller import stop_job, schedule_job, \
+ start_job_if_ready
+from fedlearner_webconsole.notification.email import send_email
+from fedlearner_webconsole.notification.template import NotificationTemplateName
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition
+from fedlearner_webconsole.utils import pp_datetime
+from fedlearner_webconsole.utils.const import SYSTEM_WORKFLOW_CREATOR_USERNAME
+from fedlearner_webconsole.utils.flask_utils import get_link
+from fedlearner_webconsole.workflow.models import WorkflowState, Workflow, TransactionState
+from fedlearner_webconsole.workflow.service import WorkflowService, CreateNewWorkflowParams, update_cronjob_config
+
+
+# TODO(xiangyuxuan.prs): uses system workflow template revision instead of template directly
+def create_ready_workflow(
+ session: Session,
+ name: str,
+ config: WorkflowDefinition,
+ project_id: int,
+ uuid: str,
+ template_id: Optional[int] = None,
+ comment: Optional[str] = None,
+) -> Workflow:
+ """Creates a workflow in ready(configured) state, this is for our internal usage, such as dataset module.
+
+ Args:
+ session: DB session of the transaction.
+ name: Workflow name.
+ config: Workflow configurations.
+ project_id: Which project this workflow belongs to.
+ uuid: Global uinique id of this workflow, practicely we use it for pairing.
+ template_id: Which template this workflow will use.
+ comment: Optional comment of the workflow.
+ """
+ return WorkflowService(session).create_workflow(
+ name=name,
+ config=config,
+ params=CreateNewWorkflowParams(project_id=project_id, template_id=template_id),
+ uuid=uuid,
+ comment=comment,
+ creator_username=SYSTEM_WORKFLOW_CREATOR_USERNAME,
+ state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID,
+ )
+
+
+def start_workflow_locally(session: Session, workflow: Workflow):
+ """Starts the workflow locally, it does not affect other participants."""
+ if not workflow.can_transit_to(WorkflowState.RUNNING):
+ raise RuntimeError(f'invalid workflow state {workflow.state} when try to start')
+ is_valid, info = WorkflowService(session).validate_workflow(workflow)
+ if not is_valid:
+ job_name, validate_e = info
+ raise ValueError(f'Invalid Variable when try to format the job {job_name}: {str(validate_e)}')
+
+ workflow.start_at = int(pp_datetime.now().timestamp())
+ workflow.state = WorkflowState.RUNNING
+ # Schedules all jobs to make them executable
+ for job in workflow.owned_jobs:
+ schedule_job(session, job)
+ # A workaround to speed up the workflow execution: manually trigger the start of jobs
+ for job in workflow.owned_jobs:
+ start_job_if_ready(session, job)
+
+
+def _notify_if_finished(session: Session, workflow: Workflow):
+ if workflow.state not in [WorkflowState.FAILED, WorkflowState.STOPPED, WorkflowState.COMPLETED]:
+ return
+ creator = UserService(session).get_user_by_username(workflow.creator)
+ email_address = None
+ if creator:
+ email_address = creator.email
+ send_email(email_address,
+ NotificationTemplateName.WORKFLOW_COMPLETE,
+ name=workflow.name,
+ state=workflow.state.name,
+ link=get_link(f'/v2/workflow-center/workflows/{workflow.id}'))
+
+
+def stop_workflow_locally(session: Session, workflow: Workflow):
+ """Stops the workflow locally, it does not affect other participants."""
+ if not workflow.can_transit_to(WorkflowState.STOPPED):
+ raise RuntimeError(f'invalid workflow state {workflow.state} when try to stop')
+
+ workflow.stop_at = int(pp_datetime.now().timestamp())
+ if workflow.is_failed():
+ workflow.state = WorkflowState.FAILED
+ elif workflow.is_finished():
+ workflow.state = WorkflowState.COMPLETED
+ else:
+ workflow.state = WorkflowState.STOPPED
+ try:
+ for job in workflow.owned_jobs:
+ stop_job(session, job)
+ _notify_if_finished(session, workflow)
+ except RuntimeError as e:
+ logging.error(f'Failed to stop workflow {workflow.id}: {str(e)}')
+ raise
+
+
+def invalidate_workflow_locally(session: Session, workflow: Workflow):
+ """Invalidates workflow locally and stops related jobs."""
+ logging.info(f'Invalidating workflow {workflow.id}')
+ # Stops the related cron jobs
+ update_cronjob_config(workflow.id, None, session)
+ # Marks the workflow's state
+ workflow.state = WorkflowState.INVALID
+ workflow.target_state = WorkflowState.INVALID
+ workflow.transaction_state = TransactionState.READY
+ # Stops owned jobs
+ for job in workflow.owned_jobs:
+ try:
+ stop_job(session, job)
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning('Error while stopping job %s during invalidation: %s', job.name, repr(e))
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/workflow_controller_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_controller_test.py
new file mode 100644
index 000000000..a3e1c6807
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_controller_test.py
@@ -0,0 +1,444 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime, timezone
+from unittest.mock import patch, Mock
+
+from fedlearner_webconsole.auth.models import User
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.models import JobType, Job, JobState, JobDependency
+from fedlearner_webconsole.notification.template import NotificationTemplateName
+from fedlearner_webconsole.proto.common_pb2 import Variable
+from fedlearner_webconsole.proto.workflow_definition_pb2 import WorkflowDefinition, JobDefinition
+from fedlearner_webconsole.utils.const import SYSTEM_WORKFLOW_CREATOR_USERNAME
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.workflow.models import WorkflowState, Workflow
+from fedlearner_webconsole.workflow.workflow_controller import create_ready_workflow, start_workflow_locally, \
+ stop_workflow_locally, \
+ invalidate_workflow_locally, _notify_if_finished
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class CreateReadyWorkflowTest(NoWebServerTestCase):
+
+ def test_create_ready_workflow_with_template(self):
+ with db.session_scope() as session:
+ workflow_template = WorkflowTemplate(
+ id=123,
+ name='t123',
+ group_alias='test group',
+ )
+ workflow_template.set_config(
+ WorkflowDefinition(
+ group_alias='test group',
+ variables=[
+ Variable(name='var'),
+ ],
+ ))
+ session.add(workflow_template)
+ session.commit()
+
+ # Changes one variable
+ config = WorkflowDefinition()
+ config.CopyFrom(workflow_template.get_config())
+ config.variables[0].value = 'new_value'
+ workflow = create_ready_workflow(
+ session,
+ name='workflow1',
+ config=config,
+ project_id=2333,
+ uuid='uuid',
+ template_id=workflow_template.id,
+ )
+ session.commit()
+ workflow_id = workflow.id
+ with db.session_scope() as session:
+ workflow: Workflow = session.query(Workflow).get(workflow_id)
+ self.assertPartiallyEqual(
+ to_dict(workflow.to_proto()),
+ {
+ 'id': workflow_id,
+ 'name': 'workflow1',
+ 'comment': '',
+ 'uuid': 'uuid',
+ 'project_id': 2333,
+ 'creator': SYSTEM_WORKFLOW_CREATOR_USERNAME,
+ 'state': 'READY_TO_RUN',
+ 'forkable': False,
+ 'forked_from': 0,
+ 'template_id': 123,
+ 'template_revision_id': 0,
+ 'template_info': {
+ 'id': 123,
+ 'is_modified': False,
+ 'name': 't123',
+ 'revision_index': 0,
+ },
+ 'config': {
+ 'group_alias':
+ 'test group',
+ 'job_definitions': [],
+ 'variables': [{
+ 'access_mode': 'UNSPECIFIED',
+ 'name': 'var',
+ 'tag': '',
+ 'value': 'new_value',
+ 'value_type': 'STRING',
+ 'widget_schema': '',
+ }],
+ },
+ 'is_local': True,
+ 'favour': False,
+ 'job_ids': [],
+ 'jobs': [],
+ 'metric_is_public': False,
+ 'create_job_flags': [],
+ 'peer_create_job_flags': [],
+ 'cron_config': '',
+ 'editor_info': {
+ 'yaml_editor_infos': {},
+ },
+ },
+ ignore_fields=['created_at', 'updated_at', 'start_at', 'stop_at'],
+ )
+
+ def test_create_ready_workflow_without_template(self):
+ with db.session_scope() as session:
+ # Changes one variable
+ config = WorkflowDefinition(
+ group_alias='test group',
+ variables=[
+ Variable(name='var', value='cofff'),
+ ],
+ )
+ workflow = create_ready_workflow(
+ session,
+ name='workflow222',
+ config=config,
+ project_id=23334,
+ uuid='uuid222',
+ )
+ session.commit()
+ workflow_id = workflow.id
+ with db.session_scope() as session:
+ workflow: Workflow = session.query(Workflow).get(workflow_id)
+ self.maxDiff = None
+ self.assertPartiallyEqual(
+ to_dict(workflow.to_proto()),
+ {
+ 'id': workflow_id,
+ 'name': 'workflow222',
+ 'comment': '',
+ 'uuid': 'uuid222',
+ 'project_id': 23334,
+ 'creator': SYSTEM_WORKFLOW_CREATOR_USERNAME,
+ 'state': 'READY_TO_RUN',
+ 'forkable': False,
+ 'forked_from': 0,
+ 'template_id': 0,
+ 'template_revision_id': 0,
+ 'template_info': {
+ 'id': 0,
+ 'is_modified': True,
+ 'name': '',
+ 'revision_index': 0,
+ },
+ 'config': {
+ 'group_alias':
+ 'test group',
+ 'job_definitions': [],
+ 'variables': [{
+ 'access_mode': 'UNSPECIFIED',
+ 'name': 'var',
+ 'tag': '',
+ 'value': 'cofff',
+ 'value_type': 'STRING',
+ 'widget_schema': '',
+ }],
+ },
+ 'is_local': True,
+ 'favour': False,
+ 'job_ids': [],
+ 'jobs': [],
+ 'metric_is_public': False,
+ 'create_job_flags': [],
+ 'peer_create_job_flags': [],
+ 'cron_config': '',
+ 'editor_info': {
+ 'yaml_editor_infos': {},
+ },
+ },
+ ignore_fields=['created_at', 'updated_at', 'start_at', 'stop_at'],
+ )
+
+
+class StartWorkflowLocallyTest(NoWebServerTestCase):
+
+ def test_start_workflow_locally_invalid_template(self):
+ running_workflow = Workflow(id=1, state=WorkflowState.RUNNING)
+ with db.session_scope() as session:
+ start_workflow_locally(session, running_workflow)
+
+ @patch('fedlearner_webconsole.workflow.workflow_controller.WorkflowService.validate_workflow')
+ def test_start_workflow_locally_invalid_state(self, mock_validate_workflow: Mock):
+ mock_validate_workflow.return_value = False, ('test_job', 'fake error')
+ workflow = Workflow(id=1, state=WorkflowState.READY)
+ with db.session_scope() as session:
+ self.assertRaisesRegex(ValueError, 'Invalid Variable when try to format the job test_job: fake error',
+ lambda: start_workflow_locally(session, workflow))
+
+ @patch('fedlearner_webconsole.workflow.workflow_controller.pp_datetime.now')
+ @patch('fedlearner_webconsole.job.controller.YamlFormatterService.generate_job_run_yaml')
+ @patch('fedlearner_webconsole.workflow.workflow_controller.WorkflowService.validate_workflow')
+ def test_start_workflow_locally_successfully(self, mock_validate_workflow: Mock, mock_gen_yaml: Mock,
+ mock_now: Mock):
+ mock_validate_workflow.return_value = True, None
+ now_dt = datetime(2021, 9, 1, 10, 20, tzinfo=timezone.utc)
+ mock_now.return_value = now_dt
+
+ workflow_id = 123
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, state=WorkflowState.READY)
+ job1 = Job(id=1,
+ name='test job 1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=workflow_id,
+ project_id=1)
+ job2 = Job(id=2,
+ name='test job 2',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=workflow_id,
+ project_id=1)
+ job_def = JobDependency(src_job_id=2, dst_job_id=1, dep_index=0)
+ config = JobDefinition(is_federated=False)
+ job1.set_config(config)
+ job2.set_config(config)
+ session.add_all([workflow, job1, job2, job_def])
+ session.commit()
+ mock_gen_yaml.return_value = {}
+ start_workflow_locally(session, workflow)
+ session.commit()
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertEqual(workflow.start_at, now_dt.timestamp())
+ self.assertEqual(workflow.state, WorkflowState.RUNNING)
+ job1 = session.query(Job).get(1)
+ job2 = session.query(Job).get(2)
+ self.assertEqual(job1.state, JobState.WAITING)
+ self.assertEqual(job2.state, JobState.STARTED)
+ mock_gen_yaml.assert_called_once()
+
+
+class StopWorkflowLocallyTest(NoWebServerTestCase):
+
+ def test_stop_workflow_locally_invalid_state(self):
+ with db.session_scope() as session:
+ new_workflow = Workflow(id=1, state=WorkflowState.NEW)
+ self.assertRaisesRegex(RuntimeError, 'invalid workflow state WorkflowState.NEW when try to stop',
+ lambda: stop_workflow_locally(session, new_workflow))
+
+ @patch('fedlearner_webconsole.workflow.workflow_controller.pp_datetime.now')
+ @patch('fedlearner_webconsole.workflow.workflow_controller.stop_job')
+ def test_stop_workflow_locally_successfully(self, mock_stop_job: Mock, mock_now: Mock):
+ now_dt = datetime(2021, 9, 1, 10, 20, tzinfo=timezone.utc)
+ mock_now.return_value = now_dt
+
+ workflow_id = 123
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, state=WorkflowState.RUNNING)
+ job1 = Job(id=1,
+ name='test job 1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=workflow_id,
+ project_id=1)
+ job2 = Job(id=2,
+ name='test job 2',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=workflow_id,
+ project_id=1)
+ session.add_all([workflow, job1, job2])
+ session.commit()
+ stop_workflow_locally(session, workflow)
+ session.commit()
+ # Stopped 2 jobs
+ self.assertEqual(mock_stop_job.call_count, 2)
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertEqual(workflow.stop_at, now_dt.timestamp())
+ self.assertEqual(workflow.state, WorkflowState.STOPPED)
+
+ def test_stop_ready_workflow(self):
+ with db.session_scope() as session:
+ ready_workflow = Workflow(id=1, state=WorkflowState.READY)
+ job1 = Job(id=1,
+ name='test job 1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=1,
+ project_id=1)
+ job2 = Job(id=2,
+ name='test job 2',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=1,
+ project_id=1)
+ session.add_all([ready_workflow, job1, job2])
+ session.commit()
+ stop_workflow_locally(session, ready_workflow)
+ self.assertEqual(ready_workflow.state, WorkflowState.STOPPED)
+
+ @patch('fedlearner_webconsole.workflow.workflow_controller.pp_datetime.now')
+ @patch('fedlearner_webconsole.workflow.workflow_controller.stop_job')
+ def test_stop_workflow_to_completed(self, mock_stop_job: Mock, mock_now: Mock):
+ now_dt = datetime(2021, 9, 1, 10, 20, tzinfo=timezone.utc)
+ mock_now.return_value = now_dt
+ workflow_id = 123
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, state=WorkflowState.RUNNING)
+ job1 = Job(id=1,
+ name='test job 1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.COMPLETED,
+ workflow_id=workflow_id,
+ project_id=1)
+ job2 = Job(id=2,
+ name='test job 2',
+ job_type=JobType.RAW_DATA,
+ state=JobState.COMPLETED,
+ workflow_id=workflow_id,
+ project_id=1)
+ session.add_all([workflow, job1, job2])
+ session.commit()
+ stop_workflow_locally(session, workflow)
+ session.commit()
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertEqual(workflow.state, WorkflowState.COMPLETED)
+
+ @patch('fedlearner_webconsole.workflow.workflow_controller.stop_job')
+ def test_stop_workflow_to_failed(self, mock_stop_job: Mock):
+ workflow_id = 123
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, state=WorkflowState.RUNNING)
+ job1 = Job(id=1,
+ name='test job 1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.COMPLETED,
+ workflow_id=workflow_id,
+ project_id=1)
+ job2 = Job(id=2,
+ name='test job 2',
+ job_type=JobType.RAW_DATA,
+ state=JobState.FAILED,
+ workflow_id=workflow_id,
+ project_id=1)
+ session.add_all([workflow, job1, job2])
+ session.commit()
+ stop_workflow_locally(session, workflow)
+ session.commit()
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertEqual(workflow.state, WorkflowState.FAILED)
+
+ @patch('fedlearner_webconsole.workflow.workflow_controller.stop_job')
+ def test_stop_workflow_locally_failed(self, mock_stop_job: Mock):
+ mock_stop_job.side_effect = RuntimeError('fake error')
+
+ workflow_id = 123
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, state=WorkflowState.RUNNING)
+ job1 = Job(id=1,
+ name='test job 1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=workflow_id,
+ project_id=1)
+ session.add_all([workflow, job1])
+ session.commit()
+ # Simulates the normal action by following a session commit
+ with self.assertRaises(RuntimeError):
+ stop_workflow_locally(session, workflow)
+ session.commit()
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertIsNone(workflow.stop_at)
+ self.assertEqual(workflow.state, WorkflowState.RUNNING)
+
+
+class InvalidateWorkflowLocallyTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.workflow.workflow_controller.stop_job')
+ @patch('fedlearner_webconsole.workflow.workflow_controller.update_cronjob_config')
+ def test_invalidate_workflow_locally(self, mock_update_cronjob_config: Mock, mock_stop_job: Mock):
+ workflow_id = 6
+ project_id = 99
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, project_id=project_id, state=WorkflowState.RUNNING)
+ job1 = Job(id=1,
+ name='test job 1',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=workflow_id,
+ project_id=project_id)
+ job2 = Job(id=2,
+ name='test job 2',
+ job_type=JobType.RAW_DATA,
+ state=JobState.NEW,
+ workflow_id=workflow_id,
+ project_id=project_id)
+ session.add_all([workflow, job1, job2])
+ session.commit()
+ invalidate_workflow_locally(session, workflow)
+ session.commit()
+ mock_update_cronjob_config.assert_called_with(workflow_id, None, session)
+ self.assertEqual(mock_stop_job.call_count, 2)
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertEqual(workflow.state, WorkflowState.INVALID)
+ self.assertEqual(workflow.target_state, WorkflowState.INVALID)
+
+
+@patch('fedlearner_webconsole.workflow.workflow_controller.send_email')
+class NotifyIfFinishedTest(NoWebServerTestCase):
+
+ def test_running_workflow(self, mock_send_email: Mock):
+ with db.session_scope() as session:
+ workflow = Workflow(state=WorkflowState.RUNNING)
+ _notify_if_finished(session, workflow)
+ mock_send_email.assert_not_called()
+
+ def test_notify(self, mock_send_email: Mock):
+ with db.session_scope() as session:
+ user = User(username='test_user', email='a@b.com')
+ session.add(user)
+ session.commit()
+ with db.session_scope() as session:
+ workflow = Workflow(id=1234, name='test-workflow', state=WorkflowState.FAILED, creator='test_user')
+ _notify_if_finished(session, workflow)
+ mock_send_email.assert_called_with('a@b.com',
+ NotificationTemplateName.WORKFLOW_COMPLETE,
+ name='test-workflow',
+ state='FAILED',
+ link='http://localhost:666/v2/workflow-center/workflows/1234')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/workflow_job_controller.py b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_job_controller.py
new file mode 100644
index 000000000..177cbff65
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_job_controller.py
@@ -0,0 +1,115 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from typing import List
+
+from sqlalchemy.orm import Session
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.rpc.client import RpcClient
+from fedlearner_webconsole.two_pc.transaction_manager import TransactionManager
+from fedlearner_webconsole.proto.two_pc_pb2 import (TwoPcType, TransactionData, TransitWorkflowStateData)
+from fedlearner_webconsole.exceptions import InternalException, InvalidArgumentException
+from fedlearner_webconsole.participant.services import ParticipantService
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.workflow.workflow_controller import start_workflow_locally, stop_workflow_locally, \
+ invalidate_workflow_locally
+
+
+def _change_workflow_status(target_state: WorkflowState, project: Project, uuid: str, participants_domain_name: List):
+ assert target_state in (WorkflowState.RUNNING, WorkflowState.STOPPED)
+ tm = TransactionManager(project_name=project.name,
+ project_token=project.token,
+ participants=participants_domain_name,
+ two_pc_type=TwoPcType.CONTROL_WORKFLOW_STATE)
+ data = TransitWorkflowStateData(target_state=target_state.name, workflow_uuid=uuid)
+ successed, message = tm.run(data=TransactionData(transit_workflow_state_data=data))
+ if not successed:
+ raise InternalException(f'error when converting workflow state by 2PC: {message}')
+
+
+def _start_2pc_workflow(project: Project, uuid: str, participants_domain_name: List):
+ return _change_workflow_status(WorkflowState.RUNNING, project, uuid, participants_domain_name)
+
+
+def _stop_2pc_workflow(project: Project, uuid: str, participants_domain_name: List):
+ return _change_workflow_status(WorkflowState.STOPPED, project, uuid, participants_domain_name)
+
+
+# start workflow main entry
+def start_workflow(workflow_id: int):
+ # TODO(liuhehan): add uuid as entrypoint
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).filter_by(id=workflow_id).first()
+ if workflow.is_local():
+ # local entry
+ try:
+ # TODO(linfan.fine): gets rid of the session, the controller should be in a separate session
+ start_workflow_locally(session, workflow)
+ session.commit()
+ return
+ except RuntimeError as e:
+ raise InternalException(e) from e
+ except ValueError as e:
+ raise InvalidArgumentException(str(e)) from e
+ participants = ParticipantService(session).get_platform_participants_by_project(workflow.project.id)
+ project = session.query(Project).filter_by(id=workflow.project.id).first()
+ participants_domain_name = [participant.domain_name for participant in participants]
+ # new version fed entry
+ _start_2pc_workflow(project, workflow.uuid, participants_domain_name)
+
+
+# stop workflow main entry
+def stop_workflow(workflow_id: int):
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ if workflow.is_local():
+ # local entry
+ try:
+ # TODO(linfan.fine): gets rid of the session, the controller should be in a separate session
+ stop_workflow_locally(session, workflow)
+ session.commit()
+ return
+ except RuntimeError as e:
+ raise InternalException(e) from e
+ participants = ParticipantService(session).get_platform_participants_by_project(workflow.project.id)
+ project = session.query(Project).filter_by(id=workflow.project.id).first()
+ participants_domain_name = [participant.domain_name for participant in participants]
+ # new version fed entry
+ _stop_2pc_workflow(project, workflow.uuid, participants_domain_name)
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+
+
+def invalidate_workflow_job(session: Session, workflow: Workflow):
+ """Invalidates workflow job across all participants."""
+ invalidate_workflow_locally(session, workflow)
+ if workflow.is_local():
+ # No actions needed
+ return
+
+ service = ParticipantService(session)
+ participants = service.get_platform_participants_by_project(workflow.project.id)
+ # Invalidates peer's workflow
+ for participant in participants:
+ client = RpcClient.from_project_and_participant(workflow.project.name, workflow.project.token,
+ participant.domain_name)
+ resp = client.invalidate_workflow(workflow.uuid)
+ if not resp.succeeded:
+ # Ignores those errors as it will be handled by their workflow schedulers
+ logging.warning(
+ f'failed to invalidate peer workflow, workflow id: {workflow.id}, participant name: {participant.name}')
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/workflow_job_controller_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_job_controller_test.py
new file mode 100644
index 000000000..affaf02bd
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_job_controller_test.py
@@ -0,0 +1,73 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch, Mock
+
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.participant.models import ProjectParticipant, Participant
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.workflow_definition_pb2 import JobDefinition, WorkflowDefinition
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.workflow.workflow_job_controller import invalidate_workflow_job
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class InvalidateWorkflowJobTest(NoWebServerTestCase):
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.invalidate_workflow')
+ def test_invalidate_workflow_job_local(self, mock_invalidate_workflow: Mock):
+ workflow_id = 777
+ with db.session_scope() as session:
+ workflow = Workflow(id=workflow_id, project_id=1, state=WorkflowState.RUNNING)
+ workflow.set_config(
+ WorkflowDefinition(job_definitions=[JobDefinition(name='raw-data', is_federated=False)]))
+ session.add(workflow)
+ session.commit()
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ invalidate_workflow_job(session, workflow)
+ session.commit()
+ mock_invalidate_workflow.assert_not_called()
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertTrue(workflow.is_invalid())
+
+ @patch('fedlearner_webconsole.rpc.client.RpcClient.invalidate_workflow')
+ def test_invalidate_workflow_job_across_participants(self, mock_invalidate_workflow: Mock):
+ workflow_id = 6789
+ with db.session_scope() as session:
+ project = Project(id=1)
+ participant = Participant(id=123, name='testp', domain_name='fl-test.com')
+ project_participant = ProjectParticipant(project_id=project.id, participant_id=participant.id)
+ session.add_all([project, participant, project_participant])
+
+ workflow = Workflow(id=workflow_id, project_id=1, state=WorkflowState.RUNNING, uuid='test_uuid')
+ workflow.set_config(
+ WorkflowDefinition(job_definitions=[JobDefinition(name='data-join', is_federated=True)]))
+ session.add(workflow)
+ session.commit()
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ invalidate_workflow_job(session, workflow)
+ session.commit()
+ mock_invalidate_workflow.assert_called_once_with('test_uuid')
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(workflow_id)
+ self.assertTrue(workflow.is_invalid())
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/workflow_scheduler.py b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_scheduler.py
new file mode 100644
index 000000000..cdc0b94e3
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_scheduler.py
@@ -0,0 +1,95 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import traceback
+from typing import Tuple
+
+from sqlalchemy.orm import load_only, joinedload
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.job.models import Job
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, WorkflowSchedulerOutput
+from fedlearner_webconsole.utils import const
+from fedlearner_webconsole.workflow.service import WorkflowService
+from fedlearner_webconsole.workflow.workflow_job_controller import start_workflow, stop_workflow
+from fedlearner_webconsole.composer.models import RunnerStatus
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from fedlearner_webconsole.composer.interface import IRunnerV2
+
+
+class ScheduleWorkflowRunner(IRunnerV2):
+
+ def auto_run_workflows(self) -> WorkflowSchedulerOutput:
+ with db.session_scope() as session:
+ # Workflows (with system preset template) whose state is ready will auto run.
+ query = session.query(Workflow.id).join(Workflow.template).filter(
+ Workflow.state == WorkflowState.READY, WorkflowTemplate.name.in_(const.SYS_PRESET_TEMPLATE))
+ workflow_ids = [result[0] for result in query.all()]
+ output = WorkflowSchedulerOutput()
+ for workflow_id in workflow_ids:
+ execution = output.executions.add()
+ execution.id = workflow_id
+ try:
+ logging.info(f'[WorkflowScheduler] auto start workflow {workflow_id}')
+ start_workflow(workflow_id)
+ except Exception as e: # pylint: disable=broad-except
+ error = str(e)
+ logging.warning(f'[WorkflowScheduler] auto start workflow {workflow_id} with error {error}')
+ execution.error_message = error
+ return output
+
+ def auto_stop_workflows(self) -> WorkflowSchedulerOutput:
+ with db.session_scope() as session:
+ # only query fields necessary for is_finished and is_failed.
+ q = session.query(Workflow).options(
+ load_only(Workflow.id, Workflow.name, Workflow.target_state, Workflow.state),
+ joinedload(Workflow.owned_jobs).load_only(Job.state,
+ Job.is_disabled)).filter_by(state=WorkflowState.RUNNING)
+ workflow_ids = [w.id for w in q.all() if WorkflowService(session).should_auto_stop(w)]
+ output = WorkflowSchedulerOutput()
+ for workflow_id in workflow_ids:
+ execution = output.executions.add()
+ execution.id = workflow_id
+ try:
+ stop_workflow(workflow_id)
+ except Exception as e: # pylint: disable=broad-except
+ error = f'Error while auto-stop workflow {workflow_id}:\n{traceback.format_exc()}'
+ logging.warning(error)
+ execution.error_message = error
+ return output
+
+ def run(self, context: RunnerContext) -> Tuple[RunnerStatus, RunnerOutput]:
+ output = RunnerOutput()
+ try:
+ output.workflow_scheduler_output.MergeFrom(self.auto_stop_workflows())
+ except Exception as e: # pylint: disable=broad-except
+ error_message = str(e)
+ output.error_message = error_message
+ logging.warning(f'[SchedulerWorkflowRunner] auto stop workflow with error {error_message}')
+
+ try:
+ # TODO(xiangyuxuan.prs): remove in future when model module don't need config workflow.
+ output.workflow_scheduler_output.MergeFrom(self.auto_run_workflows())
+ except Exception as e: # pylint: disable=broad-except
+ error_message = str(e)
+ output.error_message = f'{output.error_message} {error_message}'
+ logging.warning(f'[SchedulerWorkflowRunner] auto run workflow with error {error_message}')
+
+ if output.error_message:
+ return RunnerStatus.FAILED, output
+ return RunnerStatus.DONE, output
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow/workflow_scheduler_test.py b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_scheduler_test.py
new file mode 100644
index 000000000..80e56cf3b
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow/workflow_scheduler_test.py
@@ -0,0 +1,149 @@
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest.mock import patch
+
+from fedlearner_webconsole.composer.context import RunnerContext
+from fedlearner_webconsole.job.models import Job, JobState, JobType
+from fedlearner_webconsole.project.models import Project
+from fedlearner_webconsole.proto.composer_pb2 import RunnerOutput, WorkflowSchedulerOutput, RunnerInput
+from fedlearner_webconsole.workflow.workflow_scheduler import ScheduleWorkflowRunner
+from fedlearner_webconsole.db import db
+from fedlearner_webconsole.initial_db import _insert_schedule_workflow_item
+from fedlearner_webconsole.composer.models import SchedulerItem, RunnerStatus, ItemStatus
+from fedlearner_webconsole.workflow.models import Workflow, WorkflowState
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate
+from testing.no_web_server_test_case import NoWebServerTestCase
+
+
+class WorkflowSchedulerTest(NoWebServerTestCase):
+
+ def test_get_workflows_need_auto_run(self):
+ with db.session_scope() as session:
+ template_1 = WorkflowTemplate(name='local-test', group_alias='local-test', config=b'')
+ template_2 = WorkflowTemplate(name='sys-preset-nn-model', group_alias='nn', config=b'')
+ session.add_all([template_1, template_2])
+ session.flush()
+ workflow_1 = Workflow(name='w1')
+ workflow_2 = Workflow(name='w2',
+ template_id=template_1.id,
+ state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID)
+ workflow_3 = Workflow(name='w3', template_id=template_2.id)
+ workflow_5 = Workflow(name='w5',
+ template_id=template_2.id,
+ state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID)
+ workflow_6 = Workflow(name='w6',
+ template_id=template_2.id,
+ state=WorkflowState.READY,
+ target_state=WorkflowState.INVALID)
+ session.add_all([workflow_1, workflow_2, workflow_3, workflow_5, workflow_6])
+ session.commit()
+
+ def fake_start_workflow(workflow_id):
+ if workflow_id == workflow_6.id:
+ raise RuntimeError('error workflow_6')
+
+ with patch('fedlearner_webconsole.workflow.workflow_scheduler.start_workflow') as mock_start_workflow:
+ mock_start_workflow.side_effect = fake_start_workflow
+ # workflow_5 and workflow_6 will be auto-run
+ runner = ScheduleWorkflowRunner()
+ status, output = runner.run(RunnerContext(0, RunnerInput()))
+ self.assertEqual(status, RunnerStatus.DONE)
+ self.assertEqual(
+ output,
+ RunnerOutput(workflow_scheduler_output=WorkflowSchedulerOutput(executions=[
+ WorkflowSchedulerOutput.WorkflowExecution(id=workflow_5.id),
+ WorkflowSchedulerOutput.WorkflowExecution(id=workflow_6.id, error_message='error workflow_6'),
+ ])))
+
+ def test_insert_schedule_workflow_item(self):
+ with db.session_scope() as session:
+ item = SchedulerItem(name='workflow_scheduler', cron_config='* * * * * */30', status=ItemStatus.ON.value)
+ session.add(item)
+ session.commit()
+ _insert_schedule_workflow_item(session)
+ session.commit()
+ with db.session_scope() as session:
+ old_item = session.query(SchedulerItem).filter_by(name='workflow_scheduler').first()
+ self.assertEqual(old_item.status, ItemStatus.OFF.value)
+ new_item = session.query(SchedulerItem).filter_by(name='workflow_scheduler_v2').first()
+ self.assertEqual(new_item.status, ItemStatus.ON.value)
+
+ @patch('fedlearner_webconsole.workflow.workflow_scheduler.Workflow.is_local')
+ def test_auto_stop(self, mock_is_local):
+ mock_is_local.return_value = True
+ with db.session_scope() as session:
+ session.add(Project(name='test'))
+ session.add(
+ Job(name='testtes', state=JobState.COMPLETED, job_type=JobType.DATA_JOIN, workflow_id=30, project_id=1))
+
+ session.add(
+ Job(name='testtest', state=JobState.COMPLETED, job_type=JobType.DATA_JOIN, workflow_id=30,
+ project_id=1))
+ session.add(Workflow(name='test_complete', id=30, project_id=1, state=WorkflowState.RUNNING))
+ session.commit()
+ output = ScheduleWorkflowRunner().auto_stop_workflows()
+ self.assertEqual(output, WorkflowSchedulerOutput(executions=[WorkflowSchedulerOutput.WorkflowExecution(id=30)]))
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(30)
+ self.assertEqual(workflow.state, WorkflowState.COMPLETED)
+ with db.session_scope() as session:
+ session.add(
+ Job(name='testtes_failed',
+ state=JobState.FAILED,
+ job_type=JobType.DATA_JOIN,
+ workflow_id=31,
+ project_id=1))
+ session.add(Workflow(name='test_failed', id=31, project_id=1, state=WorkflowState.RUNNING))
+ session.commit()
+ output = ScheduleWorkflowRunner().auto_stop_workflows()
+ self.assertEqual(output, WorkflowSchedulerOutput(executions=[WorkflowSchedulerOutput.WorkflowExecution(id=31)]))
+ with db.session_scope() as session:
+ workflow = session.query(Workflow).get(31)
+ self.assertEqual(workflow.state, WorkflowState.FAILED)
+
+ @patch('fedlearner_webconsole.workflow.workflow_scheduler.ScheduleWorkflowRunner.auto_run_workflows')
+ @patch('fedlearner_webconsole.workflow.workflow_scheduler.ScheduleWorkflowRunner.auto_stop_workflows')
+ def test_run_workflow_scheduler(self, mock_auto_stop, mock_auto_run):
+ # test all succeeded
+ mock_auto_run.return_value = WorkflowSchedulerOutput(
+ executions=[WorkflowSchedulerOutput.WorkflowExecution(id=1)])
+ mock_auto_stop.return_value = WorkflowSchedulerOutput(
+ executions=[WorkflowSchedulerOutput.WorkflowExecution(id=2)])
+ expected_result = (RunnerStatus.DONE,
+ RunnerOutput(workflow_scheduler_output=WorkflowSchedulerOutput(executions=[
+ WorkflowSchedulerOutput.WorkflowExecution(id=2),
+ WorkflowSchedulerOutput.WorkflowExecution(id=1)
+ ])))
+ self.assertEqual(ScheduleWorkflowRunner().run(RunnerContext(1, RunnerInput())), expected_result)
+
+ # test auto run failed
+ mock_auto_stop.side_effect = Exception('Test')
+ expected_result = (RunnerStatus.FAILED,
+ RunnerOutput(error_message='Test',
+ workflow_scheduler_output=WorkflowSchedulerOutput(
+ executions=[WorkflowSchedulerOutput.WorkflowExecution(id=1)])))
+ self.assertEqual(ScheduleWorkflowRunner().run(RunnerContext(1, RunnerInput())), expected_result)
+ # test all failed
+ mock_auto_run.side_effect = Exception('Test')
+ expected_result = (RunnerStatus.FAILED, RunnerOutput(error_message='Test Test'))
+ self.assertEqual(ScheduleWorkflowRunner().run(RunnerContext(1, RunnerInput())), expected_result)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow_template/BUILD.bazel b/web_console_v2/api/fedlearner_webconsole/workflow_template/BUILD.bazel
new file mode 100644
index 000000000..e7a98a55e
--- /dev/null
+++ b/web_console_v2/api/fedlearner_webconsole/workflow_template/BUILD.bazel
@@ -0,0 +1,196 @@
+load("@rules_python//python:defs.bzl", "py_library")
+
+package(default_visibility = ["//web_console_v2/api:console_api_package"])
+
+py_library(
+ name = "models_lib",
+ srcs = ["models.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:mixins_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "models_lib_test",
+ size = "small",
+ srcs = [
+ "models_test.py",
+ ],
+ imports = ["../.."],
+ main = "models_test.py",
+ deps = [
+ ":models_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_datetime_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ ],
+)
+
+py_library(
+ name = "service_lib",
+ srcs = ["service.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":slots_formatter_lib",
+ ":template_validator_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:filtering_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "service_lib_test",
+ size = "small",
+ srcs = [
+ "service_test.py",
+ ],
+ imports = ["../.."],
+ main = "service_test.py",
+ deps = [
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "slots_formatter_lib",
+ srcs = ["slots_formatter.py"],
+ imports = ["../.."],
+ deps = [
+ ":template_validator_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_flatten_dict_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ ],
+)
+
+py_test(
+ name = "slots_formater_lib_test",
+ size = "small",
+ srcs = [
+ "slots_formater_test.py",
+ ],
+ imports = ["../.."],
+ main = "slots_formater_test.py",
+ deps = [
+ ":slots_formatter_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "template_validator_lib",
+ srcs = ["template_validaor.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/fedlearner_webconsole/job:yaml_formatter_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_flatten_dict_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:pp_yaml_lib",
+ ],
+)
+
+py_test(
+ name = "template_validator_lib_test",
+ size = "small",
+ srcs = [
+ "template_validator_test.py",
+ ],
+ imports = ["../.."],
+ main = "template_validator_test.py",
+ deps = [
+ ":template_validator_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:no_web_server_test_case_lib",
+ "//web_console_v2/api/testing/workflow_template",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "utils_lib",
+ srcs = ["utils.py"],
+ imports = ["../.."],
+ deps = [
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_test(
+ name = "utils_lib_test",
+ size = "small",
+ srcs = [
+ "utils_test.py",
+ ],
+ imports = ["../.."],
+ main = "utils_test.py",
+ deps = [
+ ":utils_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
+
+py_library(
+ name = "apis_lib",
+ srcs = ["apis.py"],
+ imports = ["../.."],
+ deps = [
+ ":models_lib",
+ ":service_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole:exceptions_lib",
+ "//web_console_v2/api/fedlearner_webconsole/audit:decorators_lib",
+ "//web_console_v2/api/fedlearner_webconsole/auth:third_party_sso_lib",
+ "//web_console_v2/api/fedlearner_webconsole/rpc/v2:project_service_client_lib",
+ "//web_console_v2/api/fedlearner_webconsole/swagger:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:flask_utils_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:paginate_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils:proto_lib",
+ "//web_console_v2/api/fedlearner_webconsole/utils/decorators:decorators_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "@common_flask_restful//:pkg",
+ "@common_marshmallow//:pkg",
+ "@common_sqlalchemy//:pkg",
+ ],
+)
+
+py_test(
+ name = "apis_lib_test",
+ size = "medium",
+ srcs = [
+ "apis_test.py",
+ ],
+ imports = ["../.."],
+ main = "apis_test.py",
+ deps = [
+ ":apis_lib",
+ "//web_console_v2/api/fedlearner_webconsole:db_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:models_lib",
+ "//web_console_v2/api/fedlearner_webconsole/workflow_template:service_lib",
+ "//web_console_v2/api/protocols/fedlearner_webconsole/proto:py_proto",
+ "//web_console_v2/api/testing:common_lib",
+ "@com_google_protobuf//:protobuf_python",
+ ],
+)
diff --git a/web_console_v2/api/fedlearner_webconsole/workflow_template/apis.py b/web_console_v2/api/fedlearner_webconsole/workflow_template/apis.py
index 791f2ba89..36f9cdc6e 100644
--- a/web_console_v2/api/fedlearner_webconsole/workflow_template/apis.py
+++ b/web_console_v2/api/fedlearner_webconsole/workflow_template/apis.py
@@ -1,4 +1,4 @@
-# Copyright 2021 The FedLearner Authors. All Rights Reserved.
+# Copyright 2023 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,281 +13,504 @@
# limitations under the License.
# coding: utf-8
-import io
-import json
-import re
from http import HTTPStatus
-import logging
-import tarfile
-from flask import send_file
-from flask_restful import Resource, reqparse, request
-from google.protobuf.json_format import ParseDict, ParseError
-
-from fedlearner_webconsole.utils.decorators import jwt_required
-from fedlearner_webconsole.workflow_template.models import WorkflowTemplate, \
+import grpc
+from flask_restful import Resource
+from sqlalchemy.orm import undefer
+from marshmallow import fields, Schema, post_load
+from fedlearner_webconsole.audit.decorators import emits_event
+from fedlearner_webconsole.participant.models import Participant
+from fedlearner_webconsole.proto.workflow_template_pb2 import WorkflowTemplateRevisionJson
+from fedlearner_webconsole.rpc.v2.project_service_client import ProjectServiceClient
+from fedlearner_webconsole.swagger.models import schema_manager
+from fedlearner_webconsole.utils.decorators.pp_flask import input_validator, use_args, use_kwargs
+from fedlearner_webconsole.auth.third_party_sso import credentials_required
+from fedlearner_webconsole.utils.flask_utils import download_json, make_flask_response, get_current_user, FilterExpField
+from fedlearner_webconsole.utils.paginate import paginate
+from fedlearner_webconsole.utils.proto import to_dict
+from fedlearner_webconsole.workflow_template.models import WorkflowTemplate, WorkflowTemplateRevision, \
WorkflowTemplateKind
-from fedlearner_webconsole.proto import workflow_definition_pb2
+from fedlearner_webconsole.workflow_template.service import (WorkflowTemplateService, _format_template_with_yaml_editor,
+ _check_config_and_editor_info,
+ WorkflowTemplateRevisionService)
from fedlearner_webconsole.db import db
-from fedlearner_webconsole.exceptions import (NotFoundException,
- InvalidArgumentException,
- ResourceConflictException)
-from fedlearner_webconsole.workflow_template.slots_formatter import \
- generate_yaml_template
-from fedlearner_webconsole.workflow_template.template_validaor\
- import check_workflow_definition
-
-
-def _classify_variable(variable):
- if variable.value_type == 'CODE':
- try:
- json.loads(variable.value)
- except json.JSONDecodeError as e:
- raise InvalidArgumentException(str(e))
- return variable
-
-
-def dict_to_workflow_definition(config):
- try:
- template_proto = ParseDict(
- config, workflow_definition_pb2.WorkflowDefinition())
- for variable in template_proto.variables:
- _classify_variable(variable)
- for job in template_proto.job_definitions:
- for variable in job.variables:
- _classify_variable(variable)
- except ParseError as e:
- raise InvalidArgumentException(details={'config': str(e)})
- return template_proto
-
-
-def dict_to_editor_info(editor_info):
- try:
- editor_info_proto = ParseDict(
- editor_info, workflow_definition_pb2.WorkflowTemplateEditorInfo())
- except ParseError as e:
- raise InvalidArgumentException(details={'editor_info': str(e)})
- return editor_info_proto
-
-
-def _dic_without_key(d, keys):
- result = dict(d)
- for key in keys:
- del result[key]
- return result
+from fedlearner_webconsole.exceptions import NotFoundException, InvalidArgumentException, ResourceConflictException, \
+ NetworkException
+from fedlearner_webconsole.proto.workflow_template_pb2 import WorkflowTemplateJson
+
+
+class PostWorkflowTemplatesParams(Schema):
+ config = fields.Dict(required=True)
+ editor_info = fields.Dict(required=False, load_default={})
+ name = fields.String(required=True)
+ comment = fields.String(required=False, load_default=None)
+ kind = fields.Integer(required=False, load_default=0)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'], data['editor_info'] = _check_config_and_editor_info(data['config'], data['editor_info'])
+ return data
+
+
+class PutWorkflowTemplatesParams(Schema):
+ config = fields.Dict(required=True)
+ editor_info = fields.Dict(required=False, load_default={})
+ name = fields.String(required=True)
+ comment = fields.String(required=False, load_default=None)
+
+ @post_load()
+ def make(self, data, **kwargs):
+ data['config'], data['editor_info'] = _check_config_and_editor_info(data['config'], data['editor_info'])
+ return data
+
+
+class GetWorkflowTemplatesParams(Schema):
+ filter = FilterExpField(required=False, load_default=None)
+ page = fields.Integer(required=False, load_default=None)
+ page_size = fields.Integer(required=False, load_default=None)
class WorkflowTemplatesApi(Resource):
- @jwt_required()
- def get(self):
- preset_datajoin = request.args.get('from', '') == 'preset_datajoin'
- templates = WorkflowTemplate.query
- if 'group_alias' in request.args:
- templates = templates.filter_by(
- group_alias=request.args['group_alias'])
- if 'is_left' in request.args:
- is_left = request.args.get(key='is_left', type=int)
- if is_left is None:
- raise InvalidArgumentException('is_left must be 0 or 1')
- templates = templates.filter_by(is_left=is_left)
- if preset_datajoin:
- templates = templates.filter_by(
- kind=WorkflowTemplateKind.PRESET_DATAJOIN.value)
- # remove config from dicts to reduce the size of the list
- return {
- 'data': [
- _dic_without_key(t.to_dict(), ['config', 'editor_info'])
- for t in templates.all()
- ]
- }, HTTPStatus.OK
-
- @jwt_required()
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('name', required=True, help='name is empty')
- parser.add_argument('comment')
- parser.add_argument('config',
- type=dict,
- required=True,
- help='config is empty')
- parser.add_argument('editor_info', type=dict, default={})
- parser.add_argument('kind', type=int, default=0)
- data = parser.parse_args()
- name = data['name']
- comment = data['comment']
- config = data['config']
- editor_info = data['editor_info']
- kind = data['kind']
- if WorkflowTemplate.query.filter_by(name=name).first() is not None:
- raise ResourceConflictException(
- 'Workflow template {} already exists'.format(name))
- template_proto, editor_info_proto = _check_config_and_editor_info(
- config, editor_info)
- template_proto = _format_template_with_yaml_editor(
- template_proto, editor_info_proto)
- template = WorkflowTemplate(name=name,
- comment=comment,
- group_alias=template_proto.group_alias,
- is_left=template_proto.is_left,
- kind=kind)
- template.set_config(template_proto)
- template.set_editor_info(editor_info_proto)
- db.session.add(template)
- db.session.commit()
- logging.info('Inserted a workflow_template to db')
- result = template.to_dict()
- return {'data': result}, HTTPStatus.CREATED
+
+ @credentials_required
+ @use_args(GetWorkflowTemplatesParams(), location='query')
+ def get(self, params: dict):
+ """Get templates.
+ ---
+ tags:
+ - workflow_template
+ description: Get templates list.
+ parameters:
+ - in: query
+ name: filter
+ schema:
+ type: string
+ required: true
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: list of workflow templates.
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplateRef'
+ """
+ with db.session_scope() as session:
+ try:
+ pagination = WorkflowTemplateService(session).list_workflow_templates(
+ filter_exp=params['filter'],
+ page=params['page'],
+ page_size=params['page_size'],
+ )
+ except ValueError as e:
+ raise InvalidArgumentException(details=f'Invalid filter: {str(e)}') from e
+ data = [t.to_ref() for t in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+ @input_validator
+ @credentials_required
+ @emits_event(audit_fields=['name'])
+ @use_args(PostWorkflowTemplatesParams(), location='json')
+ def post(self, params: dict):
+ """Create a workflow_template.
+ ---
+ tags:
+ - workflow_template
+ description: Create a template.
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/PostWorkflowTemplatesParams'
+ required: true
+ responses:
+ 201:
+ description: detail of workflow template.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplatePb'
+ """
+ with db.session_scope() as session:
+ template = WorkflowTemplateService(session).post_workflow_template(
+ name=params['name'],
+ comment=params['comment'],
+ config=params['config'],
+ editor_info=params['editor_info'],
+ kind=params['kind'],
+ creator_username=get_current_user().username)
+ session.commit()
+ return make_flask_response(data=template.to_proto(), status=HTTPStatus.CREATED)
class WorkflowTemplateApi(Resource):
- @jwt_required()
- def get(self, template_id):
- download = request.args.get('download', 'false') == 'true'
-
- template = WorkflowTemplate.query.filter_by(id=template_id).first()
- if template is None:
- raise NotFoundException(f'Failed to find template: {template_id}')
-
- result = template.to_dict()
- if download:
- in_memory_file = io.BytesIO()
- in_memory_file.write(json.dumps(result).encode('utf-8'))
- in_memory_file.seek(0)
- return send_file(in_memory_file,
- as_attachment=True,
- attachment_filename=f'{template.name}.json',
- mimetype='application/json; charset=UTF-8',
- cache_timeout=0)
- return {'data': result}, HTTPStatus.OK
-
- @jwt_required()
+
+ @credentials_required
+ @use_args({'download': fields.Bool(required=False, load_default=False)}, location='query')
+ def get(self, params: dict, template_id: int):
+ """Get template by id.
+ ---
+ tags:
+ - workflow_template
+ description: Get a template.
+ parameters:
+ - in: path
+ name: template_id
+ schema:
+ type: integer
+ required: true
+ - in: query
+ name: download
+ schema:
+ type: boolean
+ responses:
+ 200:
+ description: detail of workflow template.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplatePb'
+ """
+ with db.session_scope() as session:
+ template = session.query(WorkflowTemplate).filter_by(id=template_id).first()
+ if template is None:
+ raise NotFoundException(f'Failed to find template: {template_id}')
+ template_proto = template.to_proto()
+ if params['download']:
+ # Note this is a workaround to removes some fields from the proto.
+ # WorkflowTemplateJson and WorkflowTemplatePb are compatible.
+ template_json_pb = WorkflowTemplateJson()
+ template_json_pb.ParseFromString(template_proto.SerializeToString())
+ return download_json(content=to_dict(template_json_pb), filename=template.name)
+ return make_flask_response(template_proto)
+
+ @credentials_required
+ @emits_event()
def delete(self, template_id):
- result = WorkflowTemplate.query.filter_by(id=template_id)
- if result.first() is None:
- raise NotFoundException(f'Failed to find template: {template_id}')
- result.delete()
- db.session.commit()
- return {'data': {}}, HTTPStatus.OK
-
- @jwt_required()
- def put(self, template_id):
- parser = reqparse.RequestParser()
- parser.add_argument('name', required=True, help='name is empty')
- parser.add_argument('comment')
- parser.add_argument('config',
- type=dict,
- required=True,
- help='config is empty')
- parser.add_argument('editor_info', type=dict, default={})
- parser.add_argument('kind', type=int, default=0)
- data = parser.parse_args()
- name = data['name']
- comment = data['comment']
- config = data['config']
- editor_info = data['editor_info']
- kind = data['kind']
- tmp = WorkflowTemplate.query.filter_by(name=name).first()
- if tmp is not None and tmp.id != template_id:
- raise ResourceConflictException(
- 'Workflow template {} already exists'.format(name))
- template = WorkflowTemplate.query.filter_by(id=template_id).first()
- if template is None:
- raise NotFoundException(f'Failed to find template: {template_id}')
- template_proto, editor_info_proto = _check_config_and_editor_info(
- config, editor_info)
- template_proto = _format_template_with_yaml_editor(
- template_proto, editor_info_proto)
- template.set_config(template_proto)
- template.set_editor_info(editor_info_proto)
- template.name = name
- template.comment = comment
- template.group_alias = template_proto.group_alias
- template.is_left = template_proto.is_left
- template.kind = kind
- db.session.commit()
- result = template.to_dict()
- return {'data': result}, HTTPStatus.OK
-
-
-def _format_template_with_yaml_editor(template_proto, editor_info_proto):
- for job_def in template_proto.job_definitions:
- # if job is in editor_info, than use meta_yaml format with
- # slots instead of yaml_template
- yaml_editor_infos = editor_info_proto.yaml_editor_infos
- if not job_def.expert_mode and job_def.name in yaml_editor_infos:
- yaml_editor_info = yaml_editor_infos[job_def.name]
- job_def.yaml_template = generate_yaml_template(
- yaml_editor_info.meta_yaml,
- yaml_editor_info.slots)
- try:
- check_workflow_definition(template_proto)
- except ValueError as e:
- raise InvalidArgumentException(
- details={'config.yaml_template': str(e)})
- return template_proto
-
-
-def _check_config_and_editor_info(config, editor_info):
- # TODO: needs tests
- if 'group_alias' not in config:
- raise InvalidArgumentException(
- details={'config.group_alias': 'config.group_alias is required'})
- if 'is_left' not in config:
- raise InvalidArgumentException(
- details={'config.is_left': 'config.is_left is required'})
-
- # form to proto buffer
- editor_info_proto = dict_to_editor_info(editor_info)
- template_proto = dict_to_workflow_definition(config)
- for index, job_def in enumerate(template_proto.job_definitions):
- # pod label name must be no more than 63 characters.
- # workflow.uuid is 20 characters, pod name suffix such as
- # '-follower-master-0' is less than 19 characters, so the
- # job name must be no more than 24
- if len(job_def.name) > 24:
- raise InvalidArgumentException(
- details={
- f'config.job_definitions[{index}].job_name':
- 'job_name must be no more than 24 characters'
- })
- # limit from k8s
- if not re.match('[a-z0-9-]*', job_def.name):
- raise InvalidArgumentException(
- details={
- f'config.job_definitions[{index}].job_name':
- 'Only letters(a-z), numbers(0-9) '
- 'and dashes(-) are supported.'
- })
- return template_proto, editor_info_proto
-
-
-class CodeApi(Resource):
- @jwt_required()
- def get(self):
- parser = reqparse.RequestParser()
- parser.add_argument('code_path',
- type=str,
- location='args',
- required=True,
- help='code_path is required')
- data = parser.parse_args()
- code_path = data['code_path']
- try:
- with tarfile.open(code_path) as tar:
- code_dict = {}
- for file in tar.getmembers():
- if tar.extractfile(file) is not None:
- if '._' not in file.name and file.isfile():
- code_dict[file.name] = str(
- tar.extractfile(file).read(), encoding='utf-8')
- return {'data': code_dict}, HTTPStatus.OK
- except Exception as e:
- logging.error(f'Get code, code_path: {code_path}, exception: {e}')
- raise InvalidArgumentException(details={'code_path': 'wrong path'})
+ """delete template by id.
+ ---
+ tags:
+ - workflow_template
+ description: Delete a template.
+ parameters:
+ - in: path
+ name: template_id
+ schema:
+ type: integer
+ required: true
+ responses:
+ 204:
+ description: Successfully deleted.
+ """
+ with db.session_scope() as session:
+ result = session.query(WorkflowTemplate).filter_by(id=template_id)
+ if result.first() is None:
+ raise NotFoundException(f'Failed to find template: {template_id}')
+ result.delete()
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+ @input_validator
+ @credentials_required
+ @emits_event(audit_fields=['name'])
+ @use_args(PutWorkflowTemplatesParams(), location='json')
+ def put(self, params: dict, template_id: int):
+ """Put a workflow_template.
+ ---
+ tags:
+ - workflow_template
+ description: edit a template.
+ parameters:
+ - in: path
+ name: template_id
+ schema:
+ type: integer
+ required: true
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/PutWorkflowParams'
+ required: true
+ responses:
+ 200:
+ description: detail of workflow template.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplatePb'
+ """
+ with db.session_scope() as session:
+ tmp = session.query(WorkflowTemplate).filter_by(name=params['name']).first()
+ if tmp is not None and tmp.id != template_id:
+ raise ResourceConflictException(f'Workflow template {params["name"]} already exists')
+ template = session.query(WorkflowTemplate).filter_by(id=template_id).first()
+ if template is None:
+ raise NotFoundException(f'Failed to find template: {template_id}')
+ template_proto = _format_template_with_yaml_editor(params['config'], params['editor_info'], session)
+ template.set_config(template_proto)
+ template.set_editor_info(params['editor_info'])
+ template.name = params['name']
+ template.comment = params['comment']
+ template.group_alias = template_proto.group_alias
+ session.commit()
+ return make_flask_response(template.to_proto())
+
+
+class WorkflowTemplateRevisionsApi(Resource):
+
+ @credentials_required
+ @use_args(
+ {
+ 'page': fields.Integer(required=False, load_default=None),
+ 'page_size': fields.Integer(required=False, load_default=None)
+ },
+ location='query')
+ def get(self, params: dict, template_id: int):
+ """Get all template revisions for specific template.
+ ---
+ tags:
+ - workflow_template
+ description: Get all template revisions for specific template.
+ parameters:
+ - in: path
+ name: template_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the template
+ - in: query
+ name: page
+ schema:
+ type: integer
+ - in: query
+ name: page_size
+ schema:
+ type: integer
+ responses:
+ 200:
+ description: list of workflow template revisions.
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplateRevisionRef'
+ """
+ with db.session_scope() as session:
+ query = session.query(WorkflowTemplateRevision).filter_by(template_id=template_id)
+ query = query.order_by(WorkflowTemplateRevision.revision_index.desc())
+ pagination = paginate(query, params['page'], params['page_size'])
+ data = [t.to_ref() for t in pagination.get_items()]
+ return make_flask_response(data=data, page_meta=pagination.get_metadata())
+
+
+class WorkflowTemplateRevisionsCreateApi(Resource):
+
+ @credentials_required
+ def post(self, template_id: int):
+ """Create a new template revision for specific template if config has been changed.
+ ---
+ tags:
+ - workflow_template
+ description: Create a new template revision for specific template if config has been changed.
+ parameters:
+ - in: path
+ name: template_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the template
+ responses:
+ 200:
+ description: detail of workflow template revision.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplateRevisionPb'
+ """
+ with db.session_scope() as session:
+ revision = WorkflowTemplateRevisionService(session).create_new_revision_if_template_updated(
+ template_id=template_id)
+ session.commit()
+ return make_flask_response(data=revision.to_proto())
+
+
+class WorkflowTemplateRevisionApi(Resource):
+
+ @credentials_required
+ @use_args({'download': fields.Boolean(required=False, load_default=None)}, location='query')
+ def get(self, params: dict, revision_id: int):
+ """Get template revision by id.
+ ---
+ tags:
+ - workflow_template
+ description: Get template revision.
+ parameters:
+ - in: path
+ name: revision_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the template revision
+ - in: query
+ name: download
+ schema:
+ type: boolean
+ responses:
+ 200:
+ description: detail of workflow template revision.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplateRevisionPb'
+ """
+ with db.session_scope() as session:
+ template_revision = session.query(WorkflowTemplateRevision).options(
+ undefer(WorkflowTemplateRevision.config),
+ undefer(WorkflowTemplateRevision.editor_info)).get(revision_id)
+ if template_revision is None:
+ raise NotFoundException(f'Cant not find template revision {revision_id}')
+ if params['download']:
+ # Note this is a workaround to removes some fields from the proto.
+ # WorkflowTemplateRevisionJson and WorkflowTemplateRevisionPb are compatible.
+ revision_proto = template_revision.to_proto()
+ revision_json_pb = WorkflowTemplateRevisionJson()
+ revision_json_pb.ParseFromString(revision_proto.SerializeToString())
+ return download_json(content=to_dict(revision_json_pb), filename=template_revision.id)
+ return make_flask_response(data=template_revision.to_proto())
+
+ @credentials_required
+ def delete(self, revision_id: int):
+ """Delete template revision by id.
+ ---
+ tags:
+ - workflow_template
+ description: Delete template revision.
+ parameters:
+ - in: path
+ name: revision_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the template revision
+ responses:
+ 204:
+ description: No content.
+ """
+ with db.session_scope() as session:
+ WorkflowTemplateRevisionService(session).delete_revision(revision_id=revision_id)
+ session.commit()
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
+
+ @credentials_required
+ @use_args({'comment': fields.String(required=False, load_default=None)})
+ def patch(self, params: dict, revision_id: int):
+ """Patch template revision by id.
+ ---
+ tags:
+ - workflow_template
+ description: Patch template revision.
+ parameters:
+ - in: path
+ name: revision_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the template revision
+ - in: body
+ name: comment
+ schema:
+ type: string
+ required: false
+ responses:
+ 200:
+ description: detail of workflow template revision.
+ content:
+ application/json:
+ schema:
+ $ref: '#/definitions/fedlearner_webconsole.proto.WorkflowTemplateRevisionPb'
+ """
+ with db.session_scope() as session:
+ template_revision = session.query(WorkflowTemplateRevision).options(
+ undefer(WorkflowTemplateRevision.config),
+ undefer(WorkflowTemplateRevision.editor_info)).get(revision_id)
+ if template_revision is None:
+ raise NotFoundException(f'Cant not find template revision {revision_id}')
+ if params['comment']:
+ template_revision.comment = params['comment']
+ session.commit()
+ return make_flask_response(data=template_revision.to_proto())
+
+
+class WorkflowTemplateRevisionSendApi(Resource):
+
+ @use_kwargs({
+ 'participant_id': fields.Integer(required=True),
+ }, location='query')
+ def post(self, revision_id: int, participant_id: int):
+ """Send a template revision to participant.
+ ---
+ tags:
+ - workflow_template
+ description: Send a template revision to participant.
+ parameters:
+ - in: path
+ name: revision_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the template revision
+ - in: query
+ name: participant_id
+ required: true
+ schema:
+ type: integer
+ description: The ID of the participant
+ responses:
+ 204:
+ description: No content.
+ """
+ with db.session_scope() as session:
+ part: Participant = session.query(Participant).get(participant_id)
+ if part is None:
+ raise NotFoundException(f'participant {participant_id} is not exist')
+ revision: WorkflowTemplateRevision = session.query(WorkflowTemplateRevision).get(revision_id)
+ if revision is None:
+ raise NotFoundException(f'participant {revision_id} is not exist')
+ try:
+ ProjectServiceClient.from_participant(part.domain_name).send_template_revision(
+ config=revision.get_config(),
+ name=revision.template.name,
+ comment=revision.comment,
+ kind=WorkflowTemplateKind.PEER,
+ revision_index=revision.revision_index)
+ except grpc.RpcError as e:
+ raise NetworkException(str(e)) from e
+
+ return make_flask_response(status=HTTPStatus.NO_CONTENT)
def initialize_workflow_template_apis(api):
api.add_resource(WorkflowTemplatesApi, '/workflow_templates')
- api.add_resource(WorkflowTemplateApi,
- '/workflow_templates/