1 # Copyright 2021 Spirent Communications.
2 # sridhar.rao@spirent.com
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 Tool to suggest which ML approach is more applicable for
18 a particular data and usecase.
24 2. Add Informative data to the user.
27 from __future__ import print_function
30 from pypsi import wizard as wiz
31 from pypsi.shell import Shell
33 # pylint: disable=line-too-long,too-few-public-methods,too-many-instance-attributes, too-many-nested-blocks, too-many-return-statements, too-many-branches
48 class AlgoSelectorWizard():
50 Class to create wizards
54 Perform Initialization.
58 self.main_l1_values = {}
59 self.main_l2a_values = {}
60 self.main_l2b_values = {}
61 self.main_l3_values = {}
62 self.main_l4_values = {}
63 self.unsup_values = {}
67 self.wiz_main_l1 = None
68 self.wiz_main_l2_a = None
69 self.wiz_main_l2_b = None
70 self.wiz_main_l3 = None
71 self.wiz_main_l4 = None
72 self.wiz_generic = None
73 self.wiz_unsupervised = None
74 self.wiz_reinforcement = None
75 self.ml_needed = False
76 self.supervised = False
77 self.unsupervised = False
78 self.reinforcement = False
79 self.data_size = 'high'
80 self.interpretability = False
82 self.ftod_ratio = 'low'
83 self.reproducibility = False
86 ############# All the Wizards ##################################
88 ### GENERIC Wizards - Need for ML ##############################
89 def main_wizard_l1(self):
93 self.wiz_main_l1 = wiz.PromptWizard(
94 name=Bcolors.OKBLUE+"Do you Need ML - Data Availability"+Bcolors.ENDC,
97 # The list of input prompts to ask the user.
99 # ID where the value will be stored
100 id="data_availability",
102 name=Bcolors.HEADER+"Do you have access to data about different situations, or that describes a lot of examples of situations"+Bcolors.ENDC,
104 help="Y/N/U - Yes/No/Unknown",
105 validators=(wiz.required_validator),
111 def main_wizard_l2_a(self):
115 self.wiz_main_l2_a = wiz.PromptWizard(
116 name=Bcolors.OKBLUE+"Do you Need ML - Data Creation"+Bcolors.ENDC,
119 # The list of input prompts to ask the user.
121 # ID where the value will be stored
122 id="data_creativity",
124 name=Bcolors.HEADER+"Will a system be able to gather a lot of data by trying sequences of actions in many different situations and seeing the results"+Bcolors.ENDC,
126 help="Y/N/U - Yes/No/Unknown",
127 validators=(wiz.required_validator),
133 def main_wizard_l2_b(self):
137 label = """ One or more meaningful and informative 'tag' to provide context so that a machine learning model can learn from it. For example, labels might indicate whether a photo contains a bird or car, which words were uttered in an audio recording, or if an x-ray contains a tumor. Data labeling is required for a variety of use cases including computer vision, natural language processing, and speech recognition."""
138 self.wiz_main_l2_b = wiz.PromptWizard(
139 name=Bcolors.OKBLUE+"Do you Need ML - Data Programmability"+Bcolors.ENDC,
142 # The list of input prompts to ask the user.
144 # ID where the value will be stored
147 name=Bcolors.HEADER+" Do you have Labelled data? (Type Y/N/U - Yes/No/Unknown). Type help for description of label. "+Bcolors.ENDC,
150 validators=(wiz.required_validator),
154 # ID where the value will be stored
155 id="data_programmability",
157 name=Bcolors.HEADER+"Can a program or set of rules decide what actions to take based on the data you have about the situations"+Bcolors.ENDC,
159 help="Y/N/U - Yes/No/Unknown",
160 validators=(wiz.required_validator),
167 def main_wizard_l3(self):
171 self.wiz_main_l3 = wiz.PromptWizard(
172 name=Bcolors.OKBLUE+"Do you Need ML - Data Knowledge"+Bcolors.ENDC,
175 # The list of input prompts to ask the user.
177 # ID where the value will be stored
180 name=Bcolors.HEADER+"Could a knowledgeable human decide what actions to take based on the data you have about the situations"+Bcolors.ENDC,
182 help="Y/N/U - Yes/No/Unknown",
183 validators=(wiz.required_validator),
189 def main_wizard_l4(self):
193 self.wiz_main_l4 = wiz.PromptWizard(
194 name=Bcolors.OKBLUE+"Do you Need ML - Data Pattern"+Bcolors.ENDC,
197 # The list of input prompts to ask the user.
199 # ID where the value will be stored
202 name=Bcolors.HEADER+"Could there be patterns in these situations that the humans haven't recognized before"+Bcolors.ENDC,
204 help="Y/N/U - Yes/No/Unknown",
205 validators=(wiz.required_validator),
210 ### GENERIC Wizards - GOAL, METRICS, DATA ##############################
211 def gen_wizard(self):
213 Generic Wizard - Goal, metrics, data
215 self.wiz_generic = wiz.PromptWizard(
216 name=Bcolors.OKBLUE+"Understanding Goal, Metrics, Data and Output Type"+Bcolors.ENDC,
219 # The list of input prompts to ask the user.
221 # ID where the value will be stored
224 name=Bcolors.HEADER+" What is your goal with the data? Predict, Describe or Explore"+Bcolors.ENDC,
226 help="Enter one of Predict/Describe/Explore",
227 validators=(wiz.required_validator),
231 # ID where the value will be stored
232 id="metric_accuracy",
234 name=Bcolors.HEADER+" How important the metric 'Accuracy' is for you? 1-5: 1- Least important 5- Most Important"+Bcolors.ENDC,
236 help="Enter 1-5: 1 being least important, and 5 being most important",
237 validators=(wiz.required_validator),
241 # ID where the value will be stored
244 name=Bcolors.HEADER+" How important the metric 'Speed' is for you? 1-5: 1- Least important 5- Most Important"+Bcolors.ENDC,
246 help="Enter 1-5: 1 being least important, and 5 being most important",
247 validators=(wiz.required_validator),
251 # ID where the value will be stored
252 id="metric_interpretability",
254 name=Bcolors.HEADER+" How important the metric 'Interpretability' is for you? 1-5: 1- Least important 5- Most Important"+Bcolors.ENDC,
256 help="Enter 1-5: 1 being least important, and 5 being most important",
257 validators=(wiz.required_validator),
261 # ID where the value will be stored
262 id="metric_reproducibility",
264 name=Bcolors.HEADER+" How important the metric 'Reproducibility' is for you? 1-5: 1- Least important 5- Most Important"+Bcolors.ENDC,
266 help="Enter 1-5: 1 being least important, and 5 being most important",
267 validators=(wiz.required_validator),
271 # ID where the value will be stored
272 id="metric_implementation",
274 name=Bcolors.HEADER+" How important the metric 'Ease of Implementation and Maintenance' is for you? 1-5: 1- Least important 5- Most Important"+Bcolors.ENDC,
276 help="Enter 1-5: 1 being least important, and 5 being most important",
277 validators=(wiz.required_validator),
281 # ID where the value will be stored
284 name=Bcolors.HEADER+" What does the data (columns) represent? well defined 'Features', 'signals' (Timeseries, pixels, etc) or Text - (Please type the associated number)"+Bcolors.ENDC,
286 help="1. Well Defined Features\n 2. Signals\n 3. Text - Unstructured\n 4. None of the above\n",
287 validators=(wiz.required_validator),
291 # ID where the value will be stored
292 id="data_signal_type",
294 name=Bcolors.HEADER+" If Signals, can you choose any one from the below list? "+Bcolors.ENDC,
296 help="1. Image\n 2. Audio\n 3. Timeseries\n 4. None of the above\n 5. Not Applicable\n ",
297 validators=(wiz.required_validator),
301 # ID where the value will be stored
304 name=Bcolors.HEADER+" If Text, can you choose any one from the below list? "+Bcolors.ENDC,
306 help="1. Webpages\n 2. Emails\n 3. Social-Media Posts\n 4. Books\n 5. Formal Articles\n 6. Speech converted to text\n 7. None of the above\n 8. Not Applicable\n ",
307 validators=(wiz.required_validator),
311 # ID where the value will be stored
314 name=Bcolors.HEADER+" If features, are they well defined? i.e., are all the variables well understood? "+Bcolors.ENDC,
317 validators=(wiz.required_validator),
321 # ID where the value will be stored
322 id="data_features_count",
324 name=Bcolors.HEADER+" If features, How many are there? "+Bcolors.ENDC,
327 validators=(wiz.required_validator),
331 # ID where the value will be stored
332 id="data_distribution",
334 name=Bcolors.HEADER+" Are you aware of any 'Distribution' that is inherent to the data, we can take advantage of?"+Bcolors.ENDC,
337 validators=(wiz.required_validator),
341 # ID where the value will be stored
342 id="data_io_relation",
344 name=Bcolors.HEADER+" Is the probability of 'Linear Relation' between input and the output is high?"+Bcolors.ENDC,
347 validators=(wiz.required_validator),
351 # ID where the value will be stored
352 id="data_correlation",
354 name=Bcolors.HEADER+" Are you confident that there is NO high correlation among the independent variables in your day?"+Bcolors.ENDC,
356 help="Y/N/U. Change in one ",
357 validators=(wiz.required_validator),
361 # ID where the value will be stored
362 id="data_cond_indep",
364 name=Bcolors.HEADER+" Are you confident that the variables are conditionally independent?"+Bcolors.ENDC,
366 help="Y/N/U. If probability that it rains given lightining and thunder is same as probability that it rains given lightining, then rain and thunder are conditionally independent",
367 validators=(wiz.required_validator),
371 # ID where the value will be stored
374 name=Bcolors.HEADER+" Are there any missing values in the data? "+Bcolors.ENDC,
377 validators=(wiz.required_validator),
381 # ID where the value will be stored
382 id="data_size_bytes",
384 name=Bcolors.HEADER+" How big is the data in terms of size? (Use K/M/G Bytes unit) "+Bcolors.ENDC,
386 help="Number(integer) and unit: K for Kilo, M for Mega and G for Giga. Ex: 10G for 10 Giga bytes",
387 validators=(wiz.required_validator),
391 # ID where the value will be stored
392 id="data_size_samples",
394 name=Bcolors.HEADER+" How big is the data in terms of samples? (Use T/M/B Samples) "+Bcolors.ENDC,
396 help="Number(integer) and unit: T for Thousand, M for Million and B for Billion. Ex: 1M for 1 Million Samples",
397 validators=(wiz.required_validator),
401 # ID where the value will be stored
402 id="data_type_output",
404 name=Bcolors.HEADER+" What is the expected output data type ? (Please type number associated with type in 'help') "+Bcolors.ENDC,
406 help=" 1:Numerical-Discrete\n 2:Numerical-Continuous\n 3:Ordinal\n 4:Categorical-Binary\n 5:Categorical-Multiclass",
407 validators=(wiz.required_validator),
411 # ID where the value will be stored
412 id="data_output_prob",
414 name=Bcolors.HEADER+" Is the expected output data a probability value ? "+Bcolors.ENDC,
417 validators=(wiz.required_validator),
424 def unsupervised_wizard(self):
426 The Un-Supervized Learning Wizard
428 self.wiz_generic = wiz.PromptWizard(
429 name=Bcolors.OKBLUE+"Understanding Goal, Metrics, Data and Output Type"+Bcolors.ENDC,
432 # The list of input prompts to ask the user.
434 # ID where the value will be stored
437 name=Bcolors.HEADER+" What is the main goal? (Please type number associated with type in 'help')"+Bcolors.ENDC,
439 help="1: Explore Similar Groups (clustering) \n 2: Perform Dimensionality Reduction\n 3: Others\n",
440 validators=(wiz.required_validator),
444 # ID where the value will be stored
445 id="unsup_dr_topic_mod",
447 name=Bcolors.HEADER+" If dimensionality reduction, do you prefer topic modelling ? (Please type NA is you are not sure)"+Bcolors.ENDC,
450 validators=(wiz.required_validator),
454 # ID where the value will be stored
457 name=Bcolors.HEADER+" Are you aware of density variations in your data ? (Please type NA is you are not sure)"+Bcolors.ENDC,
460 validators=(wiz.required_validator),
464 # ID where the value will be stored
465 id="unsup_clus_outliers",
467 name=Bcolors.HEADER+" Are there too many outliers in your data ? (Please type NA is you are not sure)"+Bcolors.ENDC,
470 validators=(wiz.required_validator),
474 # ID where the value will be stored
475 id="unsup_clus_groups",
477 name=Bcolors.HEADER+" If clustering, do you know how many groups to form? (Please type NA is you are not sure)"+Bcolors.ENDC,
480 validators=(wiz.required_validator),
487 def reinforcement_wizard(self):
489 The Reinforced Learning Wizard
493 |-------| Agent | Action
499 | |----|Environment| |
503 self.wiz_reinforcement = wiz.PromptWizard(
504 name=Bcolors.OKBLUE+"Reinforcement Specific"+Bcolors.ENDC,
507 # The list of input prompts to ask the user.
509 # ID where the value will be stored
512 name=Bcolors.HEADER+" Type help for reference diagram for reinforcement-learning"+Bcolors.ENDC,
515 validators=(wiz.required_validator),
516 default='Type Help or Press Enter'
519 # ID where the value will be stored
520 id="ri_model_preference",
522 name=Bcolors.HEADER+" Do you prefer model-based approach? (Type NA if you are not sure) "+Bcolors.ENDC,
525 validators=(wiz.required_validator),
529 # ID where the value will be stored
530 id="ri_model_availability",
532 name=Bcolors.HEADER+" Do you have a model for model-based approach? (Type NA if not applicable) "+Bcolors.ENDC,
535 validators=(wiz.required_validator),
539 # ID where the value will be stored
540 id="ri_modelfree_value",
542 name=Bcolors.HEADER+" In Model-Free approach, do you prefer value-based approach? (Type NA if not applicable) "+Bcolors.ENDC,
545 validators=(wiz.required_validator),
549 # ID where the value will be stored
550 id="ri_modelfree_value_state",
552 name=Bcolors.HEADER+" In Model-Free Value-Based approach, do you prefer state-only model? (Type NA if not applicable) "+Bcolors.ENDC,
555 validators=(wiz.required_validator),
559 # ID where the value will be stored
562 name=Bcolors.HEADER+" What is the application domain ? (Please type number associated with type in 'help') "+Bcolors.ENDC,
564 help=" 1:Computer Resource Mgmt.\n 2:Robotics\n 3:Traffic-Control\n 4:Reccommenders\n 5:Autonomous Vehicles\n 6:Games\n 7:Chemistry\n 8:Others\n",
565 validators=(wiz.required_validator),
571 ############### All the Run Operations ######################
572 def run_mainwiz(self):
576 self.main_wizard_l1()
577 self.main_l1_values = self.wiz_main_l1.run(self.shell)
578 if self.main_l1_values['data_availability'].lower() == 'y':
579 self.main_wizard_l2_b()
580 self.main_l2b_values = self.wiz_main_l2_b.run(self.shell)
581 if self.main_l2b_values['data_labe'].lower() == 'y':
582 self.supervised = True
584 self.unsupervised = True
585 if self.main_l2b_values['data_programmability'].lower() == 'y':
586 print(Bcolors.FAIL+"ML is not required - Please consider alternate approaches\n"+Bcolors.ENDC)
588 self.main_wizard_l3()
589 self.main_l3_values = self.wiz_main_l3.run(self.shell)
590 if self.main_l3_values['data_knowledge'].lower() == 'y':
591 print(Bcolors.OKGREEN+"Looks like you need ML, let's continue"+Bcolors.ENDC)
592 self.ml_needed = True
594 self.main_wizard_l4()
595 self.main_l4_values = self.wiz_main_l4.run(self.shell)
596 if self.main_l4_values['data_pattern'].lower() == 'y':
597 print(Bcolors.OKGREEN+"Looks like you need ML, let's continue"+Bcolors.ENDC)
598 self.ml_needed = True
600 print(Bcolors.FAIL+"ML is not required - Please consider alternate approaches\n"+Bcolors.ENDC)
602 self.main_wizard_l2_a()
603 self.main_l2a_values = self.wiz_main_l2_a.run(self.shell)
604 if self.main_l2a_values['data_creativity'].lower() == 'y':
605 print(Bcolors.OKGREEN+"Looks like you need ML, let's continue"+Bcolors.ENDC)
606 self.ml_needed = True
607 self.reinforcement = True
609 print(Bcolors.FAIL+"ML is not required - Please consider alternate approaches\n"+Bcolors.ENDC)
611 def run_generic_wizard(self):
616 self.gen_values = self.wiz_generic.run(self.shell)
618 def run_unsupervised_wizard(self):
620 Run UnSupervised Learning Wizard.
622 self.unsupervised_wizard()
623 self.unsup_values = self.wiz_unsupervised.run(self.shell)
625 def run_reinforcement_wizard(self):
627 Run Reinforced Learning Wizard
629 self.reinforcement_wizard()
630 self.ri_values = self.wiz_reinforcement.run(self.shell)
632 def decide_unsupervised(self):
634 Decide which Unsupervised-learning to use
638 if int(self.unsup_values['unsup_goal']) == 1:
640 if 'high' in self.data_size:
641 if not self.reproducibility:
646 if 'y' in self.unsup_values['unsup_clus_dv'].tolower():
647 if 'y' in self.unsup_values['unsup_clus_groups'].tolower():
650 print("Unsupervised Learning model to consider: Hierarchical Clustering")
655 if 'y' in self.unsup_values['unsup_clus_outliers'].tolower():
656 print("Unsupervised Learning model to consider: Hierarchical Clustering")
658 print("Unsupervised Learning model to consider: DBSCAN")
661 if 'y' in self.gen_values['data_output_prob'].tolower():
662 print("Unsupervised Learning model to consider: Gaussian Mixture")
664 print("Unsupervised Learning model to consider: KMeans")
666 elif int(self.unsup_values['unsup_goal']) == 2:
667 # Dimensionality Reduction
668 if 'y' in self.unsup_values['unsup_dr_topic_mod'].tolower():
669 if 'y' in self.gen_values['data_output_prob'].tolower():
670 print("Unsupervised Learning model to consider: SVD")
672 print("Unsupervised Learning model to consider: LDA")
674 print("Unsupervised Learning model to consider: PCA")
676 print("Sorry. We need to discuss, please connect with Anuket Thoth Project <sridhar.rao@spirent.com>")
678 def decide_reinforcement(self):
680 Decide which reinforement learning to use.
682 if (int(self.gen_values['data_type_output']) == 2 or
683 'y' in self.ri_values['ri_model_preference'].tolower()):
685 if 'y' in self.ri_values['ri_model_availability'].tolower():
686 print("Reinforcement Learning model to consider - AlphaZero")
688 print("Reinforcement Learning models to consider - World Models, I2A, MBMF, and MBVE")
689 elif 'n' in self.ri_values['ri_model_preference'].tolower():
690 # Model-Free based approach.
691 if 'y' not in self.ri_values['ri_modelfree_value'].tolower():
692 print("Reinforcement Learning models to consider: Policy Gradient and Actor Critic")
694 if 'y' in self.ri_values['ri_modelfree_value_state'].tolower():
695 print("Reinforcement Learning models to consider - Monte Carlo, TD(0), and TD(Lambda)")
697 print("Reinforcement Learning models to consider - SARSA, QLearning, Deep Queue Nets")
700 print("Sorry. We need to discuss, please connect with Anuket Thoth Project <sridhar.rao@spirent.com>")
702 def perform_inference(self):
704 Perform Inferences. Used across all 3 types.
706 # Decide whether data is Low or High
707 self.data_size = 'unknown'
708 if ('k' in self.gen_values['data_size_bytes'].lower() or
709 't' in self.gen_values['data_size_samples']):
710 self.data_size = 'low'
712 if int(self.gen_values['metric_interpretability']) >= 3 :
713 self.interpretability = True
714 if int(self.gen_values['metric_speed']) >= 3 :
716 if int(self.gen_values['metric_reproducibility']) >= 3 :
717 self.reproducibility = True
719 # Decide Features relative to Data (ftod_ratio) - high/low
720 if ('k' in self.gen_values['data_size_bytes'].lower() or
721 't' in self.gen_values['data_size_samples']):
722 if int(self.gen_values['data_features_count']) > 50:
723 self.ftod_ratio = 'high'
724 elif ('m' in self.gen_values['data_size_bytes'].lower() or
725 'm' in self.gen_values['data_size_samples']):
726 if int(self.gen_values['data_features_count']) > 5000:
727 self.ftod_ratio = 'high'
729 if int(self.gen_values['data_features_count']) > 500000:
730 self.ftod_ratio = 'high'
733 def decide_supervised(self):
735 Decide which Supervised learning to use.
737 if 'high' in self.data_size:
738 # Cover: DT, RF, RNN, CNN, ANN and Naive Bayes
739 if self.interpretability:
741 print("Supervised Learning model to consider - Decision Tree")
743 print("Supervised Learning model to consider - Random Forest")
745 if int(self.gen_values['data_column']) == 3:
746 print("Supervised Learning model to consider - RNN")
747 elif (int(self.gen_values['data_column']) == 2 and
748 int(self.gen_values['data_signal_type']) == 1):
749 print("Supervised Learning model to consider - CNN")
750 elif (int(self.gen_values['data_column']) == 2 and
751 (int(self.gen_values['data_signal_type']) == 2 or
752 int(self.gen_values['data_signal_type']) == 3)):
753 if 'y' in self.gen_values['data_output_prob'].tolower():
754 print("Supervised Learning model to consider - Naive Bayes")
756 print("Supervised Learning model to consider - ANN")
758 print("Supervised model to consider Learning - ANN")
759 elif 'low' in self.data_size:
762 if 'high' in self.ftod_ratio:
765 print("Supervised Learning model to consider - SVN with Gaussian Kernel")
767 if int(self.gen_values['data_type_output']) != 2:
770 if 'y' in self.gen_values['data_io_relation'].tolower():
771 print("Supervised Learning model to consider - Linear Regression or Linear SVM")
773 print("Supervised Learning model to consider - Polynomial Regression or nonLinear SVM")
776 if int(self.gen_values['data_output_type']) == 4:
777 if 'y' in self.gen_values['data_output_prob'].tolower():
778 if 'y' in self.gen_values['data_cond_indep'].tolower():
779 print("Supervised Learning model to consider - Naive Bayes")
781 if 'y' in self.gen_values['data_correlation'].tolower():
782 print("Supervised Learning model to consider - LASSO or Ridge Regression")
784 print("Supervised Learning model to consider - Logistic Regression")
786 print("Supervised Learning model to consider - Polynomial Regression or nonLinear SVM")
789 print("Supervised Learning model to consider - KNN")
792 print("Sorry. We need to discuss, please connect with Anuket Thoth Project <sridhar.rao@spirent.com>")
794 def ask_and_decide(self):
800 self.run_generic_wizard()
802 self.decide_supervised()
803 elif self.unsupervised:
804 self.run_unsupervised_wizard()
805 self.decide_unsupervised()
806 elif self.reinforcement:
807 self.run_reinforcement_wizard()
808 self.decide_reinforcement()
811 def signal_handler(signum, frame):
815 print("\n You interrupted, No Suggestion will be provided!")
824 algowiz = AlgoSelectorWizard()
825 algowiz.ask_and_decide()
826 except(KeyboardInterrupt, MemoryError):
827 print("Some Error Occured - No Suggestion can be provided")
829 print("Thanks for using the Algoselector-Wizard, " +
830 "Hope our suggestion will be useful")
832 if __name__ == "__main__":
833 signal.signal(signal.SIGINT, signal_handler)