%22%22%22How%20to%20Route%20Sample%20Weights%20Through%20OptunaSearchCV.%0A%0APass%20sample_weight%20through%20OptunaSearchCV%20to%20model%20fitting%20and%20scoring.%0A%22%22%22%0A%0A%23%20%2F%2F%2F%20script%0A%23%20requires-python%20%3D%20%22%3E%3D3.11%22%0A%23%20dependencies%20%3D%20%5B%0A%23%20%20%20%20%20%22numpy%22%2C%0A%23%20%20%20%20%20%22optuna%22%2C%0A%23%20%20%20%20%20%22scikit-learn%22%2C%0A%23%20%20%20%20%20%22sklearn-optuna%22%2C%0A%23%20%5D%0A%23%20%2F%2F%2F%0A%0Aimport%20marimo%0A%0A__generated_with%20%3D%20%220.19.9%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20sklearn%0A%20%20%20%20from%20optuna.distributions%20import%20FloatDistribution%0A%20%20%20%20from%20sklearn.datasets%20import%20make_classification%0A%20%20%20%20from%20sklearn.linear_model%20import%20LogisticRegression%0A%20%20%20%20from%20sklearn.metrics%20import%20accuracy_score%2C%20make_scorer%0A%20%20%20%20from%20sklearn.pipeline%20import%20Pipeline%0A%20%20%20%20from%20sklearn.preprocessing%20import%20StandardScaler%0A%20%20%20%20from%20sklearn.utils.class_weight%20import%20compute_sample_weight%0A%0A%20%20%20%20from%20sklearn_optuna%20import%20OptunaSearchCV%0A%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20FloatDistribution%2C%0A%20%20%20%20%20%20%20%20LogisticRegression%2C%0A%20%20%20%20%20%20%20%20OptunaSearchCV%2C%0A%20%20%20%20%20%20%20%20Pipeline%2C%0A%20%20%20%20%20%20%20%20StandardScaler%2C%0A%20%20%20%20%20%20%20%20accuracy_score%2C%0A%20%20%20%20%20%20%20%20compute_sample_weight%2C%0A%20%20%20%20%20%20%20%20make_classification%2C%0A%20%20%20%20%20%20%20%20make_scorer%2C%0A%20%20%20%20%20%20%20%20sklearn%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%20How%20to%20Route%20Sample%20Weights%20Through%20OptunaSearchCV%0A%0A%20%20%20%20This%20notebook%20shows%20how%20to%20enable%20scikit-learn's%20metadata%20routing%0A%20%20%20%20and%20pass%20%60sample_weight%60%20through%20%60OptunaSearchCV%60%20to%20fitting%2C%0A%20%20%20%20scoring%2C%20and%20pipelines.%0A%0A%20%20%20%20**Prerequisites%3A**%20Familiarity%20with%20the%0A%20%20%20%20OptunaSearchCV%20quickstart%0A%20%20%20%20(%5BView%5D(%2Fexamples%2Fquickstart%2F)%20%C2%B7%20%5BOpen%20in%20marimo%5D(%2Fexamples%2Fquickstart%2Fedit%2F))%0A%20%20%20%20and%20scikit-learn%20metadata%20routing%20(requires%20sklearn%20%3E%3D%201.4).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%201.%20Create%20an%20Imbalanced%20Dataset%0A%0A%20%20%20%20Generate%20a%20dataset%20with%2090%2F10%20class%20imbalance%20and%20compute%0A%20%20%20%20balanced%20sample%20weights.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(compute_sample_weight%2C%20make_classification)%3A%0A%20%20%20%20X%2C%20y%20%3D%20make_classification(%0A%20%20%20%20%20%20%20%20n_samples%3D200%2C%0A%20%20%20%20%20%20%20%20n_features%3D10%2C%0A%20%20%20%20%20%20%20%20n_informative%3D5%2C%0A%20%20%20%20%20%20%20%20n_redundant%3D0%2C%0A%20%20%20%20%20%20%20%20n_classes%3D2%2C%0A%20%20%20%20%20%20%20%20weights%3D%5B0.9%2C%200.1%5D%2C%20%20%23%20Imbalanced%3A%2090%25%20class%200%2C%2010%25%20class%201%0A%20%20%20%20%20%20%20%20flip_y%3D0.01%2C%0A%20%20%20%20%20%20%20%20random_state%3D42%2C%0A%20%20%20%20)%0A%0A%20%20%20%20%23%20Compute%20balanced%20sample%20weights%0A%20%20%20%20sample_weight%20%3D%20compute_sample_weight(%22balanced%22%2C%20y)%0A%20%20%20%20return%20X%2C%20sample_weight%2C%20y%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%202.%20Enable%20Routing%20and%20Run%20a%20Search%0A%0A%20%20%20%20Enable%20metadata%20routing%2C%20configure%20the%20estimator%20to%20request%0A%20%20%20%20%60sample_weight%60%2C%20and%20pass%20it%20to%20%60fit()%60.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20FloatDistribution%2C%0A%20%20%20%20LogisticRegression%2C%0A%20%20%20%20OptunaSearchCV%2C%0A%20%20%20%20X%2C%0A%20%20%20%20sample_weight%2C%0A%20%20%20%20sklearn%2C%0A%20%20%20%20y%2C%0A)%3A%0A%20%20%20%20with%20sklearn.config_context(enable_metadata_routing%3DTrue)%3A%0A%20%20%20%20%20%20%20%20%23%20Configure%20estimator%20to%20request%20sample_weight%0A%20%20%20%20%20%20%20%20lr%20%3D%20LogisticRegression(max_iter%3D300%2C%20random_state%3D42)%0A%20%20%20%20%20%20%20%20lr.set_fit_request(sample_weight%3DTrue)%0A%20%20%20%20%20%20%20%20lr.set_score_request(sample_weight%3DTrue)%0A%0A%20%20%20%20%20%20%20%20%23%20Create%20search%20with%20parameter%20distributions%0A%20%20%20%20%20%20%20%20search%20%3D%20OptunaSearchCV(%0A%20%20%20%20%20%20%20%20%20%20%20%20lr%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20param_distributions%3D%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22C%22%3A%20FloatDistribution(0.01%2C%2010.0%2C%20log%3DTrue)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%7D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20n_trials%3D10%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20cv%3D3%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%23%20Fit%20with%20sample_weight%20-%20it%20will%20be%20routed%20to%20fit()%20and%20score()%0A%20%20%20%20%20%20%20%20search.fit(X%2C%20y%2C%20sample_weight%3Dsample_weight)%0A%20%20%20%20return%20(search%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo%2C%20search)%3A%0A%20%20%20%20mo.md(f%22%22%22%0A%20%20%20%20**Best%20Parameters%3A**%20%60C%20%3D%20%7Bsearch.best_params_%5B'C'%5D%3A.4f%7D%60%0A%20%20%20%20**Best%20Weighted%20Score%3A**%20%60%7Bsearch.best_score_%3A.3f%7D%60%0A%20%20%20%20**Trials%20run%3A**%20%60%7Blen(search.cv_results_%5B'params'%5D)%7D%60%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%203.%20Multi-Metric%20Scoring%20with%20Mixed%20Routing%0A%0A%20%20%20%20Create%20scorers%20with%20different%20routing%20preferences%3A%20one%20weighted%2C%0A%20%20%20%20one%20unweighted.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20FloatDistribution%2C%0A%20%20%20%20LogisticRegression%2C%0A%20%20%20%20OptunaSearchCV%2C%0A%20%20%20%20X%2C%0A%20%20%20%20accuracy_score%2C%0A%20%20%20%20make_scorer%2C%0A%20%20%20%20sample_weight%2C%0A%20%20%20%20sklearn%2C%0A%20%20%20%20y%2C%0A)%3A%0A%20%20%20%20with%20sklearn.config_context(enable_metadata_routing%3DTrue)%3A%0A%20%20%20%20%20%20%20%20%23%20Configure%20estimator%0A%20%20%20%20%20%20%20%20lr_multi%20%3D%20LogisticRegression(max_iter%3D300%2C%20random_state%3D42)%0A%20%20%20%20%20%20%20%20lr_multi.set_fit_request(sample_weight%3DTrue)%0A%20%20%20%20%20%20%20%20lr_multi.set_score_request(sample_weight%3DTrue)%0A%0A%20%20%20%20%20%20%20%20%23%20Create%20scorers%20with%20different%20routing%0A%20%20%20%20%20%20%20%20weighted_scorer%20%3D%20make_scorer(accuracy_score)%0A%20%20%20%20%20%20%20%20weighted_scorer.set_score_request(sample_weight%3DTrue)%0A%0A%20%20%20%20%20%20%20%20unweighted_scorer%20%3D%20make_scorer(accuracy_score)%0A%20%20%20%20%20%20%20%20unweighted_scorer.set_score_request(sample_weight%3DFalse)%0A%0A%20%20%20%20%20%20%20%20scoring%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%22weighted_accuracy%22%3A%20weighted_scorer%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22unweighted_accuracy%22%3A%20unweighted_scorer%2C%0A%20%20%20%20%20%20%20%20%7D%0A%0A%20%20%20%20%20%20%20%20search_multi%20%3D%20OptunaSearchCV(%0A%20%20%20%20%20%20%20%20%20%20%20%20lr_multi%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20param_distributions%3D%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22C%22%3A%20FloatDistribution(0.01%2C%2010.0%2C%20log%3DTrue)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%7D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20n_trials%3D10%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20cv%3D3%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20scoring%3Dscoring%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20refit%3D%22weighted_accuracy%22%2C%20%20%23%20Optimize%20for%20weighted%20metric%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20search_multi.fit(X%2C%20y%2C%20sample_weight%3Dsample_weight)%0A%20%20%20%20return%20(search_multi%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo%2C%20search_multi)%3A%0A%20%20%20%20mo.md(f%22%22%22%0A%20%20%20%20**Best%20Parameters%3A**%20%60C%20%3D%20%7Bsearch_multi.best_params_%5B'C'%5D%3A.4f%7D%60%0A%20%20%20%20**Weighted%20Accuracy%3A**%20%60%7Bsearch_multi.cv_results_%5B'mean_test_weighted_accuracy'%5D%5Bsearch_multi.best_index_%5D%3A.3f%7D%60%0A%20%20%20%20**Unweighted%20Accuracy%3A**%20%60%7Bsearch_multi.cv_results_%5B'mean_test_unweighted_accuracy'%5D%5Bsearch_multi.best_index_%5D%3A.3f%7D%60%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%204.%20Route%20Metadata%20in%20a%20Pipeline%0A%0A%20%20%20%20Configure%20each%20pipeline%20step%20independently%3A%20the%20scaler%20ignores%0A%20%20%20%20weights%20while%20the%20classifier%20uses%20them.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20FloatDistribution%2C%0A%20%20%20%20LogisticRegression%2C%0A%20%20%20%20OptunaSearchCV%2C%0A%20%20%20%20Pipeline%2C%0A%20%20%20%20StandardScaler%2C%0A%20%20%20%20X%2C%0A%20%20%20%20sample_weight%2C%0A%20%20%20%20sklearn%2C%0A%20%20%20%20y%2C%0A)%3A%0A%20%20%20%20with%20sklearn.config_context(enable_metadata_routing%3DTrue)%3A%0A%20%20%20%20%20%20%20%20%23%20Configure%20pipeline%20components%0A%20%20%20%20%20%20%20%20scaler%20%3D%20StandardScaler()%0A%20%20%20%20%20%20%20%20scaler.set_fit_request(sample_weight%3DFalse)%20%20%23%20Scaler%20ignores%20weights%0A%0A%20%20%20%20%20%20%20%20lr_pipe%20%3D%20LogisticRegression(max_iter%3D300%2C%20random_state%3D42)%0A%20%20%20%20%20%20%20%20lr_pipe.set_fit_request(sample_weight%3DTrue)%20%20%23%20Classifier%20uses%20weights%0A%20%20%20%20%20%20%20%20lr_pipe.set_score_request(sample_weight%3DTrue)%0A%0A%20%20%20%20%20%20%20%20pipe%20%3D%20Pipeline(%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20(%22scaler%22%2C%20scaler)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20(%22classifier%22%2C%20lr_pipe)%2C%0A%20%20%20%20%20%20%20%20%5D)%0A%20%20%20%20%20%20%20%20pipe.set_score_request(sample_weight%3DTrue)%20%20%23%20Route%20weights%20to%20pipeline%20scoring%0A%0A%20%20%20%20%20%20%20%20search_pipe%20%3D%20OptunaSearchCV(%0A%20%20%20%20%20%20%20%20%20%20%20%20pipe%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20param_distributions%3D%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22classifier__C%22%3A%20FloatDistribution(0.01%2C%2010.0%2C%20log%3DTrue)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%7D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20n_trials%3D10%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20cv%3D3%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20search_pipe.fit(X%2C%20y%2C%20sample_weight%3Dsample_weight)%0A%20%20%20%20return%20(search_pipe%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo%2C%20search_pipe)%3A%0A%20%20%20%20mo.md(f%22%22%22%0A%20%20%20%20**Best%20Parameters%3A**%20%60C%20%3D%20%7Bsearch_pipe.best_params_%5B'classifier__C'%5D%3A.4f%7D%60%0A%20%20%20%20**Best%20Score%3A**%20%60%7Bsearch_pipe.best_score_%3A.3f%7D%60%0A%0A%20%20%20%20The%20sample%20weights%20were%20correctly%20routed%20only%20to%20the%20classifier%2C%20not%20the%20scaler.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
081e883d46ceb25e89a886f57e00ad5b